In [None]:
%load_ext autoreload
%autoreload 2
import sys
print(sys.executable)
import socket
print(socket.gethostname())
import os
os.chdir("/home/icb/alioguz.can/projects/scPortrait4i")

In [None]:
import torch
import pandas as pd
import scanpy as sc
import pytorch_lightning as pl

from pathlib import Path
from torchvision import transforms
from models import TransformerModel
from scportrait.tools.ml.utils import split_dataset_fractions
from scportrait.tools.ml.datasets import H5ScSingleCellDataset
from utils import feature_extractor, get_anndata_obj, inference

torch.cuda.empty_cache()

In [None]:
data_path = Path("scportrait_manuscript/input_data/Xenium_ovarian_cancer/processed_data/scPortrait_project_xenium/extraction/data/single_cells.h5sc")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t = transforms.Compose([
    transforms.CenterCrop(128),
    transforms.Resize(224),
    feature_extractor])

layer_outputs = {}
def get_intermediate(module, input, output):
    layer_outputs['encoder_layernorm'] = output
    
random_indices = False
savedir = '/lustre/groups/ml01/workspace/alioguz.can/scportrait4i/training_output' ## UPDATE PATHS
class_list = [0]
return_id = True
if random_indices:
    dataset = H5ScSingleCellDataset(dir_list=[data_path], dir_labels=class_list, select_channel=[2,3,4], transform=t, return_id=return_id)
    train_dataset, val_dataset, test_dataset = split_dataset_fractions(
        [dataset],
        fractions=[0.9, 0.05, 0.05],
        seed=42)
else:
    ## init from predefined indices
    indices_folder = "scportrait_manuscript/input_data/Xenium_ovarian_cancer/processed_data/test_val_datasets"

    train_set_indexes = pd.read_csv(f'{indices_folder}/train_set_indexes.csv', header = None)[0].tolist()
    train_dataset = H5ScSingleCellDataset(dir_list=[data_path], dir_labels=class_list, select_channel=[2,3,4], transform=t, return_id=return_id, index_list=[train_set_indexes])

    test_set_indexes = pd.read_csv(f'{indices_folder}/test_set_indexes.csv', header = None)[0].tolist()
    test_dataset = H5ScSingleCellDataset(dir_list=[data_path], dir_labels=class_list, select_channel=[2,3,4], transform=t, return_id=return_id, index_list=[test_set_indexes])

    val_set_indexes = pd.read_csv(f'{indices_folder}/val_set_indexes.csv', header = None)[0].tolist()
    val_dataset = H5ScSingleCellDataset(dir_list=[data_path], dir_labels=class_list, select_channel=[2,3,4], transform=t, return_id=return_id, index_list=[val_set_indexes])

In [None]:
# Generate dataloaders
# ====================
print("Generating the dataloaders...")
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_size=64, # batch size?
                                                shuffle=True,
                                                num_workers=10, 
                                                drop_last=False)

val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                batch_size=64, 
                                                shuffle=False,
                                                num_workers=10, 
                                                drop_last=False)

test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                batch_size=64, 
                                                shuffle=False,
                                                num_workers=10, 
                                                drop_last=False)

print("Dataloaders are initialized.")

In [None]:
# Initialize model from path
vitMAE_xenium = TransformerModel(finetune=True, in_channels=3)
checkpoint = torch.load("/lustre/groups/ml01/workspace/alioguz.can/scportrait4i/training_output/Xenium_ovarian_cancer_finetune_epoch=epoch=119-v1.ckpt") ## UPDATE PATHS
vitMAE_xenium.load_state_dict(checkpoint['state_dict'])
vitMAE_xenium.model.vit.layernorm.register_forward_hook(get_intermediate)
vitMAE_xenium = vitMAE_xenium.to(device)

test dataloader

In [None]:
cls_outputs, pooled_patch_outputs, out_labels, out_cell_ids = inference(model=vitMAE_xenium.model, dataloader=test_dataloader, layer_outputs=layer_outputs, device=device)

In [None]:
adata = get_anndata_obj(pooled_patch_outputs, out_labels, cell_ids=out_cell_ids, channels=None)
sc.pp.scale(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color='targets', title="VIT-MAE Xenium Inference")#, save="_test_vit_data_per_cell.png")

In [None]:
sc.tl.leiden(adata, key_added="leiden_res0_25", resolution=0.25)
sc.pl.umap(
    adata,
    color=["leiden_res0_25"],
    legend_loc="right margin",
)

In [None]:
adata.write_h5ad(
    "/lustre/groups/ml01/workspace/alioguz.can/scportrait4i/xenium_vitmae_test_features.h5ad", ## UPDATE PATHS
)

train dataloader

In [None]:
cls_outputs, pooled_patch_outputs, out_labels, out_cell_ids = inference(model=vitMAE_xenium.model, dataloader=train_dataloader, layer_outputs=layer_outputs, device=device)

In [None]:
adata = get_anndata_obj(pooled_patch_outputs, out_labels, cell_ids=out_cell_ids, channels=None)
sc.pp.scale(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color='targets', title="VIT-MAE Xenium Inference")#, save="_test_vit_data_per_cell.png")

In [None]:
sc.tl.leiden(adata, key_added="leiden_res0_25", resolution=0.25)
sc.pl.umap(
    adata,
    color=["leiden_res0_25"],
    legend_loc="right margin",
)

In [None]:
adata.write_h5ad(
    "/lustre/groups/ml01/workspace/alioguz.can/scportrait4i/xenium_vitmae_train_features.h5ad", ## UPDATE PATHS
)

val_dataloader

In [None]:
cls_outputs, pooled_patch_outputs, out_labels, out_cell_ids = inference(model=vitMAE_xenium.model, dataloader=val_dataloader, layer_outputs=layer_outputs, device=device)
adata = get_anndata_obj(pooled_patch_outputs, out_labels, cell_ids=out_cell_ids, channels=None)
sc.pp.scale(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color='targets', title="VIT-MAE Xenium Inference")#, save="_test_vit_data_per_cell.png")
sc.tl.leiden(adata, key_added="leiden_res0_25", resolution=0.25)
sc.pl.umap(
    adata,
    color=["leiden_res0_25"],
    legend_loc="right margin",
)
adata.write_h5ad(
    "/lustre/groups/ml01/workspace/alioguz.can/scportrait4i/xenium_vitmae_val_features.h5ad", ## UPDATE PATHS
)

Reconstruction

In [None]:
import matplotlib.pyplot as plt

In [None]:
imgs, labels, ids = next(iter(test_dataloader))
recons = vitMAE_xenium(imgs.to(device))
recons["logits"].shape

In [None]:
reconstructed_imgs = vitMAE_xenium.unpatchify(recons["logits"])

# Visualize original vs reconstruction
idx = 12  # image index
fig, axs = plt.subplots(1, 2, figsize=(8, 4))

# Original image
axs[0].imshow(imgs[idx].permute(1, 2, 0).cpu().numpy(), vmin=0, vmax=0.05)
axs[0].set_title("Original")
axs[0].axis("off")

# Reconstructed image
axs[1].imshow(reconstructed_imgs[idx].permute(1, 2, 0).detach().cpu().numpy())
axs[1].set_title("Reconstruction")
axs[1].axis("off")

plt.tight_layout()


plt.show()