In [1]:
import os
import sys
import importlib
import math
import time

import itertools
from collections import defaultdict

import numpy as np
import pandas as pd
import scipy.stats as ss
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.auto import tqdm as tqdm_auto
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from IPython.display import clear_output

In [2]:
sys.path.append("../../model")
import utrdata_cl as utrdata
from legnet import LegNetClassifier
from pl_regressor import RNARegressor

In [3]:
UTR_TYPE = "utr3"
SEQSIZE = 50 if UTR_TYPE == "utr5" else 240

## Loading data

In [4]:
from ablation_utils import load_data

splits = load_data(utr_type=UTR_TYPE, prefix="..")

In [5]:
batch_size = 1024
steps_per_epoch = max(1, splits["train"].shape[0] // batch_size)

In [6]:
num_workers = 32

In [7]:
def launch_model(
    model_name: str,
    seed: int,
    train_ds_kws: dict,
    val_ds_kws: dict,
    model_class,
    model_kws: dict,
    criterion_class,
    criterion_kws: dict,
    optimizer_class,
    optimizer_kws: dict,
    lr_scheduler_class,
    lr_scheduler_kws: dict,
    test_time_validation: bool,
    epochs: int,
):
    pl.seed_everything(seed)

    # Creating Datasets
    train_set = utrdata.UTRData(
        df=splits["train"],
        **train_ds_kws,
    )
    val_set = utrdata.UTRData(
        df=splits["val"],
        **val_ds_kws,
    )
    test_set = utrdata.UTRData(
        df=splits["test"],
        **val_ds_kws,
    )

    assert train_set.num_channels == val_set.num_channels
    try:
        div_factor = val_ds_kws["augment_kws"]["shift_left"] + \
                     val_ds_kws["augment_kws"]["shift_right"] + 1
    except KeyError:
        div_factor = 1

    # Creating DataLoaders
    dl_train = DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True,
        drop_last=True
    )
    # dl_train = utrdata.DataLoaderWrapper(dl_train, batch_per_epoch=batch_per_epoch)
    dl_val = DataLoader(
        val_set,
        batch_size=batch_size // div_factor,
        num_workers=num_workers,
        shuffle=False,
        drop_last=False
    )
    dl_test = DataLoader(
        test_set,
        batch_size=batch_size // div_factor,
        num_workers=num_workers,
        shuffle=False,
        drop_last=False
    )

    model = RNARegressor(
        model_class=model_class,
        model_kws=model_kws | dict(
            in_channels=train_set.num_channels
        ),
        criterion_class=criterion_class,
        criterion_kws=criterion_kws,
        optimizer_class=optimizer_class,
        optimizer_kws=optimizer_kws,
        lr_scheduler_class=lr_scheduler_class,
        lr_scheduler_kws=lr_scheduler_kws,
        test_time_validation=test_time_validation,
    )
    model.model_name = model_name

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=f"saved_models/{UTR_TYPE}_{model_name}/seed={seed:02d}/",
        save_top_k=1,
        save_last=False,
        monitor="val_pearson_r_0",
        mode="max"
    )
    progressbar_callback = pl.callbacks.TQDMProgressBar(refresh_rate=0.5)

    logger = pl.loggers.tensorboard.TensorBoardLogger("tb_logs", name=model.model_name)
    trainer = pl.Trainer(
        callbacks=[checkpoint_callback, progressbar_callback],
        logger=logger,
        accelerator="gpu",
        devices=1,
        deterministic=True,
        max_epochs=epochs,
        num_sanity_val_steps=0,
        # gradient_clip_val=1e-5,
        # gradient_clip_algorithm="norm",
    )
    trainer.fit(model=model, train_dataloaders=dl_train, val_dataloaders=dl_val)
    best_model = RNARegressor.load_from_checkpoint(checkpoint_callback.best_model_path)

    prediction = trainer.predict(model=best_model, dataloaders=dl_test)
    test_pred, test_real = zip(*prediction)
    test_pred = torch.concat(test_pred).numpy()
    test_real = torch.concat(test_real).numpy()
    test_mass_center_pred = test_pred[:, -1]  # Last (or the only) column should contain the predicted mass center
    test_mass_center_real = test_real[:, -1]
    test_df = splits["test"].copy()
    assert np.allclose(test_df["mass_center"].values, test_mass_center_real)
    test_df["pred_mass_center"] = test_mass_center_pred

    metrics_ct = list()
    cell_types = ["all"]
    cell_types.extend(test_set.celltype_codes.keys())
    for ct in cell_types:
        if ct == "all":
            grouping = test_df.groupby("seq")
            real = grouping["mass_center"].mean()
            pred = grouping["pred_mass_center"].mean()
        else:
            ct_filter = test_df["cell_type"] == ct
            real = test_df.loc[ct_filter, "mass_center"]
            pred = test_df.loc[ct_filter, "pred_mass_center"]
        r = ss.pearsonr(pred, real)
        rho = ss.spearmanr(pred, real)
        metrics = {
            "model": model_name,
            "cell type": ct,
            "seed": seed,
            "pearsonr": r.statistic,
            "pearsonr_pvalue": r.pvalue,
            "spearmanr": rho.statistic,
            "spearmanr_pvalue": rho.pvalue,
        }
        metrics_ct.append(metrics)

    return metrics_ct

In [None]:
def parameter_ablation(
    model_name: str,
    parameter_update: dict,
):
    checked = {
        "seed": range(0, 30),
        "features": [
            # ("sequence", "revcomp", "intensity", "positional", "conditions"),
            ("sequence", "positional", "conditions"),
        ],
        "augment_dict": [
            dict(
                extend_left=0,
                extend_right=0,
                shift_left=0,
                shift_right=0,
                revcomp=False,
            ),
        ],
        "epochs": [10],
        "predict_cols": [['diff', 'mass_center']],
    }
    checked.update(parameter_update)

    METRICS = list()

    for subset in itertools.product(
        *checked.values()
    ):
        PARAMS = dict(zip(checked.keys(), subset))
        AUGMENT_KEY = any(PARAMS["augment_dict"].values())
        AUGMENT_TEST_TIME = AUGMENT_KEY

        metrics = launch_model(
            model_name=model_name,
            seed=PARAMS["seed"],
            train_ds_kws=dict(
                predict_cols=PARAMS["predict_cols"],
                construct_type=UTR_TYPE,
                features=PARAMS["features"],  # ("sequence", "conditions", "positional", "revcomp")
                augment=AUGMENT_KEY,
                augment_test_time=False,
                augment_kws=PARAMS["augment_dict"],
            ),
            val_ds_kws=dict(
                predict_cols=PARAMS["predict_cols"],
                construct_type=UTR_TYPE,
                features=PARAMS["features"],  # ("sequence", "conditions", "positional", "revcomp")
                augment=False,
                augment_test_time=AUGMENT_TEST_TIME,
                augment_kws=PARAMS["augment_dict"],
            ),
            model_class=LegNetClassifier,
            model_kws=dict(
                seqsize=SEQSIZE,
                ks=3,
                out_channels=PARAMS["predict_cols"].__len__(),
                conv_sizes=(128, 64, 64, 32, 32),
                mapper_size=256,
                linear_sizes=None,
                use_max_pooling=False,
                final_activation=nn.Identity
            ),
            criterion_class=nn.MSELoss,
            criterion_kws=dict(),
            optimizer_class=torch.optim.AdamW,
            optimizer_kws=dict(
                # lr=0.01,
                weight_decay=0.1,
            ),
            lr_scheduler_class=torch.optim.lr_scheduler.OneCycleLR,
            lr_scheduler_kws=dict(
                max_lr=0.01,
                steps_per_epoch=steps_per_epoch,
                epochs=PARAMS["epochs"],
                pct_start=0.3,
                three_phase=False,
                cycle_momentum=True,
            ),
            test_time_validation=AUGMENT_TEST_TIME,
            epochs=PARAMS["epochs"],
        )
        METRICS.append(metrics)
    return METRICS

## Launching ablation

In [None]:
ablation_tests = {
    "NoChanges": {},
    "RemovedDiff": {
        "predict_cols": [['mass_center']],
    },
    "RemovedPositionalEncoding": {
        "features": [
            ("sequence", "conditions"),
        ],
    },
    "RemovedDiffAndPositionalEncoding": {
        "predict_cols": [['mass_center']],
        "features": [
            ("sequence", "conditions"),
        ],
    },
}

ablation_metrics = list()
for test_name, update_dict in ablation_tests.items():
    iter_metrics = parameter_ablation(model_name=test_name, parameter_update=update_dict)
    ablation_metrics.extend(iter_metrics)

In [None]:
import json

with open(f"{UTR_TYPE}_ablation.json", "wt") as handle:
    json.dump(ablation_metrics, handle)

---