# Set-up

In [None]:
# autoreload
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import os
import yaml
import torch
import pandas as pd
import numpy as np
import xarray as xr
import torch.nn as nn

import seqpro as sp
import seqmodels as sm
import seqdata as sd
import seqexplainer as se

In [None]:
# Report cuda availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

In [None]:
# Change working directory
os.chdir("/cellar/users/aklie/projects/ML4GLand/SeqModels/use_cases/case_3/DeepSTARR")

In [None]:
# Set seed
np.random.seed(1234)

# Data

In [None]:
# Load SeqData
training_sdata = sd.open_zarr("/cellar/users/aklie/data/datasets/deAlmeida_DrosophilaS2_UMI-STARR-seq/training/2023_12_19/seqdatasets/deAlmeida22_training.zarr").load()
test_sdata = sd.open_zarr("/cellar/users/aklie/data/datasets/deAlmeida_DrosophilaS2_UMI-STARR-seq/training/2023_12_19/seqdatasets/deAlmeida22_test.zarr").load()

In [None]:
# Create a single target variable to predict
training_sdata["target"] = xr.concat([training_sdata["Dev_log2_enrichment"], training_sdata["Hk_log2_enrichment"]], dim="_targets").transpose("_sequence", "_targets")
test_sdata["target"] = xr.concat([test_sdata["Dev_log2_enrichment"], test_sdata["Hk_log2_enrichment"]], dim="_targets").transpose("_sequence", "_targets")

In [None]:
# Grab some test seqs
test_seqs = torch.tensor(training_sdata["ohe_seq"][:10].values, dtype=torch.float32)
test_dict = {"seq": test_seqs}

# Architecture

In [None]:
# Load the architecture with SeqModels
arch = sm.DeepSTARR(input_len=249, output_dim=2)
arch, arch(test_seqs).shape

# Training module

In [None]:
from seqmodels import Module

In [None]:
# Create module for training
module = Module(
    arch=arch,
    input_vars=["ohe_seq"],
    output_vars=["output"],
    target_vars=["target"],
    loss_fxn="mse",
    train_metrics_fxn=["r2", "pearson", "spearman"],
    val_metrics_fxn=["r2", "pearson", "spearman"],
    scheduler="reduce_lr_on_plateau",
)
module, module(test_dict).shape

# DataLoaders

In [None]:
# Split training into training and validation
train_sdata = training_sdata.sel(_sequence=(training_sdata["train_val"]==True).compute())
valid_sdata = training_sdata.sel(_sequence=(training_sdata["train_val"]==False).compute())
train_sdata.dims["_sequence"], valid_sdata.dims["_sequence"]

In [None]:
# Train dataloader
train_dl = sd.get_torch_dataloader(
    train_sdata.load(),
    sample_dims="_sequence",
    variables=["ohe_seq", "target"],
    batch_size=128,
    shuffle=True,
    num_workers=0,
    drop_last=False,
    pin_memory=True,
)
batch = next(iter(train_dl))
batch["ohe_seq"].shape, batch["target"].shape

In [None]:
# Validation dataloader
valid_dl = sd.get_torch_dataloader(
    valid_sdata.load(),
    sample_dims="_sequence",
    variables=["ohe_seq", "target"],
    batch_size=128,
    shuffle=False,
    num_workers=0,
    drop_last=False,
    pin_memory=True,
)
batch = next(iter(valid_dl))
batch["ohe_seq"].shape, batch["target"].shape

# Trainer

In [None]:
from pytorch_lightning import Trainer

In [None]:
# Logger
from pytorch_lightning.loggers import CSVLogger
logger = CSVLogger(save_dir="log", name="", version="")

In [None]:
# Add ModelCheckpoint, EarlyStopping and LearningRateMonitor callbacks
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
callbacks = [
ModelCheckpoint(
    dirpath=os.path.join(
        logger.save_dir, 
        logger.name, 
        logger.version, 
        "checkpoints"
    ),
    save_top_k=5,
    monitor="val_loss_epoch",
    mode="min",
),
    EarlyStopping(
        monitor="val_loss_epoch",
        patience=10,
        mode="min",
    ),
    LearningRateMonitor(),
]

In [None]:
# Trainer
trainer = Trainer(
    logger=logger,
    callbacks=callbacks,
    max_epochs=100,
    check_val_every_n_epoch=1,
)

# Fit

In [None]:
# Fit the weigths
trainer.fit(module, train_dl, valid_dl)

In [None]:
# Get the best model weights
best_model_path = trainer.checkpoint_callback.best_model_path
copy_path = os.path.join("best_model.ckpt")
os.system(f"cp {best_model_path} {copy_path}")

# Training Summary

In [None]:
from utils import training_summary

In [None]:
# Plot loss and metric curves
training_summary(logger.save_dir, logger="csv", metrics=["r2", "pearson", "spearman"], save="training_summary.png")

# Performance

In [None]:
from utils import scatter

In [None]:
# Load the best model weights
module = Module.load_from_checkpoint("best_model.ckpt", arch=arch).eval().cuda()

In [None]:
# Get predictions and targets as arrays
preds_dict = module.predict({"seq": test_sdata["ohe_seq"].values.astype("float32")})
preds = preds_dict["output"].cpu().numpy().squeeze()
targets = test_sdata["target"].values

In [None]:
# Save the predictions
df = pd.DataFrame({
    "Dev_log2_enrichment": targets[:, 0],
    "Hk_log2_enrichment": targets[:, 1],
    "pred_Dev_log2_enrichment": preds[:, 0],
    "pred_Hk_log2_enrichment": preds[:, 1],
})
df.to_csv("test_predictions.csv", index=False)

In [None]:
# Plot a nice blue color
scatter(
    x=targets[:, 0],
    y=preds[:, 0],
    c="#4682B4",
    alpha=0.8,
    xlabel="Experimental binding scores",
    ylabel="Predicted binding scores",
    density=True,
    rasterized=True,
    s=5,
    save="Dev_log2_enrichment_scatter.png",
)
scatter(
    x=targets[:, 1],
    y=preds[:, 1],
    c="#4682B4",
    alpha=0.8,
    xlabel="Experimental binding scores",
    ylabel="Predicted binding scores",
    density=True,
    rasterized=True,
    s=5,
    save="Hk_log2_enrichment_scatter.png",
)

# Attribution

In [None]:
from bpnetlite.attributions import hypothetical_attributions
from seqexplainer.attributions import plot_attribution_logo
from seqexplainer.attributions._references import k_shuffle_ref_inputs

In [None]:
# Need the number of sequences and number of references per sequence
n_seqs, n_refs = 100, 100

In [None]:
# Grab the sequences and references
seqs = test_sdata["ohe_seq"].values[:100]
refs = torch.tensor(k_shuffle_ref_inputs(seqs, k=2, n_per_input=n_refs), dtype=torch.float32)

In [None]:
# Reshape them to be compatible with Captum
inputs = torch.tensor(seqs, dtype=torch.float32).repeat_interleave(n_refs, dim=0)
baselines = refs.reshape(-1, *refs.shape[2:])

In [None]:
# Get hypothetical attributions
attrs = se.attribute(
    model=module.arch,
    inputs=inputs[:n_seqs*n_refs],
    method="DeepLift",
    references=baselines[:n_seqs*n_refs],
    target=0,
    batch_size=128,
    device="cuda",
    custom_attribution_func=hypothetical_attributions,
    hypothetical=True,
)

Computing attributions on batches of size 128:   0%|          | 0/79 [00:00<?, ?it/s]

In [None]:
# Get the average hypothetical attributions per sequence
attrs = torch.tensor(attrs, dtype=torch.float32)
attr_shape = (n_seqs, n_refs) + attrs.shape[1:]
attrs = torch.mean(attrs.view(attr_shape), dim=1, keepdim=False)

In [None]:
# Multiply by inputs
attrs = attrs.cpu() * seqs[:n_seqs]
attrs = attrs.numpy()

In [None]:
# Save ohe seqs as npz
np.savez_compressed("attributions/test_ohe.npz", seqs)
np.savez_compressed("attributions/test_shap.npz", attrs)

# DONE!

---