In [None]:
import time
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from PIL import Image
import h5py
import openslide
import numpy as np
from tqdm.notebook import tqdm  # Import tqdm for progress bar

from utils.file_utils import save_hdf5
from dataset_modules.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag, get_eval_transforms
from models import get_encoder


In [None]:

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def compute_w_loader(output_path, loader, model, verbose=0):
    """
    args:
        output_path: directory to save computed features (.h5 file)
        model: pytorch model
        verbose: level of feedback
    """
    if verbose > 0:
        print('Processing {}: total of {} batches'.format(file_path, len(loader)))

    mode = 'w'
    for count, data in enumerate(tqdm(loader)):
        batch = data['img']
        coords = data['coord'].numpy().astype(np.int32)
        batch = batch.to(device, non_blocking=True)

        with torch.no_grad():  # Use no_grad context to disable gradient calculation
            features = model(batch)

        features = features.cpu().numpy()

        asset_dict = {'features': features, 'coords': coords}
        save_hdf5(output_path, asset_dict, attr_dict=None, mode=mode)
        mode = 'a'

    return output_path


In [None]:

# Define arguments
data_dir = "path_to_data_directory"
csv_path = "path_to_csv_file"
feat_dir = "path_to_feature_directory"
model_name = 'uni_v1'  # or any other model name
batch_size = 256
slide_ext = '.svs'
no_auto_skip = False
target_patch_size = 224

# Now you can proceed to load data, model, and run the computation using the defined variables above


In [None]:
print('Initializing dataset')

# Define paths
csv_path = "path_to_csv_file"
feat_dir = "path_to_feature_directory"
data_h5_dir = "path_to_h5_directory"
data_slide_dir = "path_to_slide_directory"
slide_ext = ".svs"

bags_dataset = Dataset_All_Bags(csv_path)

os.makedirs(feat_dir, exist_ok=True)
os.makedirs(os.path.join(feat_dir, 'pt_files'), exist_ok=True)
os.makedirs(os.path.join(feat_dir, 'h5_files'), exist_ok=True)
dest_files = os.listdir(os.path.join(feat_dir, 'pt_files'))

model, img_transforms = get_encoder(args.model_name, target_img_size=args.target_patch_size)

_ = model.eval()
model = model.to(device)
total = len(bags_dataset)

loader_kwargs = {'num_workers': 8, 'pin_memory': True} if device.type == "cuda" else {}

for bag_candidate_idx in tqdm(range(total)):
    slide_id = bags_dataset[bag_candidate_idx].split(slide_ext)[0]
    bag_name = slide_id+'.h5'
    h5_file_path = os.path.join(data_h5_dir, 'patches', bag_name)
    slide_file_path = os.path.join(data_slide_dir, slide_id+slide_ext)
    print('\nProgress: {}/{}'.format(bag_candidate_idx, total))
    print(slide_id)

    if not args.no_auto_skip and slide_id+'.pt' in dest_files:
        print('Skipped {}'.format(slide_id))
        continue 

    output_path = os.path.join(feat_dir, 'h5_files', bag_name)
    time_start = time.time()
    wsi = openslide.open_slide(slide_file_path)
    dataset = Whole_Slide_Bag_FP(file_path=h5_file_path, 
                                    wsi=wsi, 
                                    img_transforms=img_transforms)

    loader = DataLoader(dataset=dataset, batch_size=args.batch_size, **loader_kwargs)
    output_file_path = compute_w_loader(output_path, loader=loader, model=model, verbose=1)

    time_elapsed = time.time() - time_start
    print('\nComputing features for {} took {} s'.format(output_file_path, time_elapsed))

    with h5py.File(output_file_path, "r") as file:
        features = file['features'][:]
        print('Features size: ', features.shape)
        print('Coordinates size: ', file['coords'].shape)

    features = torch.from_numpy(features)
    bag_base, _ = os.path.splitext(bag_name)
    torch.save(features, os.path.join(feat_dir, 'pt_files', bag_base+'.pt'))
