In [None]:
import os

# import numpy as np
import rootutils
import scanpy as sc
import torch
from scipy.sparse import csr_matrix

# from torch_geometric.data import Data, Dataset

rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)
from torchmetrics.clustering import AdjustedRandScore, NormalizedMutualInfoScore

from src.data.spatial_omics_datamodule import SpatialOmicsDataModule
from src.models.bgrl_domain_module import BGRLDomainLitModule
from src.utils.clustering_utils import set_leiden_resolution


def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(44)

In [None]:
datamodule = SpatialOmicsDataModule(
    data_dir="../data/domain/raw",
    processed_dir="../data/domain/processed_knn_baseline/",
    redo_preprocess=False,
)

datamodule.prepare_data()
datamodule.setup()
test_dataloder = datamodule.test_dataloader()

for batch in test_dataloder:
    print(batch.sample_name[0])
    print(batch.x.shape)

In [None]:
desired_sample_name = "MERFISH_small4"
# desired_sample_name = "Zhuang-ABCA-4.002"
batch = None

for b in test_dataloder:
    if b.sample_name[0] == desired_sample_name:
        batch = b
        break

print(batch)
print(batch.sample_name)

In [None]:
# checkpoint_path = "../logs/train_domain/runs/2025-03-31_15-41-12/checkpoints/epoch_011.ckpt"
checkpoint_path = "../logs/augmentation/runs/2025-04-01_12-34-30/checkpoints/epoch_011.ckpt"
# checkpoint_path = "../logs/abca/runs/2025-04-02_17-10-41/checkpoints/epoch_008.ckpt"
model = BGRLDomainLitModule.load_from_checkpoint(checkpoint_path).net.online_encoder
model

In [None]:
with torch.no_grad():
    node_embeddings = model(batch.x, batch.edge_index, batch.edge_weight)

# load adata object
sample_name = batch.sample_name[0]
file_path = os.path.join("../data/domain/processed_knn_baseline", sample_name + ".h5ad")
adata = sc.read_h5ad(file_path)

# append cell embeddings to adata object
cell_embeddings_np = node_embeddings.cpu().numpy()
adata.obsm["cell_embeddings"] = cell_embeddings_np

# get ground truth labels
domain_name = None
if sample_name.startswith("MERFISH_small"):
    domain_name = "domain"
elif sample_name.startswith("STARmap"):
    domain_name = "region"
elif sample_name.startswith("BaristaSeq"):
    domain_name = "layer"
elif sample_name.startswith("Zhuang"):
    domain_name = "parcellation_structure"
ground_truth_labels = adata.obs[domain_name]

# determine resolution based on number of ground truth labels
sc.pp.neighbors(adata, use_rep="cell_embeddings")
resolution = set_leiden_resolution(adata, target_num_clusters=ground_truth_labels.nunique())
# perform leiden clustering
sc.tl.leiden(adata, resolution=resolution)
leiden_labels = adata.obs["leiden"]

# convert ground truth labels and leiden labels to PyTorch tensors
ground_truth_labels = adata.obs[domain_name].astype("category").cat.codes
ground_truth_labels = torch.tensor(ground_truth_labels.values, dtype=torch.long)
leiden_labels = adata.obs["leiden"].astype("category").cat.codes
leiden_labels = torch.tensor(leiden_labels.values, dtype=torch.long)

# calculate metrics
test_nmi = NormalizedMutualInfoScore()
test_ars = AdjustedRandScore()
nmi = test_nmi(ground_truth_labels, leiden_labels)
ari = test_ars(ground_truth_labels, leiden_labels)

# log metrics for each graph
print(f"test/nmi: {nmi}")
print(f"test/ari: {ari}")

sc.pl.embedding(adata, basis="spatial", color=domain_name, size=30)
sc.pl.embedding(adata, basis="spatial", color="leiden", size=30)

In [None]:
merfish4 = sc.read_h5ad(
    "../logS/augmentation/runs/2025-04-06_17-57-24/adata_files/MERFISH_small4.h5ad"
)
merfish5 = sc.read_h5ad(
    "../logS/augmentation/runs/2025-04-06_17-57-24/adata_files/MERFISH_small5.h5ad"
)

In [None]:
sc.pl.embedding(merfish4, basis="spatial", color="domain", size=30, title="Ground Truth")
sc.pl.embedding(merfish4, basis="spatial", color="leiden", size=30, title="NMI: 0.61, ARI: 0.53")

In [None]:
merfish5.obsm["spatial"] *= -1
merfish5.obsm["spatial"][:, 0] *= -1

sc.pl.embedding(merfish5, basis="spatial", color="domain", size=40, title="Ground Truth")
sc.pl.embedding(
    merfish5,
    basis="spatial",
    color="leiden",
    size=40,
    title="NMI: 0.64, ARI: 0.59, HOM: 0.66, COM: 0.63",
)

In [None]:
merfish5.uns["leiden_colors"] = [
    merfish5.uns["leiden_colors"][i] for i in [0, 1, 2, 4, 5, 7, 3, 6]
]

In [None]:
import matplotlib.pyplot as plt

# Flip and mirror the spatial embedding
merfish5.obsm["spatial"] *= -1
merfish5.obsm["spatial"][:, 0] *= -1

# Plot the ground truth with a custom figure size
fig, ax = plt.subplots(figsize=(6, 8))  # Adjust width and height (e.g., 6 wide, 10 tall)
sc.pl.embedding(
    merfish5, basis="spatial", color="domain", size=60, title="Ground Truth", ax=ax, show=False
)
plt.show()

# Plot the Leiden clusters with a custom figure size
fig, ax = plt.subplots(figsize=(6, 8))  # Adjust width and height
sc.pl.embedding(
    merfish5,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.64, HOM: 0.66, COM: 0.63",
    ax=ax,
    show=False,
)
plt.show()

In [None]:
adata = sc.read_h5ad("../data/MERFISH_small5.h5ad")
adata

In [None]:
import matplotlib.pyplot as plt

# Flip and mirror the spatial embedding
adata.obsm["spatial"] *= -1
adata.obsm["spatial"][:, 0] *= -1

# Plot the ground truth with a custom figure size
fig, ax = plt.subplots(figsize=(6, 8))  # Adjust width and height (e.g., 6 wide, 10 tall)
sc.pl.embedding(
    adata,
    basis="spatial",
    color="domain_annotation",
    size=60,
    title="Ground Truth",
    ax=ax,
    show=False,
)
plt.show()

# Plot the Leiden clusters with a custom figure size
fig, ax = plt.subplots(figsize=(6, 8))  # Adjust width and height
sc.pl.embedding(
    adata,
    basis="spatial",
    color="leiden",
    size=60,
    title="NMI: 0.58, HOM: 0.57, COM: 0.60",
    ax=ax,
    show=False,
)
plt.show()

In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import torch
from torch_geometric.data import Data
from tqdm import tqdm

graph_dir = "../data/domain/processed_123/"
all_distances = []

for fname in tqdm(os.listdir(graph_dir)):
    if fname.endswith("graph.pt"):
        path = os.path.join(graph_dir, fname)
        data: Data = torch.load(path, weights_only=False)

        if hasattr(data, "position") and data.position is not None and data.position.size(0) > 1:
            edge_index = data.edge_index
            pos = data.position

            src, dst = edge_index
            dists = torch.norm(pos[src] - pos[dst], dim=1)

            valid_dists = dists[dists > 0]
            all_distances.append(valid_dists)


all_distances = torch.cat(all_distances).cpu().numpy()

plt.hist(all_distances, bins=100, density=True)
plt.title("Domain123: Distribution of Intra-Graph Neighbor Distances")
plt.xlabel("Distance")
plt.ylabel("Density")
plt.grid(True)
plt.show()

print(
    f"Mean: {all_distances.mean():.4f}, Median: {np.median(all_distances):.4f}, Std: {all_distances.std():.4f}"
)

In [None]:
graph_dir = "../data/domain/processed_123/"
adata_dir = graph_dir

domain_features = defaultdict(list)
stds = []

for fname in os.listdir(graph_dir):
    if fname.endswith("_graph.pt"):
        base = fname.replace("_graph.pt", "")
        graph = torch.load(os.path.join(graph_dir, fname), weights_only=False)
        adata = sc.read_h5ad(os.path.join(adata_dir, base + ".h5ad"))

        graph.y = list(adata.obs["domain_annotation"])

        x = graph.x
        y = graph.y

        for i, domain in enumerate(y):
            domain_features[domain].append(x[i].unsqueeze(0))

for domain, feats in domain_features.items():
    all_feats = torch.cat(feats, dim=0)
    std = all_feats.std(dim=0)
    stds.append(std)
    print(f"Domain '{domain}': mean std = {std.mean().item():.4f}")
print(f"Overall mean std: {torch.mean(torch.stack(stds)).item():.4f}")

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
pretrain_metrics = pd.read_csv(
    # "../logs/phenotype_nsclc_train_baseline/runs/2025-05-12_21-45-49/csv/version_0/metrics.csv"
    "../logs/phenotype_nsclc_train_baseline/runs/2025-05-13_14-14-38/csv/version_0/metrics.csv"
)
pretrain_metrics

In [None]:
pretrain_metrics = pretrain_metrics.dropna(subset=["train/loss_epoch"])[
    ["epoch", "step", "train/loss_epoch", "train/lr_epoch"]
]
pretrain_metrics

In [None]:
sns.set(style="whitegrid")
sns.lineplot(data=pretrain_metrics, x="epoch", y="train/loss_epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Pretraining Loss vs. Epoch")
plt.show()

In [None]:
sns.set(style="whitegrid")
sns.lineplot(data=pretrain_metrics, x="epoch", y="train/lr_epoch")
plt.xlabel("Epoch")
plt.ylabel("LR")
plt.title("Pretraining LR vs. Epoch")
plt.show()

In [None]:
finetune_metrics = pd.read_csv(
    # "../logs/phenotype_nsclc_train_baseline/runs/2025-05-12_21-45-49/csv/version_1/metrics.csv"
    "../logs/phenotype_nsclc_train_baseline/runs/2025-05-13_15-13-30/csv/version_1/metrics.csv"
)
finetune_metrics

In [None]:
finetune_metrics_train = finetune_metrics.dropna(subset=["train/loss_epoch"])[
    ["epoch", "step", "train/loss_epoch", "train/lr_epoch"]
]
finetune_metrics_train

In [None]:
sns.set(style="whitegrid")
sns.lineplot(data=finetune_metrics_train, x="epoch", y="train/loss_epoch")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Finetuning Loss vs. Epoch")
plt.show()

In [None]:
sns.set(style="whitegrid")
sns.lineplot(data=finetune_metrics, x="epoch", y="train/lr_epoch")
plt.xlabel("Epoch")
plt.ylabel("LR")
plt.title("Finetuning LR vs. Epoch")
plt.show()

In [None]:
finetune_metrics_val = finetune_metrics.dropna(subset=["val/loss"])[
    ["epoch", "step", "val/loss", "val/f1", "val/accuracy", "val/precision", "val/recall"]
]
finetune_metrics_val

In [None]:
sns.set(style="whitegrid")
sns.lineplot(data=finetune_metrics_val, x="epoch", y="val/loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Finetuning Val Loss vs. Epoch")
plt.show()

In [None]:
sns.set(style="whitegrid")

sns.lineplot(data=finetune_metrics_val, x="epoch", y="val/loss", label="Val Loss")
sns.lineplot(data=finetune_metrics_val, x="epoch", y="val/f1", label="Val F1")
sns.lineplot(data=finetune_metrics_val, x="epoch", y="val/accuracy", label="Val Accuracy")
sns.lineplot(data=finetune_metrics_val, x="epoch", y="val/precision", label="Val Precision")
sns.lineplot(data=finetune_metrics_val, x="epoch", y="val/recall", label="Val Recall")

plt.xlabel("Epoch")
plt.ylabel("Metrics")
plt.title("Finetuning Metrics vs. Epoch")
plt.legend()
plt.show()