# Evaluate ablation models on held-out genes

In [None]:
import numpy as np
import pandas as pd
import anndata
import os, sys

import torch
from tqdm import tqdm

sys.path.append('/code/decima/src/decima')
from read_hdf5 import HDF5Dataset, list_genes
from lightning import LightningModel
torch.set_float32_matmul_precision("medium")

## Paths

In [None]:
save_dir="/gstore/data/resbioai/grelu/decima/20240823"
matrix_file = os.path.join(save_dir, "aggregated.h5ad")
h5_file = os.path.join(save_dir, "data.h5")
ckpt_dir = os.path.join(save_dir, 'lightning_logs')

## Load test data

In [None]:
ad = anndata.read_h5ad(matrix_file)
ad = ad[:, ad.var.dataset=='test']
assert np.all(list_genes(h5_file, key='test') == ad.var_names.tolist())

In [None]:
ds = HDF5Dataset(
    key='test',
    h5_file=h5_file,
    ad=ad,
    seq_len=524288,
    max_seq_shift=0,
)

## Checkpoints

In [None]:
ckpts = {
    'original':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/kugrjb50/checkpoints/epoch=3-step=2920.ckpt',
    'no_borzoi':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/ojt3zkrg/checkpoints/epoch=13-step=10220.ckpt',
    'head_only':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/rvafp25k/checkpoints/epoch=11-step=8760.ckpt',
    'poisson_only':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/8opbecd4/checkpoints/epoch=6-step=5110.ckpt',
    'weight_1':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/qab9xk0h/checkpoints/epoch=8-step=6570.ckpt',
    'weight_1e-1':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/iquyvr35/checkpoints/epoch=4-step=3650.ckpt',
    'weight_1e-2':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/s3eaiu2q/checkpoints/epoch=14-step=10950.ckpt',
    'weight_1e-3':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/7430cfec/checkpoints/epoch=7-step=5840.ckpt',
    'weight_1e-5':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/8xnp7q4u/checkpoints/epoch=1-step=1460.ckpt',
    'weight_1e-6':'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/004n0rm4/checkpoints/epoch=2-step=2190.ckpt',
    'no_mask': '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/9c9pjysd/checkpoints/epoch=4-step=3650.ckpt',
}

## Test

In [None]:
for k, v in ckpts.items():
    model = LightningModel.load_from_checkpoint(v).eval()
    preds = model.predict_on_dataset(ds, devices=2, batch_size=10, num_workers=64).T
    assert preds.shape==ad.shape
    
    print(k)

    per_pb_corrs = [np.corrcoef(ad[i].X, preds[i])[0, 1] for i in range(ad.shape[0])]
    print(f"Mean Pearson Correlation per pseudobulk: True: {np.mean(per_pb_corrs).mean().round(2)}")

    per_gene_corrs = [np.corrcoef(ad.X[:, i], preds[:, i])[0, 1] for i in range(ad.shape[1])]
    print(f"Mean Pearson Correlation per gene: True: {np.mean(per_gene_corrs).mean().round(2)}")
    
    print("")
    del model

## Evaluate the model with cropping

In [None]:
model = LightningModel.load_from_checkpoint('/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/kugrjb50/checkpoints/epoch=3-step=2920.ckpt')

In [None]:
from decima_model import DecimaCropModel
model.model = DecimaCropModel(n_tasks=ad.shape[0])

In [None]:
ckpt = torch.load(
    '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/a1vj8bsi/checkpoints/last.ckpt')
model.load_state_dict(ckpt['state_dict'])
model = model.eval()

In [None]:
with torch.no_grad():
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        preds = model.predict_on_dataset(ds, devices=2, batch_size=10, num_workers=64).T
assert preds.shape==ad.shape

per_pb_corrs = [np.corrcoef(ad[i].X, preds[i])[0, 1] for i in range(ad.shape[0])]
print(f"Mean Pearson Correlation per pseudobulk: True: {np.mean(per_pb_corrs).mean().round(2)}")

per_gene_corrs = [np.corrcoef(ad.X[:, i], preds[:, i])[0, 1] for i in range(ad.shape[1])]
print(f"Mean Pearson Correlation per gene: True: {np.mean(per_gene_corrs).mean().round(2)}")