In [1]:
import mpramnist
from mpramnist.vikramdataset import VikramDataset

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import logging
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import TQDMProgressBar

In [2]:
NUM_EPOCHS = 50
BATCH_SIZE = 1024
NUM_WORKERS = 8
lr = 0.01

In [3]:
left_flank = VikramDataset.LEFT_FLANK # original flanks from human_legnet
right_flank = VikramDataset.RIGHT_FLANK

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [7]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks("", right_flank), # this is original parameters for human_legnet
    t.RightCrop(230,250), # this is using for shifting
    t.CenterCrop(230),
    t.Seq2Tensor(),
    t.Reverse(0.5),

])
test_transform = t.Compose([ # трансформы теста слегка другие
    t.Seq2Tensor(), 
    t.Reverse(0), # обратно комплементарная трансформация для всех последовательностей без веротяности 0.5

])

# load the data
train_dataset = VikramDataset(cell_type = "HepG2", 
                              split="train", 
                              transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = VikramDataset(cell_type = "HepG2", 
                            split="val", 
                            transform=test_transform) # use "val" for default validation set or use list
test_dataset = VikramDataset(cell_type = "HepG2", 
                             split="test", 
                             transform=test_transform) # use "test" for default test set or use list

In [8]:
print(train_dataset)
print("------------")
print(test_dataset)

Dataset VikramDataset of size 36946 (MpraDaraset)
    Number of datapoints: 36946
    Default split folds: {'train': '1, 2, 3, 4, 5, 6, 7, 8', 'val': 9, 'test': 10}
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
    Scalar features: {}
    Vector features: {}
    Cell types: ['HepG2', 'K562', 'WTC11']
    Сell type used: WTC11
    Target columns that can be used: {'averaged_expression', 'expression'}
    Sequence size: 230
    Number of samples: {'train': 98336, 'val': 12292, 'test': 12292}
------------
Dataset VikramDataset of size 4622 (MpraDaraset)
    Number of datapoints: 4622
    Default split folds: {'train': '1, 2, 3, 4, 5, 6, 7, 8', 'val': 9, 'test': 10}
    Used split fold: [10]
    Scalar features: {}
    Vector features: {}
    Cell types: ['HepG2', 'K562', 'WTC11']
    Сell type used: WTC11
    Target columns that can be used: {'averaged_expression', 'expression'}
    Sequence size: 230
    Number of samples: {'train': 98336, 'val': 12292, 'test': 12292}


In [None]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [6]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

def initialize_weights(m):
    if isinstance(m, nn.Conv1d):
        n = m.kernel_size[0] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2 / n))
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.001)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

class SELayer(nn.Module):
    def __init__(self, inp, reduction=4):
        super(SELayer, self).__init__()
        self.fc = nn.Sequential(
                nn.Linear(inp, int(inp // reduction)),
                nn.SiLU(),
                nn.Linear(int(inp // reduction), inp),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, = x.size()
        y = x.view(b, c, -1).mean(dim=2)
        y = self.fc(y).view(b, c, 1)
        return x * y

class EffBlock(nn.Module):
    def __init__(self, in_ch, ks, resize_factor, activation, out_ch=None, se_reduction=None):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = self.in_ch if out_ch is None else out_ch
        self.resize_factor = resize_factor
        self.se_reduction = resize_factor if se_reduction is None else se_reduction
        self.ks = ks
        self.inner_dim = self.in_ch * self.resize_factor

        block = nn.Sequential(
                        nn.Conv1d(
                            in_channels=self.in_ch,
                            out_channels=self.inner_dim,
                            kernel_size=1,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(self.inner_dim),
                       activation(),

                       nn.Conv1d(
                            in_channels=self.inner_dim,
                            out_channels=self.inner_dim,
                            kernel_size=ks,
                            groups=self.inner_dim,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(self.inner_dim),
                       activation(),
                       SELayer(self.inner_dim, reduction=self.se_reduction),
                       nn.Conv1d(
                            in_channels=self.inner_dim,
                            out_channels=self.in_ch,
                            kernel_size=1,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(self.in_ch),
                       activation(),
        )

        self.block = block

    def forward(self, x):
        return self.block(x)

class LocalBlock(nn.Module):
    def __init__(self, in_ch, ks, activation, out_ch=None):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = self.in_ch if out_ch is None else out_ch
        self.ks = ks

        self.block = nn.Sequential(
                       nn.Conv1d(
                            in_channels=self.in_ch,
                            out_channels=self.out_ch,
                            kernel_size=self.ks,
                            padding='same',
                            bias=False
                       ),
                       nn.BatchNorm1d(self.out_ch),
                       activation()
        )

    def forward(self, x):
        return self.block(x)

class ResidualConcat(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return torch.concat([self.fn(x, **kwargs), x], dim=1)

class MapperBlock(nn.Module):
    def __init__(self, in_features, out_features, activation=nn.SiLU):
        super().__init__()
        self.block = nn.Sequential(
            nn.BatchNorm1d(in_features),
            nn.Conv1d(in_channels=in_features,
                      out_channels=out_features,
                      kernel_size=1),
        )

    def forward(self, x):
        return self.block(x)

class HumanLegNet(nn.Module):
    def __init__(self,
                 in_ch,
                 output_dim,
                 stem_ch,
                 stem_ks,
                 ef_ks,
                 ef_block_sizes,
                 pool_sizes,
                 resize_factor,
                 activation=nn.SiLU,
                 ):
        super().__init__()
        assert len(pool_sizes) == len(ef_block_sizes)

        self.in_ch = in_ch
        self.stem = LocalBlock(in_ch=in_ch,
                               out_ch=stem_ch,
                               ks=stem_ks,
                               activation=activation)

        blocks = []
        self.output_dim = output_dim
        in_ch = stem_ch
        out_ch = stem_ch
        for pool_sz, out_ch in zip(pool_sizes, ef_block_sizes):
            blc = nn.Sequential(
                ResidualConcat(
                    EffBlock(
                        in_ch=in_ch,
                        out_ch=in_ch,
                        ks=ef_ks,
                        resize_factor=resize_factor,
                        activation=activation)
                ),
                LocalBlock(in_ch=in_ch * 2,
                           out_ch=out_ch,
                           ks=ef_ks,
                           activation=activation),
                nn.MaxPool1d(pool_sz) if pool_sz != 1 else nn.Identity()
            )
            in_ch = out_ch
            blocks.append(blc)
        self.main = nn.Sequential(*blocks)

        self.mapper = MapperBlock(in_features=out_ch,
                                  out_features=out_ch * 2)
        self.head = nn.Sequential(nn.Linear(out_ch * 2, out_ch * 2),
                                   nn.BatchNorm1d(out_ch * 2),
                                   activation(),
                                   nn.Linear(out_ch * 2, self.output_dim))

    def forward(self, x):
        x = self.stem(x)
        x = self.main(x)
        x = self.mapper(x)
        x =  F.adaptive_avg_pool1d(x, 1)
        x = x.squeeze(-1)
        x = self.head(x)
        x = x.squeeze(-1)
        return x

In [7]:
import pytorch_lightning as L
from torchmetrics import MeanSquaredError
from torchmetrics import PearsonCorrCoef

class SeqModel(L.LightningModule):
    
    def __init__(self, in_ch, out_ch, print_each = 1, weight_decay = 1e-2, lr=3e-4):
        super().__init__()
        
        self.model = HumanLegNet(in_ch=in_ch,
                                 output_dim = out_ch,
                                 stem_ch=64,
                                 stem_ks=11,
                                 ef_ks=9,
                                 ef_block_sizes=[80, 96, 112, 128],
                                 pool_sizes=[2,2,2,2],
                                 resize_factor=4)
        self.model.apply(initialize_weights)
        self.loss = nn.MSELoss() 
        self.print_each = print_each
        self.weight_decay = weight_decay
        
        self.lr = lr
        self.train_pearson = PearsonCorrCoef()
        self.val_pearson = PearsonCorrCoef()
        self.test_pearson = PearsonCorrCoef()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        X, y = batch
        y_hat = self.model(X)
        loss = self.loss(y_hat, y)
        
        self.log("train_loss", loss, prog_bar=True,  on_step=False, on_epoch=True, logger = True)
        self.train_pearson.update(y_hat, y)
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss(y_hat, y)

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_pearson.update(y_hat, y)

    def on_validation_epoch_end(self):
        train_pearson = self.train_pearson.compute()
        val_pearson = self.val_pearson.compute()
        
        self.log("val_pearson", val_pearson, prog_bar=True)
        self.log("train_pearson", train_pearson)
        
        if (self.current_epoch + 1) % self.print_each == 0:
            res_str = f"| Epoch: {self.current_epoch} "
            res_str += f"| Val Loss: {self.trainer.callback_metrics['val_loss']:.5f} "
            res_str += f'| Val Pearson: {val_pearson:.5f} '
    
            res_str += f'| Train Loss: {self.trainer.callback_metrics['train_loss']:.5f} '
            res_str += f'| Train Pearson: {train_pearson:.5f} '
            border = '-'*len(res_str)
            print("\n".join(['', border, res_str, border, '']))
        
        self.train_pearson.reset()
        self.val_pearson.reset()
        
    def test_step(self, batch, _):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss(y_hat, y)
        
        self.log('test_loss', 
                 loss, 
                 prog_bar=True, 
                 on_step=False,
                 on_epoch=True)
        self.test_pearson.update(y_hat, y)

    def on_test_epoch_end(self):
        test_pearson = self.test_pearson.compute()
        self.log('test_pearson', test_pearson, prog_bar=True)
        self.test_pearson.reset()

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.model(x)          
        
        return {"y": y.cpu().detach().float(), "pred": y_hat.cpu().detach().float()}

    def train_dataloader(self):
        return data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS, pin_memory=True)

    def val_dataloader(self):
        return data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS, pin_memory=True)

    def test_dataloader(self):
        return data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS, pin_memory=True)
    
    def configure_optimizers(self):
        
        self.optimizer = torch.optim.AdamW(self.parameters(),
                                               lr=self.lr,
                                               weight_decay = self.weight_decay)
        
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer,
                                                        max_lr=self.lr,
                                                        three_phase=False, 
                                                        total_steps=self.trainer.estimated_stepping_batches, # type: ignore
                                                        pct_start=0.3,
                                                        cycle_momentum =False)
        lr_scheduler_config = {
                    "scheduler": lr_scheduler,
                    "interval": "step",
                    "frequency": 1,
                    "name": "cycle_lr"
            }
            
        return [self.optimizer], [lr_scheduler_config]

## Train LegNet model on HepG2 cell type for 50 epochs

In [8]:
seq_model = SeqModel(in_ch = in_channels, out_ch = out_channels,  weight_decay = 1e-4, lr = lr, print_each = 5)

# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=30,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

# Train the model
trainer.fit(seq_model)
trainer.test(seq_model)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-01-15 18:52:12.647955: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-15 18:52:12.662410: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one 

[{'test_loss': 0.23584365844726562, 'test_pearson': 0.7789967060089111}]

## Lets train now for K562 cell type for 50 epochs

In [14]:
# load the data
train_dataset = VikramDataset(cell_type = "K562", split="train", transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = VikramDataset(cell_type = "K562", split="val", transform=test_transform) # use "val" for default validation set or use list
test_dataset = VikramDataset(cell_type = "K562", split="test", transform=test_transform) # use "test" for default test set or use list

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [15]:
seq_model = SeqModel(in_ch = in_channels, out_ch = out_channels,  weight_decay = 1e-4, lr = lr, print_each = 5)

# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=30,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = False,
    num_sanity_val_steps=0
)

# Train the model
trainer.fit(seq_model)

trainer.test(seq_model)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name        | Type            | Params | Mode 
--------------------------------------------------------
0 | model       | HumanLegNet     | 1.3 M  | train
1 | loss        | MSELoss         | 0      | train
2 | val_pearson | PearsonCorrCoef | 0      | train
--------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
118       Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=50` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.09412148594856262
      test_pearson          0.7980344891548157
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.09412148594856262, 'test_pearson': 0.7980344891548157}]