In [None]:
import hvplot.polars  # noqa
import hydra
import numpy as np
import pandas as pd
import polars as pl
import rootutils
import torch
from rich import print
from tqdm import tqdm

torch.set_float32_matmul_precision("high")
rootutils.setup_root("../", pythonpath=True)

In [None]:
# setting up paths
from pathlib import Path

from dotenv import dotenv_values

paths = dotenv_values()

In [None]:
paths

In [None]:
run_dir = Path(paths["RUN_DIR"])
run_dir

In [None]:
from src.data.vcc_embedding_module import VCCDataModule

In [None]:
with hydra.initialize("../config", version_base=None):
    conf = hydra.compose(
        "train",
        [
            "model=projection_model_base",
            "data=dataset_embedding",
            "trainer.devices=[3]",
            "trainer.precision=32",
            f"data.datamodule.gene_embedding_path={paths['DATA_DIR'] + '/gene_embeddings/poincare_go_gaf_logmapped_256.parquet'}",
        ],
    )

In [None]:
print(conf)

In [None]:
datamodule: VCCDataModule = hydra.utils.instantiate(conf.data.datamodule)
datamodule

In [None]:
# from src.models.vcc_lightning import VCCModule
from src.models.projectionvcc_lightning import VCCModule

# net = hydra.utils.instantiate(conf.model.net)
model = VCCModule.load_from_checkpoint(run_dir / "VCC_31_08_2025/00-52-18/last.ckpt")

In [None]:
model.eval()

In [None]:
model.contrastive_loss.gene_hyperbolic

In [None]:
model.compile(fullgraph=True, mode="reduce-overhead")

In [None]:
trainer = hydra.utils.instantiate(conf.trainer)

In [None]:
prediction = trainer.predict(model, datamodule)

In [None]:
len(prediction)

In [None]:
prediction[0][0]

In [None]:
latent = [pred[1] for pred in prediction]
prediction = [pred[0] for pred in prediction]
latent[:2]

# Studing performance

In [None]:
datamodule.setup("predict")

In [None]:
latent = torch.cat(latent)

In [None]:
y_pred = torch.load(run_dir / "VCC_29_08_2025/20-40-00/predictions.pt")

In [None]:
y_pred = torch.cat(prediction).to(torch.float32)
y_pred

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
genes, counts = np.unique(datamodule.test_data.perturbed_genes, return_counts=True)
genes, counts

In [None]:
np.median(counts)

In [None]:
torch.from_numpy(datamodule.test_data.gene_embeddings.values).corrcoef()

In [None]:
datamodule.test_data.gene_embeddings.loc[genes].T.corr().to_numpy()

In [None]:
from umap import UMAP

embedder = UMAP(n_neighbors=15)
embedding = embedder.fit_transform(y_pred)

In [None]:
fig, ax = plt.subplots(figsize=(9, 8))
embedding_df = pl.DataFrame(
    data={
        "umap-0": embedding[:, 0],
        "umap-1": embedding[:, 1],
        "gene": datamodule.test_data.perturbed_genes,
    }
).sort("gene")
avg_data = embedding_df.group_by("gene").agg(pl.all().mean())


sns.scatterplot(
    x=embedding[:, 0],
    y=embedding[:, 1],
    hue=datamodule.test_data.perturbed_genes,
    ax=ax,
    alpha=0.7,
)
handles, labels = plt.gca().get_legend_handles_labels()

sorted_pairs = sorted(zip(labels, handles))
sorted_labels = [label for label, handle in sorted_pairs]
sorted_handles = [handle for label, handle in sorted_pairs]
plt.legend(sorted_handles, sorted_labels, bbox_to_anchor=(1, 1), ncol=3)

for row in avg_data.iter_rows(named=True):
    ax.text(x=row["umap-0"], y=row["umap-1"], s=row["gene"], ha="center", va="center")

ax.set_title("Projection UMAP", size=12)

In [None]:
y_pred_linear = torch.exp(y_pred) - 1
y_pred_linear

In [None]:
y_pred_linear.var(dim=0)

In [None]:
y_pred_linear.mean(axis=0)

In [None]:
datamodule.test_data.perturbed_genes

In [None]:
from sklearn.metrics import (
    calinski_harabasz_score,
    davies_bouldin_score,
    silhouette_score,
)

In [None]:
silhouette_score(y_pred.cpu(), datamodule.test_data.perturbed_genes)

In [None]:
davies_bouldin_score(y_pred.cpu(), datamodule.test_data.perturbed_genes)

In [None]:
calinski_harabasz_score(y_pred.cpu(), datamodule.test_data.perturbed_genes)

# Trying out cell-eval

In [None]:
torch.cat(prediction).shape

In [None]:
# Save data so you dont have to do previous steps again and again
X = torch.cat(prediction).to(torch.float16).numpy()
print(X.shape)
# np.savez("../results/preds.npz", X)

In [None]:
X = np.load("../results/preds.npz")["arr_0"]  # Care
X

In [None]:
data_path = Path(paths["DATA_DIR"])

In [None]:
validation_data = pl.read_csv(data_path / "processed-data/pert_counts_Validation.csv")
validation_data

In [None]:
gene_names = pl.scan_parquet(conf.data.datamodule.control_data_path).collect_schema().names()[1:]
gene_names[:10]

In [None]:
pred_df = pl.DataFrame(X, schema=gene_names)
pred_df

In [None]:
target_genes = (
    validation_data.with_columns(target_gene_2=pl.col("target_gene").repeat_by("n_cells"))
    .select(pl.col("target_gene_2").explode().alias("target_gene"))
    .to_pandas()
)  # Generating the row annotations

In [None]:
# Processing non-targeting data

nt_data = pl.read_parquet(conf.data.datamodule.control_data_path)
nt_target_genes = (
    nt_data.with_columns(pl.lit("non-targeting").alias("target_gene"))
    .select("target_gene")
    .to_pandas()
)
nt_data = nt_data.select(pl.all().exclude("sample_index")).to_numpy().astype(np.float32)

In [None]:
X = np.vstack([X.astype(np.float32), nt_data])
target_genes = pd.concat([target_genes, nt_target_genes])

target_genes

In [None]:
import anndata as ad

In [None]:
result = ad.AnnData(
    X=X.astype(np.float32),  # So scipy.sparse doesn't cry
    var=pd.DataFrame(index=gene_names),
    obs=target_genes,
)
result

In [None]:
result.write_h5ad("../results/prediction_280825.h5ad")