In [None]:
import hvplot.polars  # noqa
import hydra
import numpy as np
import pandas as pd
import polars as pl
import rootutils
import torch
from rich import print

torch.set_float32_matmul_precision("high")
rootutils.setup_root("../", pythonpath=True)

In [None]:
# setting up paths
from pathlib import Path

from dotenv import dotenv_values

paths = dotenv_values()

In [None]:
run_dir = Path(paths["RUN_DIR"])
run_dir

In [None]:
from src.data.vcc_datamodule import VCCDataModule
from src.models.vcc_lightning import VCCModule

In [None]:
with hydra.initialize("../config", version_base=None):
    conf = hydra.compose("train", ["model=model_attention"])

In [None]:
print(conf)

In [None]:
datamodule = hydra.utils.instantiate(conf.data.datamodule)
datamodule

In [None]:
net = hydra.utils.instantiate(conf.model.net)
model = VCCModule.load_from_checkpoint(run_dir / "VCC_02_08_2025/18-40/last.ckpt", net=net)

In [None]:
model.eval()

In [None]:
trainer = hydra.utils.instantiate(conf.trainer)

In [None]:
prediction = trainer.predict(model, datamodule)

In [None]:
len(prediction)

# Trying out cell-eval

In [None]:
# Save data so you dont have to do previous steps again and again
X = np.vstack(prediction).astype(np.float32)
print(X.shape)
np.savez("../results/preds.npz", X)

In [None]:
X = np.load("../results/preds.npz")["arr_0"]  # Care
X

In [None]:
data_path = Path(paths["DATA_DIR"])

In [None]:
validation_data = pl.read_csv(data_path / "processed-data/pert_counts_Validation.csv")
validation_data

In [None]:
gene_names = pl.scan_parquet(conf.data.datamodule.control_data_path).collect_schema().names()[1:]
gene_names[:10]

In [None]:
pred_df = pl.DataFrame(X, schema=gene_names)
pred_df

In [None]:
pred_df.hvplot.hist(y=["SAMD11", "CCNL2"], bins=50)

In [None]:
with open("../results/gene_names.csv", "w") as f:
    for gene in gene_names:
        f.write(f"{gene},\n")

In [None]:
target_genes = (
    validation_data.with_columns(target_gene_2=pl.col("target_gene").repeat_by("n_cells"))
    .select(pl.col("target_gene_2").explode().alias("target_gene"))
    .to_pandas()
)  # Generating the row annotations

In [None]:
# Processing non-targeting data

nt_data = pl.read_parquet(conf.data.datamodule.control_data_path)
nt_target_genes = (
    nt_data.with_columns(pl.lit("non-targeting").alias("target_gene"))
    .select("target_gene")
    .to_pandas()
)
nt_data = nt_data.select(pl.all().exclude("sample_index")).to_numpy().astype(np.float32)

In [None]:
X = np.vstack([X.astype(np.float32), nt_data])
target_genes = pd.concat([target_genes, nt_target_genes])

target_genes

In [None]:
import anndata as ad

In [None]:
result = ad.AnnData(
    X=X.astype(np.float32),  # So scipy.sparse doesn't cry
    var=pd.DataFrame(index=gene_names),
    obs=target_genes,
)
result

In [None]:
result.write_h5ad("../results/prediction_020825.h5ad")