In [1]:
import torch
import torchvision
import os
import sys
from os.path import join as j_
from PIL import Image
import pandas as pd
import numpy as np
import time
import random
from tqdm import tqdm
# print(torch.version)
# print(torch.version.cuda)
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Sampler
from torchvision.transforms import Lambda
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
# loading all packages here to start
from uni import get_encoder
from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe
from uni.downstream.eval_patch_features.fewshot import eval_knn, eval_fewshot
from uni.downstream.eval_patch_features.protonet import ProtoNet, prototype_topk_vote
from uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics
from uni.downstream.utils import concat_images
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# configs
BATCH_SIZE = 1 # load each slide all tiles sequentially 
K_FOLDS_PATH = r"E:\KSA Project\\dataset\\splits\kfolds_IDARS.csv"
DATA_PATH = r"E:\\KSA Project\\dataset\\testing\\Patches"
FEATURES_SAVE_DIR = r"E:\\KSA Project\\dataset\\testing\\uni_features"
# torch.tensor([1.2, 3.4]).device

### Downloading UNI weights + Creating Model

The function `get_encoder` performs the commands above, downloading in the checkpoint in the `./assets/ckpts/` relative path of this GitHub repository.

In [3]:
from uni import get_encoder
model, transform = get_encoder(enc_name='uni', device=device)
#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
#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

### Data Loaders

In [4]:
# import dataloader from one step back directory there is a fine named dataloader.py
sys.path.append("..")
from dataloader import PatchLoader, SlideBatchSampler

mode = 1 # for sequentially data/patches loading we will use mode =1 and mode= 2 for random loading.

transform = transforms.Compose([
    transforms.FiveCrop(224),  # this is a list of 5 crops
    Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))  # convert to tensor and stack
])

def create_dataloader(label_file, data_path, transform, num_samples, mode):
    # Create the dataset
    dataset = PatchLoader(label_file=label_file, data_path=data_path, transform=transform, num_samples=num_samples, mode=mode)
    # Ensure sequential data loading by disabling shuffle
    batch_sampler = SlideBatchSampler(dataset.ntiles)
    dataloader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0, pin_memory=False)
    return dataloader

# Create DataLoaders
data_loader = create_dataloader(label_file=K_FOLDS_PATH,data_path=DATA_PATH,
                                     transform=transform,num_samples=None,mode=mode)
print(f"Length of data_loader: {len(data_loader)}")

Number of Slides: 5
Number of tiles: 25
Length of data_loader: 5


In [5]:
for batch_idx, (images, labels) in enumerate(data_loader):
    print(f"Batch {batch_idx+1}")
    print(f"Images shape: {images.shape}")
    print(f"Labels: {labels}")
    if batch_idx == 5:  # Only print a few batches to check if it's working
        break

Batch 1
Images shape: torch.Size([7, 5, 3, 224, 224])
Labels: tensor([0, 0, 0, 0, 0, 0, 0])
Batch 2
Images shape: torch.Size([9, 5, 3, 224, 224])
Labels: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1])
Batch 3
Images shape: torch.Size([3, 5, 3, 224, 224])
Labels: tensor([1, 1, 1])
Batch 4
Images shape: torch.Size([4, 5, 3, 224, 224])
Labels: tensor([0, 0, 0, 0])
Batch 5
Images shape: torch.Size([2, 5, 3, 224, 224])
Labels: tensor([0, 0])


### DataLoader testing

In [13]:

def print_unique_class_representation(dataloader):
    def get_unique_classes(loader):
        all_labels = []
        for _, labels in loader:
            all_labels.extend(labels.tolist())
        unique_labels, counts = torch.unique(torch.tensor(all_labels), return_counts=True)
        return unique_labels, counts
    unique_labels, counts = get_unique_classes(dataloader)

    print("\nTest Loader Unique Labels and Counts:")
    print(f"Unique labels: {unique_labels}")
    print(f"Counts: {counts}")

# Example usage:
print_unique_class_representation(data_loader)
# for batch_idx, (images, labels) in enumerate(data_loader):
#     print(f"Batch {batch_idx}: {len(images)} images")
#     print(f"Images shape: {images.shape}")
#     unique_labels, counts = torch.unique(labels, return_counts=True)
#     print(f"Unique labels: {unique_labels}")
#     print(f"Counts: {counts}")
#     if batch_idx == 2:  # Only print a few batches to check if it's working
#         break


Test Loader Unique Labels and Counts:
Unique labels: tensor([0, 1])
Counts: tensor([26839,  4934])


### ROI Feature Extraction on FiveCrop Patches Level and Save

In [17]:
@torch.no_grad()
def extract_embeddings_patch_by_patch(model, dataloader, save_dir):
    """
    Extract and save embeddings for each WSI, patch by patch, without averaging the five crops.
    Args:
    - model: The model used to extract embeddings.
    - preprocess: Preprocessing function to apply to the images.
    - dataloader: Dataloader providing WSI patches and labels.
    - save_dir: Directory where the extracted embeddings will be saved.
    Returns:
    - None: The function saves the extracted embeddings to disk.
    """
    device = next(model.parameters()).device
    print(f'The size of input dataloader is {len(dataloader)}')

    for batch_idx, (images, labels) in tqdm(enumerate(dataloader), total=len(dataloader)):
        wsi_name = dataloader.dataset.slides[batch_idx]
        print(f"Processing WSI {batch_idx+1} {wsi_name}")
        # print WSI shape and number of patches
        print(f"Images shape: {images.shape} and Labels shape: {labels.shape} and no of patches: {len(images)}")
        # make a new directory for each WSI
        save_dir_wsi = os.path.join(save_dir, f'{wsi_name}')
        # Check if the WSI directory already exists, skip if processed
        if os.path.exists(save_dir_wsi):
            print(f"WSI {batch_idx+1} {wsi_name} already processed. Skipping...")
            continue
        os.makedirs(save_dir_wsi, exist_ok=True)
        for i in range(len(images)):
            image = images[i]
            label = labels[i]
            # Reshape image to combine batch and fivecrop dimensions
            num_crops, channels, height, width = image.shape
            image = image.view(num_crops, channels, height, width).to(device)
            # print shape of the patch after reshaping
            with torch.inference_mode():
                embeddings = model(image).detach().cpu()  # Extract features for the image
            # Save embeddings to disk
            save_path = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.pt')
            torch.save(embeddings, save_path)
            # also save in simple text format in text file
            # save_path_txt = os.path.join(save_dir_wsi, f'{wsi_name}_{i}.txt')
            # np.savetxt(save_path_txt, embeddings.cpu().numpy())
            # print the saved path
            # print(f"Embeddings saved to {save_path_txt}")


In [None]:
extract_embeddings_patch_by_patch(model, data_loader, FEATURES_SAVE_DIR)