In [9]:
from dino.vision_transformer import DINOHead, VisionTransformer
from dino.vim.models_mamba import VisionMamba
from dino.config import configurations
from dino.main import get_args_parser
from functools import partial

from torch import nn
import torch
from dino.utils import load_pretrained_weights

from MIL.models.resnet_custom import resnet50_baseline


In [10]:
parser = get_args_parser()
args = parser.parse_known_args()[0]

In [11]:
args.arch = 'vit-s'
args.image_size = 224
args.patch_size = 16
args.num_classes = 2
args.n_last_blocks = 4
args.avgpool_patchtokens = False

if args.arch in configurations:
    config = configurations[args.arch]
    config['img_size'] = args.image_size
    config['patch_size'] = args.patch_size
    config['num_classes'] = args.num_classes

    if 'norm_layer' in config and config['norm_layer'] == "nn.LayerNorm":
        config['norm_layer'] = partial(nn.LayerNorm, eps=config['eps'])
    config['drop_path_rate'] = 0  
    if args.arch.startswith('vim'):
        config['final_pool_type']='all'
        model = VisionMamba(return_features=True, **config)
        embed_dim = model.embed_dim
    elif args.arch.startswith('vit'):
        model = VisionTransformer(**config)
        embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
    print('EMBEDDED DIM:', embed_dim)
elif args.arch=='resnet':
    model = resnet50_baseline(pretrained=True)
else:
    print(f"Unknown architecture: {args.arch}")

model.cuda()
model.eval()
args.pretrained_weights = '/home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/checkpoint.pth'
args.checkpoint_key = 'teacher'
load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)

EMBEDDED DIM: 1536
Take key teacher in provided checkpoint dict
Skipping loading parameter head.weight due to size mismatch or it not being present in the checkpoint.
Skipping loading parameter head.bias due to size mismatch or it not being present in the checkpoint.
Pretrained weights found at /home/ubuntu/checkpoints/camelyon16_224_10x/vit-s_224-96/checkpoint.pth and loaded with msg: _IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=['head.mlp.0.weight', 'head.mlp.0.bias', 'head.mlp.2.weight', 'head.mlp.2.bias', 'head.mlp.4.weight', 'head.mlp.4.bias', 'head.last_layer.weight_g', 'head.last_layer.weight_v'])


In [47]:
import os
import glob
from natsort import os_sorted
import time
import openslide

In [48]:
def get_feaures(model, inp):
    with torch.no_grad():
        if "vit" in args.arch:
            intermediate_output = model.get_intermediate_layers(inp, args.n_last_blocks)
            output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
            if args.avgpool_patchtokens:
                output = torch.cat((output.unsqueeze(-1), 
                                    torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
                output = output.reshape(output.shape[0], -1)
        else:
            output = model(inp)
    return output

In [49]:
def compute_w_loader(file_path, output_path, wsi, model,
    batch_size = 8, verbose = 0, print_every=20, pretrained=True, 
    custom_downsample=1, target_patch_size=-1):
    """
    args:
        file_path: directory of bag (.h5 file)
        output_path: directory to save computed features (.h5 file)
        model: pytorch model
        batch_size: batch_size for computing features in batches
        verbose: level of feedback
        pretrained: use weights pretrained on imagenet
        custom_downsample: custom defined downscale factor of image patches
        target_patch_size: custom defined, rescaled image size before embedding
    """
    dataset = Whole_Slide_Bag_FP(file_path=file_path, wsi=wsi, pretrained=pretrained, 
        custom_downsample=custom_downsample, target_patch_size=target_patch_size)
    x, y = dataset[0]
    kwargs = {'num_workers': 4, 'pin_memory': True} if device.type == "cuda" else {}
    loader = DataLoader(dataset=dataset, batch_size=batch_size, **kwargs, collate_fn=collate_features)

    if verbose > 0:
        print('processing {}: total of {} batches'.format(file_path,len(loader)))

    mode = 'w'
    for count, (batch, coords) in enumerate(loader):
        with torch.no_grad():
#             if count % print_every == 0:
#                 print('batch {}/{}, {} files processed'.format(count, len(loader), count * batch_size))
            batch = batch.to(device, non_blocking=True)

            features = get_feaures(model, batch)
            features = features.cpu().numpy()

            asset_dict = {'features': features, 'coords': coords}
            save_hdf5(output_path, asset_dict, attr_dict= None, mode=mode)
            mode = 'a'

    return output_path

In [50]:
from MIL.datasets.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag_FP
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from MIL.utils.utils import print_network, collate_features
from MIL.utils.file_utils import save_hdf5
import h5py
from tqdm.notebook import tqdm
import pandas as pd
from torchvision import transforms

In [51]:
def create_data_split(feat_dir, classes, slide_folder, h5_dir):
    os.makedirs(os.path.join(feat_dir, 'pt_files'), exist_ok=True)
    os.makedirs(os.path.join(feat_dir, 'h5_files'), exist_ok=True)
    dest_files = os.listdir(os.path.join(feat_dir, 'pt_files'))

    for class_name in classes:
        print(f"Processing for class {class_name}")
        args.data_h5_dir = os.path.join(h5_dir, class_name)
        slide_paths = glob.glob(os.path.join(slide_folder, class_name, '*tif'))
        for slide_file_path in tqdm(slide_paths):
            slide_id = os.path.splitext(os.path.basename(slide_file_path))[0]
#             if os.path.exists(os.path.join(feat_dir, 'pt_files', slide_id+'.pt')):
#                 continue
            bag_name = slide_id+'.h5'
            h5_file_path = os.path.join(args.data_h5_dir, 'patches', bag_name)
            output_path = os.path.join(feat_dir, 'h5_files', bag_name)
            wsi = openslide.open_slide(slide_file_path)
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            args.batch_size=16
            output_file_path = compute_w_loader(h5_file_path, output_path, wsi, 
                                                model = model, batch_size = args.batch_size, verbose = 0, 
                                                print_every = 20, custom_downsample=1, 
                                                target_patch_size=args.image_size)


            with h5py.File(output_file_path, "r") as file:
                features = file['features'][:]
                features = torch.from_numpy(features)
                torch.save(features, os.path.join(feat_dir, 'pt_files', slide_id+'.pt'))

In [None]:
torch.load(features, os.path.join(feat_dir, 'pt_files', slide_id+'.pt'))

In [52]:
classes = ['normal', 'tumor']
output_dir = 'clam_data/'
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
split = 'training'
print(f"Processing split: {split}")
slide_folder = f'/home/ubuntu/Downloads/Camelyon16/{split}/'
feat_dir = f'{output_dir}/{args.image_size}_10x/{args.arch}/{split}'
h5_dir = f'dataset/Camelyon16/{split}/{args.image_size}_10x/h5/'

create_data_split(feat_dir, classes, slide_folder, h5_dir)

# create csv

slide_ids = os_sorted(os.listdir(os.path.join(feat_dir, 'h5_files')))
slide_ids = [i.split('.')[0] for i in slide_ids ]
label = [i.split('_')[0] for i in slide_ids ]
df = pd.DataFrame([slide_ids, slide_ids, label]).T
df.columns = ['case_id', 'slide_id', 'label']
df.to_csv(os.path.join(feat_dir, 'tumor_vs_normal.csv'))

Processing split: training
Processing for class normal


  0%|          | 0/158 [00:00<?, ?it/s]

Processing for class tumor


  0%|          | 0/110 [00:00<?, ?it/s]

In [53]:
output_dir = 'clam_data/'
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
split = 'testing'
print(f"Processing split: {split}")
slide_folder = f'/home/ubuntu/Downloads/Camelyon16/{split}/'
feat_dir = f'{output_dir}/{args.image_size}_10x/{args.arch}/{split}'
h5_dir = f'dataset/Camelyon16/{split}/{args.image_size}_10x/h5/'

classes = ['images']
create_data_split(feat_dir, classes, slide_folder, h5_dir)

# create csv
df = pd.read_csv('/home/ubuntu/Downloads/Camelyon16/testing/reference.csv', header=None)
df = df.drop([2, 3], axis=1)
df[2] = df[0]
df = df[[0, 2, 1]]
df.columns = ['case_id', 'slide_id', 'label']
df['label'] = df['label'].apply(str.lower)
df.to_csv(os.path.join(feat_dir, 'tumor_vs_normal.csv'))

Processing split: testing
Processing for class images


  0%|          | 0/129 [00:00<?, ?it/s]