In [1]:
import mpramnist
from mpramnist.Agarwal.dataset import AgarwalDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import pytorch_lightning as L

In [2]:
NUM_EPOCHS = 50
BATCH_SIZE = 1024
NUM_WORKERS = 8
lr = 0.01

In [3]:
left_flank = AgarwalDataset.LEFT_FLANK # original flanks from human_legnet
right_flank = AgarwalDataset.RIGHT_FLANK

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [4]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks("", right_flank), # this is original parameters for human_legnet
    t.RightCrop(230,250), # this is using for shifting
    t.CenterCrop(230),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5),

])
test_transform = t.Compose([ # трансформы теста слегка другие
    t.Seq2Tensor(), 
    t.ReverseComplement(0), # обратно комплементарная трансформация для всех последовательностей без веротяности 0.5

])

# load the data
train_dataset = AgarwalDataset(cell_type = "HepG2", 
                              split="train", 
                              transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = AgarwalDataset(cell_type = "HepG2", 
                            split="val", 
                            transform=test_transform) # use "val" for default validation set or use list
test_dataset = AgarwalDataset(cell_type = "HepG2", 
                             split="test", 
                             transform=test_transform) # use "test" for default test set or use list

In [5]:
print(train_dataset)
print("------------")
print(test_dataset)

Dataset AgarwalDataset of size 98336 (MpraDaraset)
    Number of datapoints: 98336
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
------------
Dataset AgarwalDataset of size 12298 (MpraDaraset)
    Number of datapoints: 12298
    Used split fold: [10]


In [15]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [6]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [17]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model = LitModel(model = model,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-1, lr = 1e-2, print_each = 5)

## Train LegNet model on HepG2 cell type for 50 epochs

In [18]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=30,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

# Train the model
trainer.fit(seq_model,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)
trainer.test(seq_model, dataloaders = test_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
120       Modules in train mode
0         Modules in eval mode



-------------------------------------------------------------------------------
| Epoch: 4 | Val Loss: 0.52628 | Val Pearson: 0.62366 | Train Pearson: 0.67534 
-------------------------------------------------------------------------------


-------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.37344 | Val Pearson: 0.66087 | Train Pearson: 0.73332 
-------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 14 | Val Loss: 0.34909 | Val Pearson: 0.65818 | Train Pearson: 0.76906 
--------------------------------------------------------------------------------


--------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.30849 | Val Pearson: 0.72334 | Train Pearson: 0.81470 
--------------------------------------------------------------------------------


-------------------------

`Trainer.fit` stopped: `max_epochs=30` reached.



--------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.25519 | Val Pearson: 0.76679 | Train Pearson: 0.90628 
--------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


[{'test_loss': 0.24528789520263672, 'test_pearson': 0.7782077789306641}]

## Lets train now for K562 cell type for 50 epochs

In [14]:
# load the data
train_dataset = AgarwalDataset(cell_type = "K562", split="train", transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = AgarwalDataset(cell_type = "K562", split="val", transform=test_transform) # use "val" for default validation set or use list

test_dataset = AgarwalDataset(cell_type = "K562", split="test", transform=test_transform) # use "test" for default test set or use list

In [None]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)