In [None]:
import hvplot.polars  # noqa
import hydra
import lightning
import numpy as np
import pandas as pd
import polars as pl
import rootutils
import torch
from rich import print

torch.set_float32_matmul_precision("high")
rootutils.setup_root("../", pythonpath=True)

In [None]:
# setting up paths
from pathlib import Path

from dotenv import dotenv_values

paths = dotenv_values()

In [None]:
run_dir = Path(paths["RUN_DIR"])
run_dir

In [None]:
from src.data.vcc_embedding_module import VCCDataModule
from src.models.vcc_lightning import VCCModule

In [None]:
with hydra.initialize("../config", version_base=None):
    conf = hydra.compose(
        "train",
        [
            "model=worst_model_attention",
            "data=dataset",
            "model.net.ko_processor_args.input_size=18080",
        ],
    )

In [None]:
print(conf.model)

In [None]:
datamodule: VCCDataModule = hydra.utils.instantiate(conf.data.datamodule)
# datamodule.setup(stage='fit')

In [None]:
net = hydra.utils.instantiate(conf.model.net)
model = VCCModule.load_from_checkpoint(run_dir / "VCC_07_08_2025/20-41/last.ckpt", net=net)

In [None]:
model.eval()

In [None]:
trainer = hydra.utils.instantiate(conf.trainer, max_steps=1)

In [None]:
model.net.attention.mhattention.in_proj_weight.shape

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary

summary = ModelSummary(model, max_depth=-1)
summary

In [None]:
acitvations = {}
weights = {}


def hook_fn(module, input, output):
    print("Are you working?", flush=True)
    acitvations[module] = output.detach()

In [None]:
target_layer = model.net.attention
hook = target_layer.register_forward_hook(hook_fn)

In [None]:
for name, module in model.named_modules():
    print(name, module)
    break

In [None]:
for batch in datamodule.predict_dataloader():
    break

In [None]:
batch = batch[0]

In [None]:
model.to(torch.float16);

In [None]:
model(batch)

In [None]:
acitvations

In [None]:
hook

# Studing performance

In [None]:
y_pred = torch.cat([1, 2])
y_pred

In [None]:
y_pred_linear = (torch.exp(y_pred) - 1).to(torch.float16)
y_pred_linear

In [None]:
y_pred_linear.var(dim=0).sort()

In [None]:
y_pred_linear.mean(axis=0)

In [None]:
datamodule.test_data.perturbed_genes

In [None]:
datamodule.setup()

In [None]:
for batch in datamodule.train_dataloader():
    ko_vec_var = batch[0]["ko_vec"].var(dim=0)
    exp_vec_var = batch[0]["exp_vec"].var(dim=0)

    print(f"KO var: {ko_vec_var}, Expression var:{exp_vec_var}")
    break

In [None]:
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: {param.grad.data.norm()}")

In [None]:
from lightning.pytorch.utilities.model_summary import ModelSummary

summary = ModelSummary(model, -1)
print(summary)

# Try overfitting

In [None]:
# from lightning.pytorch import Trainer
# from lightning.pytorch.utilities import grad_norm

In [None]:
# def track_grad(optimizer=None):
#     print("Self:", self)
#     print("optimizer:", optimizer)
#     self.log_dict(grad_norm(self, norm_type=2))

# model.on_before_optimizer_step = track_grad

In [None]:
conf.logging

In [None]:
wandblogger = hydra.utils.instantiate(conf.logging.wandb, tags=["debug"])

In [None]:
wandblogger

In [None]:
trainer2 = hydra.utils.instantiate(
    conf.trainer, logger=wandblogger, max_epochs=30, enable_checkpointing=False
)
trainer2

In [None]:
trainer2.fit(model, datamodule)