In [1]:
from tqdm import tqdm
from dataset import VikramDataset
import transforms as t

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [2]:
NUM_EPOCHS = 3
BATCH_SIZE = 64
lr = 0.001

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [3]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks("",""),
    t.LeftCrop(230,230),
    t.Seq2Tensor(),
    t.Reverse(),
    t.AddReverseChannel()

])
test_transform = t.Compose([
    t.AddFlanks("",""),
    t.LeftCrop(230,230),
    t.Seq2Tensor(),
    t.AddReverseChannel()
])

# load the data
train_dataset = VikramDataset(cell_type = "HepG2", split="train", transform=train_transform) # could use a list e.g. [1,2,5,6,7,8] 
                                                                                            # for needed folds
val_dataset = VikramDataset(cell_type = "HepG2", split="val", transform=test_transform) # use "val" for default validation set
test_dataset = VikramDataset(cell_type = "HepG2", split="test", transform=test_transform) # use "test" for default test set

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
print(train_dataset)
print("===================")
print(test_dataset) # must to realize __repr__

Dataset VikramDataset of size 98336 (MpraDaraset)
    Number of datapoints: 98336
    Split: train
    Sequence size: 230
    Number of channels: 5
Dataset VikramDataset of size 12298 (MpraDaraset)
    Number of datapoints: 12298
    Split: test
    Sequence size: 230
    Number of channels: 5


## Then, we define a simple model for illustration, object function and optimizer that we use to classify.

In [5]:
class Net(nn.Module):
    
    def __init__(self, seq_len = 230, in_ch = 4, block_sizes=[16, 24, 32, 40, 48], kernel_size=7):
        
        super().__init__()
        self.seq_len = seq_len
        self.in_ch = in_ch
        out_ch = 64
        nn_blocks = []
      
        for in_bs, out_bs in zip([in_ch] + block_sizes, block_sizes):
            
            block = nn.Sequential(
                nn.Conv1d(in_bs, out_bs, kernel_size=kernel_size, padding=1),
                nn.SiLU(),
                nn.BatchNorm1d(out_bs)
            )
            nn_blocks.append(block)
            
        self.conv_net = nn.Sequential(
            *nn_blocks,
            nn.Flatten(),
            nn.Linear(block_sizes[-1] * (seq_len + len(block_sizes)*(3-kernel_size)), out_ch),
        )
        self.head = nn.Sequential(nn.Linear(out_ch, out_ch),
                                   nn.BatchNorm1d(out_ch),
                                   nn.SiLU(),
                                   nn.Linear(out_ch, 1))

    def forward(self, x):
       
        out = self.conv_net(x)
        out = self.head(out)
        
        return out
        
model = Net(in_ch=len(train_dataset[0][0]), seq_len = len(train_dataset[0][0][0]))
model = model.to("cuda")

# define loss function and optimizer
criterion = nn.MSELoss()

optimizer = torch.optim.AdamW(model.parameters(), 
                                      lr=lr,
                                      weight_decay=lr)

## Now we can start to train

Use "cuda" device to select gpu

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [7]:
from torchmetrics import PearsonCorrCoef
for epoch in range(NUM_EPOCHS):
    
    model.train()
    for inputs, targets in tqdm(train_loader):
        
        # forward + backward + optimize
        optimizer.zero_grad()
    
        outputs = model(inputs.to(device))
        outputs = outputs.view(-1)
 
        targets = targets.squeeze().to(device)

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
    model.eval()
    y_loss = 0
    y_val_pearson = 0
    loss_mse = nn.MSELoss() 
    val_pearson_coef = PearsonCorrCoef().to(device)
    
    with torch.no_grad():
        
        for inputs, targets in tqdm(val_loader):
            outputs = model(inputs.to(device))

            outputs = outputs.view(-1)
            targets = targets.squeeze().to(device)
            
            val_pearson = val_pearson_coef(outputs, targets)
            y_val_pearson = val_pearson

            y_loss = loss_mse(outputs, targets)
        print('%s  loss: %.3f  val_pearson:%.3f' % ("val", y_loss, y_val_pearson))


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:32<00:00, 16.55it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:09<00:00, 20.21it/s]


val  loss: 0.227  val_pearson:0.701


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:31<00:00, 16.79it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:09<00:00, 19.43it/s]


val  loss: 0.204  val_pearson:0.094


100%|███████████████████████████████████████████████████████████████████████████████| 1537/1537 [01:31<00:00, 16.72it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 193/193 [00:08<00:00, 21.52it/s]

val  loss: 0.166  val_pearson:0.777





In [8]:
# evaluation

def test(split):
    model.eval()
    y_loss = 0
    y_val_pearson = 0
    loss_mse = nn.MSELoss() 
    val_pearson_coef = PearsonCorrCoef().to(device)
    if split == "test":
        data_loader = test_loader
    elif split == "val":
        data_loader = val_loader
    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs.to(device))

            outputs = outputs.view(-1)
            targets = targets.squeeze().to(device)
            
            val_pearson = val_pearson_coef(outputs, targets)
            y_val_pearson = val_pearson

            loss = loss_mse(outputs, targets)
            y_loss = loss
    
        print('%s  loss: %.3f  val_pearson:%.3f' % (split, y_loss, y_val_pearson))

        
print('==> Evaluating ...')
test("val")
test('test')


==> Evaluating ...
val  loss: 0.166  val_pearson:0.777
test  loss: 0.464  val_pearson:0.783
