In [None]:
import hvplot.polars  # noqa
import hydra
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import rootutils
import seaborn as sns
import torch
from rich import print
from tqdm import tqdm

torch.set_float32_matmul_precision("high")
rootutils.setup_root("../", pythonpath=True)

In [None]:
# setting up paths
from pathlib import Path

from dotenv import dotenv_values

paths = dotenv_values()

In [None]:
paths

In [None]:
run_dir = Path(paths["RUN_DIR"])
run_dir

In [None]:
from src.data.vcc_embedding_module import VCCDataModule

In [None]:
with hydra.initialize("../config", version_base=None):
    conf = hydra.compose(
        "train",
        [
            "model=model_film_full",
            "data=dataset_embedding",
            # "trainer.devices=[3]",
            "trainer.precision=32",
            f"data.datamodule.gene_embedding_path={paths['DATA_DIR'] + '/gene_embeddings/poincare_go_gaf_logmapped_256.parquet'}",
        ],
    )

In [None]:
print(conf)

In [None]:
datamodule: VCCDataModule = hydra.utils.instantiate(conf.data.datamodule)
datamodule

In [None]:
checkpoint = torch.load(run_dir / "VCC_17_09_2025/22-02-39/last.ckpt", weights_only=False)

In [None]:
model = hydra.utils.instantiate(conf.model)

In [None]:
model.load_state_dict(checkpoint["state_dict"])

In [None]:
# from src.models.vcc_lightning import VCCModule
from src.models.projectionvcc_lightning import VCCModule

# net = hydra.utils.instantiate(conf.model.net)
model = VCCModule.load_from_checkpoint(run_dir / "VCC_11_09_2025/14-05-05/last.ckpt")

In [None]:
[param for param in model.parameters()]

In [None]:
model.eval()

In [None]:
model.compile(fullgraph=True, mode="reduce-overhead")

In [None]:
trainer = hydra.utils.instantiate(conf.trainer)

In [None]:
prediction = trainer.predict(model, datamodule)

In [None]:
len(prediction)

In [None]:
prediction[0][0]

In [None]:
latent = [pred[1] for pred in prediction]
prediction = [pred[0] for pred in prediction]
latent[:2]

# Studing performance

In [None]:
y_pred = torch.load(run_dir / "VCC_07_09_2025/16-09-10/predictions.pt")

In [None]:
datamodule.setup("predict")

In [None]:
y_pred = torch.cat(prediction).to(torch.float32)
y_pred

In [None]:
y_pred = torch.relu(y_pred)

In [None]:
genes, counts = np.unique(datamodule.test_data.perturbed_genes, return_counts=True)
genes, counts

In [None]:
np.median(counts)

In [None]:
y_pred.std(dim=0)

In [None]:
from src.utils.umap_utilities import Turbo256, perform_umap, plot_output_plotly

In [None]:
spread_palette = np.random.choice([Turbo256[i * 5] for i, _ in enumerate(genes)], size=len(genes))

In [None]:
embedding_df = perform_umap(y_pred, genes=datamodule.test_data.perturbed_genes, n_jobs=10)
fig = plot_output_plotly(embedding_df)
fig

In [None]:
projection_df = perform_umap(
    torch.cat(latent),
    genes=datamodule.test_data.perturbed_genes,
)
plot_output_plotly(projection_df)

In [None]:
topk_indices = (y_pred - datamodule.test_data.control_expression).abs().topk(5)[1].numpy().tolist()
topk_indices[:10]

In [None]:
top_genes = []
genes_all = np.asarray(list(datamodule.test_data.gene_embeddings.keys()))

In [None]:
for _ in tqdm(topk_indices):
    top_genes.append(genes_all[_].tolist())

In [None]:
top_genes_df = pl.DataFrame(
    {"top_genes": top_genes, "target_genes": datamodule.test_data.perturbed_genes}
)
top_genes_df

In [None]:
top_genes_df["target_genes"].value_counts()

In [None]:
processed_top_df = (
    top_genes_df.explode("top_genes")
    .group_by("target_genes")
    .agg(pl.col("top_genes"))
    .join(
        top_genes_df["target_genes"].value_counts().rename({"count": "num_samples"}),
        on="target_genes",
    )
    .explode("top_genes")
    .group_by("target_genes")
    .agg(pl.col("top_genes").value_counts(), "num_samples")
    .explode("top_genes")
    .with_columns(
        pl.col("top_genes").struct.unnest(),
    )
    .with_columns(
        fraction=pl.col("count") / pl.col("num_samples").list.mean(),
        num_samples=pl.col("num_samples").list.mean(),
    )
    .sort("fraction", descending=True)
    # .filter(pl.col('fraction') > 0.25)
)
processed_top_df

In [None]:
processed_top_df["fraction"].quantile(0.5)

In [None]:
import altair as alt

alt.data_transformers.enable("vegafusion")
alt.theme.enable("latimes")

In [None]:
alt.Chart(processed_top_df).mark_bar().encode(
    alt.X("fraction", title="Fraction of samples").bin(maxbins=50),
    alt.Y("count()").scale(type="log"),
) + alt.Chart().mark_rule().encode(x=alt.X(datum=processed_top_df["fraction"].quantile(0.95)))

In [None]:
filtered_top_genes = processed_top_df.filter(pl.col("count") > pl.col("count").quantile(0.95))
filtered_top_genes

In [None]:
small_df = embedding_df.sample(5000)
brush = alt.selection_interval(zoom=False)
selector = alt.selection_point(fields=["gene"])
selector_legend = alt.selection_point(fields=["gene"], bind="legend")

scatter = (
    alt.Chart(small_df, title=alt.Title("UMAP plot", anchor="middle"))
    .mark_circle()
    .encode(
        alt.X("UMAP-0:Q"),
        alt.Y("UMAP-1:Q"),
        color=alt.when(selector_legend)
        .then(
            alt.Color("gene:N", scale=alt.Scale(range=spread_palette), title="Target gene").legend(
                columns=3, symbolLimit=50
            )
        )
        .otherwise(alt.value("lightgray")),
        opacity=alt.when(selector_legend).then(alt.value(0.91)).otherwise(alt.value(0.05)),
        tooltip=["gene"],
    )
    .add_params(brush, selector_legend)
    .transform_filter(
        selector,
        # selector_legend
    )
)

composition = (
    alt.Chart(small_df, title=alt.Title("Composition", anchor="middle"))
    .mark_arc(innerRadius=70, outerRadius=170)
    .encode(theta="count(gene):Q", color=("gene:N"), tooltip=["gene"])
    .transform_filter(brush)
    .add_params(selector)
)

# text =  composition.mark_text(radius=140, size=20, align='center').encode(text='gene:N', color=alt.value('black'))

(scatter.properties(width=500, height=500) | composition.properties(width=500, height=500))

In [None]:
alt.Chart(filtered_top_genes).mark_circle().encode(
    alt.X("top_genes:N", axis=alt.Axis(labelAngle=45)),
    alt.Y("target_genes:O").sort(field="fraction", op="median", order="descending"),
    color=alt.Color("fraction:N").bin(maxbins=7),
    size=alt.Size("fraction:Q"),
    tooltip=["top_genes", "target_genes", "fraction"],
).properties(width=1200)

In [None]:
processed_top_df

In [None]:
matrix_df = processed_top_df.pivot(
    on="top_genes", values="fraction", index="target_genes"
).fill_null(0)
matrix_df

In [None]:
filtered_top_genes

In [None]:
filtered_top_genes.group_by("target_genes").agg(
    pl.all(), group_fraction=pl.col("fraction").median()
)

In [None]:
row_selection = alt.selection_point(fields=["target_genes"], on="click", empty="none")

color_selection = (
    alt.when(row_selection)
    .then(alt.value("black"))
    .otherwise(
        alt.Color(
            "fraction:Q",
            bin=True,
        ).scale(scheme="viridis")
    )
)

base = (
    alt.Chart(filtered_top_genes)
    .mark_rect()
    .encode(
        alt.X("top_genes", axis=alt.Axis(labelAngle=45)).sort(
            field="fraction", op="median", order="descending"
        ),
        alt.Y("target_genes"),
        color=color_selection,
        tooltip=["fraction"],
        # opacity=alt.condition(row_selection, alt.value(0.7), alt.value(1)) # Make transparent if not selected
        # text = alt.Text('fraction', bin=True)
    )
    .add_params(row_selection)
)


base

In [None]:
model.net.last_attn_weights.mean(0)

In [None]:
plt.imshow(model.net.last_attn_weights.mean(0))
plt.colorbar()

In [None]:
model.net.last_attn_weights.mean(0)

In [None]:
from sklearn.metrics import (
    calinski_harabasz_score,
    davies_bouldin_score,
    silhouette_score,
)

In [None]:
silhouette_score(y_pred.cpu(), datamodule.test_data.perturbed_genes)

In [None]:
davies_bouldin_score(y_pred.cpu(), datamodule.test_data.perturbed_genes)

In [None]:
calinski_harabasz_score(y_pred.cpu(), datamodule.test_data.perturbed_genes)

In [None]:
[param for param in model.net.perturbation_effect.named_parameters()][0][1]

In [None]:
model.net.b1

# Trying out cell-eval

In [None]:
torch.cat(prediction).shape

In [None]:
# Save data so you dont have to do previous steps again and again
X = torch.cat(prediction).relu().to(torch.float16).numpy()
print(X.shape)
# np.savez("../results/preds.npz", X)

In [None]:
X = np.load("../results/preds.npz")["arr_0"]  # Care
X

In [None]:
data_path = Path(paths["DATA_DIR"])

In [None]:
validation_data = pl.read_csv(data_path / "processed-data/pert_counts_Validation.csv")
validation_data

In [None]:
gene_names = pl.scan_parquet(conf.data.datamodule.control_data_path).collect_schema().names()[1:]
gene_names[:10]

In [None]:
pred_df = pl.DataFrame(X, schema=gene_names)
pred_df

In [None]:
target_genes = (
    validation_data.with_columns(target_gene_2=pl.col("target_gene").repeat_by("n_cells"))
    .select(pl.col("target_gene_2").explode().alias("target_gene"))
    .to_pandas()
)  # Generating the row annotations

In [None]:
# Processing non-targeting data

nt_data = pl.read_parquet(conf.data.datamodule.control_data_path)
nt_target_genes = (
    nt_data.with_columns(pl.lit("non-targeting").alias("target_gene"))
    .select("target_gene")
    .to_pandas()
)
nt_data = nt_data.select(pl.all().exclude("sample_index")).to_numpy().astype(np.float32)

In [None]:
X = np.vstack([X.astype(np.float32), nt_data])
target_genes = pd.concat([target_genes, nt_target_genes])

target_genes

In [None]:
import anndata as ad

In [None]:
result = ad.AnnData(
    X=X.astype(np.float32),  # So scipy.sparse doesn't cry
    var=pd.DataFrame(index=gene_names),
    obs=target_genes,
)
result

In [None]:
y_pred.sum(-1)

In [None]:
result.write_h5ad("../results/prediction_150925.h5ad")