In [1]:
import numpy as np
import pandas as pd
import torch
import openslide
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torch.utils.data import DataLoader
from tqdm import tqdm


from dataset import Visium_HES_Dataset
from model import load_UNI2h

import os
os.chdir('/workspaces/HES_feature_extraction')

In [2]:
# Determine the device to run the model on (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Load the model and its transformer

In [3]:
# Load the UNI2 model
model = load_UNI2h(device)

# Create the transform function from the model configuration
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))

# Open the slide using openslide

In [4]:
# Determine the slide id
slide_id = "Visium_FFPE_V43T08-041_A"

# Load the slide with openslide
slide = openslide.OpenSlide(f"data_hdd/visium/results/{slide_id}/images/annotation_images/visium_region.ome.tiff")

# Make a thumbnail of the slide
# slide.get_thumbnail(size=(256, 256))

# Load the pixels of the MALDI MSI image

In [5]:
# Load the pixels of the MALDI MSI image
coordinates = pd.read_csv(f"data_hdd/visium/results/{slide_id}/warped_coordinates.csv")

# Create the dataset

In [6]:
# Create the dataset
dataset = Visium_HES_Dataset(slide=slide,
                            coordinates=coordinates,
                            transform=transform)

# Create the dataloader
dataloader = DataLoader(dataset,
                        batch_size=256,
                        shuffle=False)

# Perform inference on the dataset

In [7]:
# Create the embedding
feature_emb = []

# Iterate over the dataloader with tqdm
for batch in tqdm(dataloader, desc="Processing batches"):
    
    # Get the embedding from the model
    with torch.inference_mode():
        output = model(batch.to(device))
    
    # Append the embedding to the list
    feature_emb.append(output.cpu().numpy())

# Concatenate the embeddings
feature_emb = np.concatenate(feature_emb)

Processing batches: 100%|██████████| 20/20 [05:24<00:00, 16.24s/it]


In [8]:
# Clean up the GPU memory
torch.cuda.empty_cache()

# Save the Results

In [9]:
# Transform the embedding features into a dataframe
feature_emb_df = pd.DataFrame(feature_emb, index=coordinates.index)

# Save the embedding features to a pickle file
feature_emb_df.to_pickle(f"data_hdd/visium/results/{slide_id}/hes_features.pkl")