In [2]:
import mpramnist
from mpramnist.Agarwal.dataset import AgarwalDataset
from mpramnist.Kircher.dataset import KircherDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Kircher

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import pandas as pd
import numpy as np
import torch
import torch.utils.data as data
import pytorch_lightning as L

from torchmetrics import PearsonCorrCoef

In [3]:
BATCH_SIZE = 1024
NUM_WORKERS = 8

In [4]:
# original flanks from human_legnet
left_flank = AgarwalDataset.LEFT_FLANK 
right_flank = AgarwalDataset.RIGHT_FLANK

# Pretrain on Agarwal's data

In [5]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks("", right_flank), # this is original parameters for human_legnet
    t.RightCrop(230,250),
    t.CenterCrop(230),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5),
])
test_transform = t.Compose([ # test transforms are slightly different
    t.Seq2Tensor(), 
    t.ReverseComplement(0), # the Reverse complementary transformation is applied deterministically, not randomly with a 50% chance
])

# load the data
train_dataset = AgarwalDataset(cell_type = "HepG2", 
                              split="train", 
                              transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                            # for needed folds
val_dataset = AgarwalDataset(cell_type = "HepG2", 
                            split="val", 
                            transform=test_transform) # use "val" for default validation set or use list
test_dataset = AgarwalDataset(cell_type = "HepG2", 
                             split="test", 
                             transform=test_transform) # use "test" for default test set or use list

In [6]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [7]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [8]:
# Default model is HumanLegNet
seq_model = LitModel_Kircher(in_ch = in_channels, out_ch = out_channels, weight_decay = 1e-1, lr = 1e-2, print_each = 5)

In [8]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=10,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

# Train the model
trainer.fit(seq_model,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)
trainer.test(seq_model, dataloaders = test_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-04-05 17:10:53.819124: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-05 17:10:53.834125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one 


-----------------------------------------------------------------------------------------------------
| Epoch: 4 | Val Loss: 0.35766 | Val Pearson: 0.65240 | Train Loss: 0.33538 | Train Pearson: 0.69676 
-----------------------------------------------------------------------------------------------------



`Trainer.fit` stopped: `max_epochs=10` reached.



-----------------------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.26829 | Val Pearson: 0.74576 | Train Loss: 0.20822 | Train Pearson: 0.82842 
-----------------------------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


[{'test_loss': 0.2599364221096039, 'test_pearson': 0.7570507526397705}]

# Evaluation

In [9]:
def meaned_prediction(forw, rev, trainer, seq_model, is_kircher = False):

    predictions_forw = trainer.predict(seq_model, dataloaders = forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["ref_predicted"] for pred in predictions_forw])
    
    predictions_rev = trainer.predict(seq_model, dataloaders = rev)
    y_preds_rev = torch.cat([pred["ref_predicted"] for pred in predictions_rev])
    
    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
    
    pears = PearsonCorrCoef()
    if is_kircher:
        y_preds_forw_alt = torch.cat([pred["alt_predicted"] for pred in predictions_forw])
        y_preds_rev_alt = torch.cat([pred["alt_predicted"] for pred in predictions_rev])
        mean_alt = torch.mean(torch.stack([y_preds_forw_alt, y_preds_rev_alt]), dim=0)
        pred = mean_alt - mean_forw
        return pears(pred, targets) 
        
    return pears(mean_forw, targets)

In [10]:
forw_transform = t.Compose([
    t.Seq2Tensor()
])
rev_transform = t.Compose([
    t.Seq2Tensor(),
    t.ReverseComplement(1)
])
test_forw = AgarwalDataset(cell_type = "HepG2", split="test", transform=forw_transform)
test_rev = AgarwalDataset(cell_type = "HepG2", split="test", transform=rev_transform)

forw_hepg2 = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev_hepg2 = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw_hepg2, rev_hepg2, trainer, seq_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


tensor(0.7719)

# Kircher dataset examination

In [12]:
promoters_for_hepg2 = ["F9","LDLR.2","LDLR","SORT1.2","SORT1","SORT1-flip"]
promoters_for_k562 = ["PKLR-24h","PKLR-48h"]

small quality

In [20]:
kircher_dataset_forw = KircherDataset(length = 230, promoters =  ["LDLR"], transform = forw_transform)
kircher_dataset_rev = KircherDataset(length = 230, promoters =  ["LDLR"], transform = rev_transform)  

kircher_forw = data.DataLoader(dataset = kircher_dataset_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
kircher_rev = data.DataLoader(dataset = kircher_dataset_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
    
meaned_prediction(kircher_forw, kircher_rev, trainer, seq_model, is_kircher = True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


tensor(0.5170)

In [25]:
kircher_model = LitModel_Kircher.load_from_checkpoint("../datasets/KircherDataset/best_model_test10_val9.ckpt",
                                                      weight_decay = 1e-1, lr = 1e-2)
meaned_prediction(forw_hepg2, rev_hepg2, trainer, kircher_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


tensor(0.7965)

In [26]:
meaned_prediction(kircher_forw, kircher_rev, trainer, kircher_model, is_kircher = True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


tensor(0.5929)