In [1]:
from tqdm import tqdm
import mpramnist
from mpramnist.vikramdataset import VikramDataset

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [6]:
NUM_EPOCHS = 5
BATCH_SIZE = 64
lr = 0.001

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [7]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks("GGCCCGCTCTAGACCTGCAGG","CACTAGAGGGTATATAATGGAAGCTCGACTTCCAGCTTGGCAATCCGGTACTGT"),
    t.RandomCrop(230),
    t.Seq2Tensor(),
    t.Reverse(0.5),
    t.AddReverseChannel()

])
test_transform = t.Compose([
    t.Seq2Tensor(),
    t.AddReverseChannel()
])

# load the data
train_dataset = VikramDataset(cell_type = "HepG2", split="train", transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = VikramDataset(cell_type = "HepG2", split="val", transform=test_transform) # use "val" for default validation set
test_dataset = VikramDataset(cell_type = "HepG2", split="test", transform=test_transform) # use "test" for default test set

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [8]:
print(train_dataset)
print("===================")
print(test_dataset)

Dataset VikramDataset of size 98336 (MpraDaraset)
    Number of datapoints: 98336
    Default split folds: {'train': '1, 2, 3, 4, 5, 6, 7, 8', 'val': 9, 'test': 10}
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
    Scalar features: {}
    Vector features: {}
    Cell types: ['HepG2', 'K562', 'WTC11']
    Сell type used: HepG2
    Target columns that can be used: {'averaged_expression', 'expression'}
    Number of channels: 5
    Sequence size: 230
    Number of samples: {'train': 98336, 'val': 12292, 'test': 12292}
    Description: The VikramDataset is based on lentiMPRA assay, which determines the regulatory activity of over 680,000 sequences, representing a nearly comprehensive set of all annotated CREs among three cell types (HepG2, K562, and WTC11). HepG2 is a human liver cancer cell line, K562 is myelogenous leukemia cell line, WTC11 is pluripotent stem cell line derived from adult skin 
Dataset VikramDataset of size 12298 (MpraDaraset)
    Number of datapoints: 12298
    Default 

## Then, we define a simple model for illustration, object function and optimizer that we use to classify.

In [14]:
class Net(nn.Module):
    
    def __init__(self, seq_len = 230, in_ch = 4, block_sizes=[16, 24, 32, 40, 48], kernel_size=7):
        
        super().__init__()
        self.seq_len = seq_len
        self.in_ch = in_ch
        out_ch = 64
        nn_blocks = []
      
        for in_bs, out_bs in zip([in_ch] + block_sizes, block_sizes):
            
            block = nn.Sequential(
                nn.Conv1d(in_bs, out_bs, kernel_size=kernel_size, padding=1),
                nn.SiLU(),
                nn.BatchNorm1d(out_bs)
            )
            nn_blocks.append(block)
            
        self.conv_net = nn.Sequential(
            *nn_blocks,
            nn.Flatten(),
            nn.Linear(block_sizes[-1] * (seq_len + len(block_sizes)*(3-kernel_size)), out_ch),
        )
        self.head = nn.Sequential(nn.Linear(out_ch, out_ch),
                                   nn.BatchNorm1d(out_ch),
                                   nn.SiLU(),
                                   nn.Linear(out_ch, 1))

    def forward(self, x):
       
        out = self.conv_net(x)
        out = self.head(out)
        
        return out
        
model = Net(in_ch=len(train_dataset[0][0]), seq_len = len(train_dataset[0][0][0]))
model = model.to("cuda")

# define loss function and optimizer
criterion = nn.MSELoss()

optimizer = torch.optim.AdamW(model.parameters(), 
                                      lr=lr,
                                      weight_decay=lr)

## Now we can start to train

Use "cuda" device to select gpu

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [16]:
from torchmetrics import PearsonCorrCoef
for epoch in range(NUM_EPOCHS):
    
    model.train()
    for inputs, targets in tqdm(train_loader):
        
        # forward + backward + optimize
        optimizer.zero_grad()
    
        outputs = model(inputs.to(device))
        outputs = outputs.view(-1)
 
        targets = targets.squeeze().to(device)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
    model.eval()
    y_loss = 0
    y_val_pearson = 0
    loss_mse = nn.MSELoss() 
    val_pearson_coef = PearsonCorrCoef().to(device)
    
    with torch.no_grad():
        
        for inputs, targets in tqdm(val_loader):
            outputs = model(inputs.to(device))

            outputs = outputs.view(-1)
            targets = targets.squeeze().to(device)
            
            val_pearson_coef.update(outputs, targets)

            y_loss = loss_mse(outputs, targets)
        print('%s  loss: %.3f  val_pearson:%.3f' % ("val", y_loss, val_pearson_coef.compute()))

100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:12<00:00, 21.11it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:06<00:00, 28.10it/s]


val  loss: 0.270  val_pearson:0.494


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:18<00:00, 19.48it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:05<00:00, 38.16it/s]


val  loss: 0.209  val_pearson:0.589


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:08<00:00, 22.44it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:05<00:00, 34.44it/s]


val  loss: 0.030  val_pearson:0.622


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:17<00:00, 19.73it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:05<00:00, 36.15it/s]


val  loss: 0.065  val_pearson:0.637


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:14<00:00, 20.50it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:06<00:00, 30.84it/s]

val  loss: 0.165  val_pearson:0.647





train more epochs for better quality

In [17]:
# evaluation

def test(split):
    model.eval()
    y_loss = 0
    y_val_pearson = 0
    loss_mse = nn.MSELoss() 
    val_pearson_coef = PearsonCorrCoef().to(device)
    if split == "test":
        data_loader = test_loader
    elif split == "val":
        data_loader = val_loader
    else:
        raise Exception(f'Wrong split: {split}')
    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs.to(device))

            outputs = outputs.view(-1)
            targets = targets.squeeze().to(device)
            
            val_pearson_coef.update(outputs, targets)

            loss = loss_mse(outputs, targets)
            y_loss = loss
    
        print('%s  loss: %.3f  val_pearson:%.3f' % (split, y_loss, val_pearson_coef.compute()))

        
print('==> Evaluating ...')
test('val')
test('test')

==> Evaluating ...
val  loss: 0.165  val_pearson:0.647
test  loss: 0.324  val_pearson:0.667
