In [1]:
import time
import os
import argparse
import pdb
from functools import partial

import torch
import torch.nn as nn
import timm
from torch.utils.data import DataLoader
from PIL import Image
import h5py
import openslide
from tqdm import tqdm

import numpy as np
from glob import glob

from utils.file_utils import save_hdf5
from dataset_modules.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag_FP
from models import get_encoder

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [16]:
args = argparse.Namespace(
    data_h5_dir='/home/z/zeyugao/dataset/WSIData/Camelyon/clam_gen/',        # 设置你的数据 H5 文件夹路径
    anno_dir='/home/shared/su123/Camelyon/annotation/',              # 设置你的注释文件夹路径
    data_slide_dir='/home/shared/su123/Camelyon/WSIs/',  # 设置你的切片数据文件夹路径
    slide_ext='.tif',                      
    # 设置文件扩展名
    csv_path='/home/z/zeyugao/dataset/WSIData/Camelyon/clam_gen/process_list_autogen.csv',          # 设置 CSV 文件路径
    feat_dir='/home/z/zeyugao/dataset/WSIData/Camelyon/clam_gen/dsmil',              # 设置特征输出文件夹路径
    model_name='dsmil_camel',              # 选择的模型名称
    batch_size=128,                           # 批处理大小
    no_auto_skip=False,                       # 是否自动跳过
    target_patch_size=224,                    # 目标 patch 大小
    suffix="_0_512",                           # 文件名后缀
    patch_size=512,
)

In [17]:
csv_path = args.csv_path
bags_dataset = Dataset_All_Bags(csv_path)

In [18]:
os.makedirs(args.feat_dir, exist_ok=True)

In [19]:
dest_files = os.listdir(args.feat_dir)
anno_list = glob(os.path.join(args.anno_dir, '*.png'))

In [20]:
model, img_transforms = get_encoder(args.model_name, target_img_size=args.target_patch_size)
_ = model.eval()

loading model checkpoint
IClassifier(
  (feature_extractor): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): InstanceN



In [21]:
model = model.to(device)
total = len(bags_dataset)

In [22]:
loader_kwargs = {'num_workers': 8, 'pin_memory': True} if device.type == "cuda" else {}

In [23]:
# bag_candidate_idx = 100
for bag_candidate_idx in tqdm(range(total)):
    slide_id = os.path.basename(bags_dataset[bag_candidate_idx]).split(args.slide_ext)[0]
    uuid = bags_dataset[bag_candidate_idx].split('/')[-2]
    bag_name = slide_id+'.h5'
    h5_file_path = os.path.join(args.data_h5_dir, 'patches', bag_name)
    slide_file_path = os.path.join(args.data_slide_dir, uuid, slide_id + args.slide_ext)
    anno_path = args.anno_dir + '%s.png' % (slide_id)
    if anno_path not in anno_list:
        anno_path = None
    output_path = os.path.join(args.feat_dir, slide_id + args.suffix + '.npy')
    wsi = openslide.open_slide(slide_file_path)
    dataset = Whole_Slide_Bag_FP(file_path=h5_file_path,
                                 anno_path=anno_path,
                                 wsi=wsi,
                                 ori_patch_size=args.patch_size,
                                 img_transforms=img_transforms)
    break

  0%|          | 0/399 [00:00<?, ?it/s]

downsample [1. 1.]
downsampled_level_dim [ 97792 221184]
level_dim [ 97792 221184]
name normal_001
patch_level 0
patch_size 512
save_path /home/z/zeyugao/dataset/WSIData/Camelyon/clam_gen/patches

feature extraction settings
transformations:  Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=None)
    ToTensor()
)





In [29]:
loader = DataLoader(dataset=dataset, batch_size=args.batch_size, **loader_kwargs)

In [30]:
features_list = []
indexs_list = []
inst_labels_list = []
for count, data in enumerate(tqdm(loader)):
    with torch.inference_mode():	
        batch = data['img']
        coords = data['coord']
        inst_labels = data['inst_label'].tolist()
        batch = batch.to(device, non_blocking=True)
        
        features = model(batch)
        features = features.cpu().numpy().astype(np.float32)

        features_list.append(features)
        indexs_list += coords
        inst_labels_list += inst_labels

100%|██████████| 18/18 [00:06<00:00,  2.86it/s]


In [32]:
features_list[0].shape

(128, 512)

In [30]:
# dsmil 
import torch
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict

class IClassifier(nn.Module):
    def __init__(self, feature_extractor, feature_size, output_class):
        super(IClassifier, self).__init__()
        
        self.feature_extractor = feature_extractor      
        self.fc = nn.Linear(feature_size, output_class)
        
    def forward(self, x):
        device = x.device
        feats = self.feature_extractor(x) # N x K
        # c = self.fc(feats.view(feats.shape[0], -1)) # N x C
        return feats.view(feats.shape[0], -1)

In [40]:
norm=nn.InstanceNorm2d
resnet = models.resnet18(pretrained=False, norm_layer=norm)
resnet.fc = nn.Identity()

In [41]:
i_classifier = IClassifier(resnet, 512, output_class=1).cuda()

In [42]:
weight_path = "/home/z/zeyugao/PreModel/dsmil/model-v1-lung.pth"

state_dict_weights = torch.load(weight_path)
for i in range(4):
    state_dict_weights.popitem()
state_dict_init = i_classifier.state_dict()
new_state_dict = OrderedDict()
for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
    name = k_0
    new_state_dict[name] = v
i_classifier.load_state_dict(new_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

In [43]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor, Compose

# 创建一个224x224的随机RGB图像
random_image_array = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
random_image = Image.fromarray(random_image_array)
transform = Compose([
    ToTensor()  # 将PIL图像转换为张量
])

# 应用转换
transformed_image = transform(random_image)

In [44]:
transformed_image = transformed_image.unsqueeze(0)

In [45]:
patches = transformed_image.float().cuda()
feats = i_classifier(patches)

In [46]:
feats.shape

torch.Size([1, 512])

In [53]:
'conch_v1' in MODEL2CONSTANTS.keys()

True