In [1]:
import mpramnist
from mpramnist.Agarwal.dataset import AgarwalDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Agarwal

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import pytorch_lightning as L

In [2]:
NUM_EPOCHS = 50
BATCH_SIZE = 1024
NUM_WORKERS = 8
lr = 0.01

# HepG2

In [3]:
constant_left_flank = AgarwalDataset.CONSTANT_LEFT_FLANK # required for each sequence
constant_rigtht_flank = AgarwalDataset.CONSTANT_RIGHT_FLANK # required for each sequence
left_flank = AgarwalDataset.LEFT_FLANK # original flanks from human_legnet
right_flank = AgarwalDataset.RIGHT_FLANK

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [4]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(constant_left_flank, constant_rigtht_flank),
    t.AddFlanks("", right_flank), # this is original parameters for human_legnet
    t.RightCrop(230,250), # this is using for shifting
    t.CenterCrop(230),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5),

])
test_transform = t.Compose([ 
    t.AddFlanks(constant_left_flank, constant_rigtht_flank),
    t.Seq2Tensor(), 
    t.ReverseComplement(0), 

])

# load the data
train_dataset = AgarwalDataset(cell_type = "HepG2", 
                              split="train", # could use a list e.g. [1,2,5,6,7,8] for needed folds
                              transform=train_transform,
                              root = "../data/") 

val_dataset = AgarwalDataset(cell_type = "HepG2", 
                            split="val", # use "val" for default validation set or use list
                            transform=test_transform,
                            root = "../data/") 

test_dataset = AgarwalDataset(cell_type = "HepG2", 
                             split="test", # use "test" for default test set or use list
                             transform=test_transform,
                             root = "../data/") 

In [5]:
print(train_dataset)
print("------------")
print(test_dataset)

Dataset AgarwalDataset of size 98336 (MpraDaraset)
    Number of datapoints: 98336
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
------------
Dataset AgarwalDataset of size 12298 (MpraDaraset)
    Number of datapoints: 12298
    Used split fold: [10]


In [6]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [7]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [8]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model_HepG2 = LitModel_Agarwal(model = model,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-1, lr = 1e-2, print_each = 10)

## Train LegNet model on HepG2 cell type for 50 epochs

In [9]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=50,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
# Train the model
trainer.fit(seq_model_HepG2,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)
trainer.test(seq_model_HepG2, dataloaders = test_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290  


-------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.36528 | Val Pearson: 0.67847 | Train Pearson: 0.72798 
-------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.31451 | Val Pearson: 0.69521 | Train Pearson: 0.78115 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.33094 | Val Pearson: 0.71588 | Train Pearson: 0.84101 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.25184 | Val Pearson: 0.76680 | Train Pearson: 0.89410 
--------------------------------------------------------------------------------



`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.24526 | Val Pearson: 0.77403 | Train Pearson: 0.93257 
--------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.2384251356124878
      test_pearson           0.783002495765686
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.2384251356124878, 'test_pearson': 0.783002495765686}]

## Mean evaluation

In [11]:
from torchmetrics import PearsonCorrCoef

forw_transform = t.Compose([
    t.AddFlanks(constant_left_flank, constant_rigtht_flank),
    t.Seq2Tensor()
])
rev_transform = t.Compose([
    t.AddFlanks(constant_left_flank, constant_rigtht_flank),
    t.Seq2Tensor(),
    t.ReverseComplement(1)
])

def meaned_prediction(forw, rev, trainer, seq_model, name):

    predictions_forw = trainer.predict(seq_model, dataloaders = forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["predicted"] for pred in predictions_forw])
    
    predictions_rev = trainer.predict(seq_model, dataloaders = rev)
    y_preds_rev = torch.cat([pred["predicted"] for pred in predictions_rev])
    
    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
    
    pears = PearsonCorrCoef()
    print(name + " Pearson correlation")
    
    return pears(mean_forw, targets)

In [12]:
test_forw = AgarwalDataset(cell_type = "HepG2", split="test", transform=forw_transform, root = "../data/")
test_rev = AgarwalDataset(cell_type = "HepG2", split="test", transform=rev_transform, root = "../data/")

forw_hepg2 = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev_hepg2 = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw_hepg2, rev_hepg2, trainer, seq_model_HepG2, "HepG2")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


HepG2 Pearson correlation


tensor(0.8039)

# K562

In [13]:
# load the data
train_dataset = AgarwalDataset(cell_type = "K562", split="train", transform=train_transform,
                              root = "../data/") # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = AgarwalDataset(cell_type = "K562", split="val", transform=test_transform,
                            root = "../data/") # use "val" for default validation set or use list

test_dataset = AgarwalDataset(cell_type = "K562", split="test", transform=test_transform,
                             root = "../data/") # use "test" for default test set or use list

In [14]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [15]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model_K562 = LitModel_Agarwal(model = model,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-1, lr = 1e-2, print_each = 10)

In [16]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=50,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
# Train the model
trainer.fit(seq_model_K562,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)
trainer.test(seq_model_K562, dataloaders = test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
120       Modules in train mode
0         Modules in eval mode



-------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.12408 | Val Pearson: 0.75037 | Train Pearson: 0.76056 
-------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.11672 | Val Pearson: 0.76982 | Train Pearson: 0.79418 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.12490 | Val Pearson: 0.77975 | Train Pearson: 0.82812 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.09681 | Val Pearson: 0.81341 | Train Pearson: 0.87800 
--------------------------------------------------------------------------------



`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.09479 | Val Pearson: 0.81651 | Train Pearson: 0.91554 
--------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.0938987210392952
      test_pearson          0.8150753378868103
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.0938987210392952, 'test_pearson': 0.8150753378868103}]

## Mean evaluation

In [18]:
test_forw = AgarwalDataset(cell_type = "K562", split="test", transform=forw_transform, root = "../data/")
test_rev = AgarwalDataset(cell_type = "K562", split="test", transform=rev_transform, root = "../data/")

forw_k562 = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev_k562 = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw_k562, rev_k562, trainer, seq_model_K562, "K562")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


K562 Pearson correlation


tensor(0.8294)

# WTC11

In [19]:
# load the data
train_dataset = AgarwalDataset(cell_type = "WTC11", split="train", transform=train_transform,
                              root = "../data/") # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = AgarwalDataset(cell_type = "WTC11", split="val", transform=test_transform,
                            root = "../data/") # use "val" for default validation set or use list

test_dataset = AgarwalDataset(cell_type = "WTC11", split="test", transform=test_transform,
                             root = "../data/") # use "test" for default test set or use list

In [20]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [21]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model_WTC11 = LitModel_Agarwal(model = model,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-1, lr = 1e-2, print_each = 10)

In [22]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=50,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [23]:
# Train the model
trainer.fit(seq_model_WTC11,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)
trainer.test(seq_model_WTC11, dataloaders = test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/home/nios/miniconda3/envs/mpramnist/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (37) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
120       Modules in tr


-------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 1.21388 | Val Pearson: 0.60671 | Train Pearson: 0.71072 
-------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.84714 | Val Pearson: 0.63644 | Train Pearson: 0.78602 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.70407 | Val Pearson: 0.67494 | Train Pearson: 0.84055 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.65775 | Val Pearson: 0.69793 | Train Pearson: 0.89084 
--------------------------------------------------------------------------------



`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.64369 | Val Pearson: 0.70745 | Train Pearson: 0.92010 
--------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.6499876976013184
      test_pearson          0.6997539401054382
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.6499876976013184, 'test_pearson': 0.6997539401054382}]

## Mean evaluation

In [24]:
test_forw = AgarwalDataset(cell_type = "WTC11", split="test", transform=forw_transform, root = "../data/")
test_rev = AgarwalDataset(cell_type = "WTC11", split="test", transform=rev_transform, root = "../data/")

forw_wtc11 = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev_wtc11 = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw_wtc11, rev_wtc11, trainer, seq_model_WTC11, "WTC11")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


WTC11 Pearson correlation


tensor(0.7267)