In [44]:
import os
import sys
import logging

sys.path.append("../")
# put us into the base directory
if os.path.basename(os.getcwd()) == "notebooks":
    os.chdir("../")

import hydra
from omegaconf import DictConfig, OmegaConf
import wandb
import pandas as pd
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from src.utils.preprocess import preprocess_data, get_tokens, validate_test_data
from src.utils.utils import set_seed, setup_logger, count_parameters 
from src.training.data_module import YeastDataModule
from src.models.bpnet import BPNet
from src.models.transformer_lora import LoraBPNet
from src.models.lora import Lora
from src.training.loss import TotalBPNetLoss
import seaborn as sns
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2
set_seed(42)
configs = ['config', 'lora', 'llm_bpnet']
names = ['bpnet.ckpt', 'lora.ckpt', 'llm_bpnet.ckpt']
model_metrics = {}
os.getcwd()

Seed set to 42


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


'/data/ceph/hdd/project/node_07/ml4rg_students/2024/Project07_PolyB/Regulate-Me'

In [45]:
def load_config(config_path: str, config_name: str) -> DictConfig:
    with hydra.initialize(config_path=config_path, version_base="1.3"):
        config = hydra.compose(config_name=config_name)
    return config


def summarize_metrics(metrics):
    summary = {}
    for key, values in metrics.items():
        if isinstance(values, list) and values and isinstance(values[0], (int, float)):
            summary[key] = sum(values) / len(values)
        else:
            summary[key] = values
    return summary


def print_metrics(metrics):
    print("\nMetrics Summary:")
    metric_rows = []
    for key, values in metrics.items():
        if isinstance(values, list) and values and isinstance(values[0], (int, float)):
            print(f"{key}:")
            for i, value in enumerate(values):
                print(f"  Track {i+1}: {value:.4f}")
                metric_rows.append({"metric": key, "track": i + 1, "value": value})
        else:
            print(f"{key}: {values}")
            metric_rows.append({"metric": key, "track": None, "value": values})
    # print()
    df_metrics = pd.DataFrame(metric_rows)
    return df_metrics

In [46]:
# Function to process each model
def process_model(config_name, model_name):
    config = load_config(config_path="../configs", config_name=config_name)
    set_seed(config.seed)

    counts = torch.load(os.path.join(config.data.data_dir, config.data.counts_file))
    dataset = pd.read_parquet(
        os.path.join(config.data.data_dir, config.data.preprocessed_file)
    )

    train_idx, val_idx, test_idx, one_hots, counts, dataset = preprocess_data(
        dataset,
        counts,
        config.data.restrict_seq_len,
        config.data.seq_col,
        set(config.data.val_chroms),
        set(config.data.test_chroms),
    )

    if config.model_name != "bpnet":
        output_seq_len = 300 - config.tokenizer.kmer_size + 1
        counts = counts[:, :, :output_seq_len]
        data = get_tokens(
            dataset,
            config.tokenizer.stride,
            config.data.seq_col,
            config.tokenizer.kmer_size,
        )
    else:
        data = one_hots

    data_module = YeastDataModule(
        batch_size=config.training.batch_size,
        train_idx=train_idx,
        val_idx=val_idx,
        test_idx=test_idx,
        data=data,
        counts=counts,
    )

    loss_fn = TotalBPNetLoss(
        alpha=config.loss.alpha,
        beta=config.loss.beta,
        profile_loss_type=config.loss.profile_loss_type,
        eps=config.loss.eps,
    )

    model_path = os.path.join(config.data.best_model_path, model_name)
    if config.model_name == "bpnet":
        model = BPNet.load_from_checkpoint(model_path)
    elif config.model_name == "llm_bpnet":
        model = LoraBPNet.load_from_checkpoint(model_path)
    elif config.model_name == "lora":
        model = Lora.load_from_checkpoint(model_path)
    else:
        raise ValueError(f"Model {config.model_name} not recognized")

    model.eval()
    data_module.setup()
    test_data = data_module.test_dataloader()
    metrics, loss = validate_test_data(model, test_data, loss_fn)

    summarized_metrics = summarize_metrics(metrics)
    df_metrics = pd.DataFrame(summarized_metrics, index=[config.model_name])
    df_metrics["test_loss"] = loss


    return df_metrics

In [47]:
# Loop over models and collect metrics
def collect_all_metrics(configs, names):
    all_metrics = pd.DataFrame()
    for model_name, name in zip(configs, names):
        df_metrics = process_model(model_name, name)
        all_metrics = pd.concat([all_metrics, df_metrics])
        print(f"Model processed: {model_name} ")
    return all_metrics

## Test results

In [48]:
all_model_metrics = collect_all_metrics(configs, names)
all_model_metrics

Seed set to 42
Seed set to 42


Model processed: config 


Map (num_proc=4):   0%|          | 0/6580 [00:00<?, ? examples/s]

test


Seed set to 42


Model processed: lora 


Map (num_proc=4):   0%|          | 0/6580 [00:00<?, ? examples/s]

test
Model processed: llm_bpnet 


Unnamed: 0,count_r2,profile_pearson_median,profile_pearson_mean,profile_auprc,profile_auroc,test_loss
bpnet,0.037942,0.730465,0.68194,0.604532,0.920428,939.202738
lora,0.103778,0.809382,0.738675,0.639585,0.931284,711.48426
llm_bpnet,0.028984,0.771082,0.70298,0.623123,0.925758,844.047709


In [49]:
all_model_metrics.drop(
    columns=["count_r2"], inplace=True
)

In [50]:
all_model_metrics.columns = ["Pearson Median", "Pearson Mean" ,"AUPRC", "AUROC" ,"Test Loss"]
all_model_metrics.index = ["BPNet", "SpeciesLM + LoRA", "SpeciesLM"]
all_model_metrics


Unnamed: 0,Pearson Median,Pearson Mean,AUPRC,AUROC,Test Loss
BPNet,0.730465,0.68194,0.604532,0.920428,939.202738
SpeciesLM + LoRA,0.809382,0.738675,0.639585,0.931284,711.48426
SpeciesLM,0.771082,0.70298,0.623123,0.925758,844.047709


In [51]:
# make folder figures
os.makedirs("./data/figures", exist_ok=True)
all_model_metrics.to_latex("./data/figures/all_model_metrics.tex", index=True, float_format="%.3f")