In [1]:
import mpramnist
from mpramnist.Malinois.dataset import MalinoisDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights

from mpramnist.trainers import LitModel_Malinois

from mpramnist import transforms as t
from mpramnist import target_transforms as t_t

import numpy as np

import torch
import torch.nn as nn
import torch.utils.data as data
import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch import loggers as pl_loggers
from torchmetrics import PearsonCorrCoef

## Now, let’s test the hypothesis that HumanLegNet will perform on Mallinois data at least as well as the BassetBranched model.

We’ll run the model to predict activity for each cell line **individually**. 

**Reverse-complement augmentation** will be applied to each sequence with a probability of 0.5 using the corresponding transform. 

We’ll also use the **original parameters** (as employed by the authors, `filtration = "original"`) and duplicate sequences with activity above 0.5 (`duplication_cutoff = 0.5`). 

Finally, we’ll average the model’s predictions across strands (`def meaned_prediction`).

# Define some parameters and required functions

In [2]:
BATCH_SIZE = 1076
NUM_WORKERS = 103
ACTIVITY_COLUMNS = ['K562', 'HepG2', 'SKNSH']
LR = 0.01
WD = 0.1

In [3]:
class L1KLmixed(nn.Module):
    
    def __init__(self, reduction='mean', alpha=1.0, beta=5.0):
        super().__init__()
        
        self.reduction = reduction
        self.alpha = alpha
        self.beta  = beta
        
        self.MSE = nn.L1Loss(reduction=reduction.replace('batch',''))
        self.KL  = nn.KLDivLoss(reduction=reduction, log_target=True)
        
    def forward(self, preds, targets):
        preds_log_prob  = preds   - torch.logsumexp(preds, dim=-1, keepdim=True)
        target_log_prob = targets - torch.logsumexp(targets, dim=-1, keepdim=True)
        
        MSE_loss = self.MSE(preds, targets)
        KL_loss  = self.KL(preds_log_prob, target_log_prob)
        
        combined_loss = MSE_loss.mul(self.alpha) + \
                        KL_loss.mul(self.beta)
        
        return combined_loss.div(self.alpha+self.beta)

In [4]:
left_flank = MalinoisDataset.LEFT_FLANK
right_flank = MalinoisDataset.RIGHT_FLANK
# preprocessing
train_transform = t.Compose([
    # The filter = 'original' parameter automatically adds flanks and pads sequences to length 600, 
    # making t.AddFlanks(left_flank, right_flank) unnecessary
    #t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(260),
    t.RightCrop(230,260), # this is using for shifting
    t.LeftCrop(230, 230),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5),
])

val_test_transform = t.Compose([ 
    # The filter = 'original' parameter automatically adds flanks and pads sequences to length 600, 
    # making t.AddFlanks(left_flank, right_flank) unnecessary
    #t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(230),
    t.Seq2Tensor(), 
    t.ReverseComplement(0), 
])

In [5]:
forw_transform = t.Compose([ 
    # The filter = 'original' parameter automatically adds flanks and pads sequences to length 600, 
    # making t.AddFlanks(left_flank, right_flank) unnecessary
    #t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(230),
    t.Seq2Tensor(), 
    t.ReverseComplement(0), 
])
rev_transform = t.Compose([ 
    # The filter = 'original' parameter automatically adds flanks and pads sequences to length 600, 
    # making t.AddFlanks(left_flank, right_flank) unnecessary
    #t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(230),
    t.Seq2Tensor(), 
    t.ReverseComplement(1), 
])

def meaned_prediction(forw, rev, trainer, seq_model, num_outputs, name):
    
    predictions_forw = trainer.predict(seq_model, dataloaders = forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["predicted"] for pred in predictions_forw])
    
    predictions_rev = trainer.predict(seq_model, dataloaders = rev)
    y_preds_rev = torch.cat([pred["predicted"] for pred in predictions_rev])
    
    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
    
    pears = PearsonCorrCoef(num_outputs = num_outputs)
    print(name + " Pearson correlation")
    
    return pears(mean_forw, targets)

# K562

For the **K562** cell line, the authors reported a Pearson correlation coefficient of **0.88**

In [6]:
# load the data
activity_columns = ['K562']
train_dataset = MalinoisDataset(split = "train",  transform = train_transform, duplication_cutoff = 0.5,
                                filtration = "original", activity_columns = ['K562'], 
                                stderr_columns = ['K562_lfcSE'], root = "../data/") 

val_dataset = MalinoisDataset(split = "val", transform = val_test_transform, 
                              filtration = "original", activity_columns = ['K562'], 
                              stderr_columns = ['K562_lfcSE'], root = "../data/") 

test_dataset = MalinoisDataset(split = "test", transform = val_test_transform, 
                               filtration = "original", activity_columns = ['K562'], 
                               stderr_columns = ['K562_lfcSE'], root = "../data/")

In [7]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 865249 (MpraDaraset)
    Number of datapoints: 865249
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 60663 (MpraDaraset)
    Number of datapoints: 60663
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 64881 (MpraDaraset)
    Number of datapoints: 64881
    Used split fold: ['7', '13']


## Train

In [8]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [9]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [10]:
model_K562 = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model_K562.apply(initialize_weights)

seq_model_K562 = LitModel_Malinois(model = model_K562, num_outputs = out_channels,
                           activity_columns = activity_columns,
                           loss = L1KLmixed(), 
                           use_one_cycle = True, # This parameter in LitModel_Malinois defines the OneCycleLR scheduler and AdamW optimizer — exactly as required for LegNet architectures.
                           weight_decay = WD, lr = LR, print_each = 10)

In [11]:
checkpoint_callback = ModelCheckpoint(
        monitor='val_pearson', 
        mode='max',  
        save_top_k=1,
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./malinois_logs", name="Malinois_model_legnet_K562")

trainer_K562 = L.Trainer(
        accelerator="gpu",
        devices=[1],
        precision='16-mixed',
        enable_progress_bar=True,
        max_epochs=35,
        callbacks=[checkpoint_callback],
        logger = logger
    )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
trainer_K562.fit(seq_model_K562,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290  

Sanity Checking: |                                                                                | 0/? [00:00…



Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


---------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.11004 | Harm Mean Loss: 0.74758 | Enthropy Spearman: 1.00000 |
| Val Pearson K562: 0.84680 |
| Train Pearson K562: 0.83674 |
---------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.09040 | Harm Mean Loss: 0.61135 | Enthropy Spearman: 1.00000 |
| Val Pearson K562: 0.86641 |
| Train Pearson K562: 0.86465 |
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.06594 | Harm Mean Loss: 0.36235 | Enthropy Spearman: 1.00000 |
| Val Pearson K562: 0.90972 |
| Train Pearson K562: 0.91312 |
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=35` reached.


In [13]:
best_model_path = checkpoint_callback.best_model_path
seq_model_K562 = LitModel_Malinois.load_from_checkpoint(
    best_model_path,
    model=model_K562,
    num_outputs=out_channels,
    loss=L1KLmixed(),
    weight_decay = 0.1,
    lr = 0.01,
    print_each=1
)

In [14]:
test_forw = MalinoisDataset(split = "test", transform = forw_transform, 
                               filtration = "original", activity_columns = ['K562'], 
                               stderr_columns = ['K562_lfcSE'], root = "../data/")
test_rev = MalinoisDataset(split = "test", transform = rev_transform, 
                               filtration = "original", activity_columns = ['K562'], 
                               stderr_columns = ['K562_lfcSE'], root = "../data/")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer_K562, seq_model_K562, 1, "K562")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

K562 Pearson correlation


tensor(0.8981)

# HepG2

For the **HepG2** cell line, the authors reported a Pearson correlation coefficient of **0.89**

In [15]:
# load the data
activity_columns = ['HepG2']
train_dataset = MalinoisDataset(split = "train",  transform = train_transform, duplication_cutoff = 0.5,
                                filtration = "original", activity_columns = ['HepG2'], 
                                stderr_columns = ['HepG2_lfcSE'], root = "../data/") 

val_dataset = MalinoisDataset(split = "val", transform = val_test_transform, 
                              filtration = "original", activity_columns = ['HepG2'], 
                              stderr_columns = ['HepG2_lfcSE'], root = "../data/") 

test_dataset = MalinoisDataset(split = "test", transform = val_test_transform, 
                               filtration = "original", activity_columns = ['HepG2'], 
                               stderr_columns = ['HepG2_lfcSE'], root = "../data/")

In [16]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 876218 (MpraDaraset)
    Number of datapoints: 876218
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 61471 (MpraDaraset)
    Number of datapoints: 61471
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 65157 (MpraDaraset)
    Number of datapoints: 65157
    Used split fold: ['7', '13']


## Train

In [17]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [18]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [19]:
model_HepG2 = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model_HepG2.apply(initialize_weights)

seq_model_HepG2 = LitModel_Malinois(model = model_HepG2, num_outputs = out_channels,
                           activity_columns = activity_columns,
                           loss = L1KLmixed(), 
                           use_one_cycle = True, # This parameter in LitModel_Malinois defines the OneCycleLR scheduler and AdamW optimizer — exactly as required for LegNet architectures.
                           weight_decay = WD, lr = LR, print_each = 10)

In [20]:
checkpoint_callback = ModelCheckpoint(
        monitor='val_pearson', 
        mode='max',  
        save_top_k=1,
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./malinois_logs", name="Malinois_model_legnet_HepG2")

trainer_HepG2 = L.Trainer(
        accelerator="gpu",
        devices=[1],
        precision='16-mixed',
        enable_progress_bar=True,
        max_epochs=35,
        callbacks=[checkpoint_callback],
        logger = logger
    )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [21]:
trainer_HepG2.fit(seq_model_HepG2,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
122       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                | 0/? [00:00…



Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


---------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.09291 | Harm Mean Loss: 0.58388 | Enthropy Spearman: 1.00000 |
| Val Pearson HepG2: 0.81839 |
| Train Pearson HepG2: 0.83293 |
---------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.07186 | Harm Mean Loss: 0.41185 | Enthropy Spearman: 1.00000 |
| Val Pearson HepG2: 0.86198 |
| Train Pearson HepG2: 0.85822 |
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.06068 | Harm Mean Loss: 0.27767 | Enthropy Spearman: 1.00000 |
| Val Pearson HepG2: 0.90607 |
| Train Pearson HepG2: 0.90882 |
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=35` reached.


In [22]:
best_model_path = checkpoint_callback.best_model_path
seq_model_HepG2 = LitModel_Malinois.load_from_checkpoint(
    best_model_path,
    model=model_HepG2,
    num_outputs=out_channels,
    loss=L1KLmixed(),
    weight_decay = 0.1,
    lr = 0.01,
    print_each=1
)

In [23]:
test_forw = MalinoisDataset(split = "test", transform = forw_transform, 
                               filtration = "original", activity_columns = ['HepG2'], 
                               stderr_columns = ['HepG2_lfcSE'], root = "../data/")
test_rev = MalinoisDataset(split = "test", transform = rev_transform, 
                               filtration = "original", activity_columns = ['HepG2'], 
                               stderr_columns = ['HepG2_lfcSE'], root = "../data/")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer_HepG2, seq_model_HepG2, 1, "HepG2")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

HepG2 Pearson correlation


tensor(0.8963)

# SK-N-SH

For the **SK-N-SH** cell line, the authors reported a Pearson correlation coefficient of **0.88**

In [24]:
# load the data
activity_columns = ['SKNSH']
train_dataset = MalinoisDataset(split = "train",  transform = train_transform, duplication_cutoff = 0.5,
                                filtration = "original", activity_columns = ['SKNSH'], 
                                stderr_columns = ['SKNSH_lfcSE'], root = "../data/") 

val_dataset = MalinoisDataset(split = "val", transform = val_test_transform, 
                              filtration = "original", activity_columns = ['SKNSH'], 
                              stderr_columns = ['SKNSH_lfcSE'], root = "../data/") 

test_dataset = MalinoisDataset(split = "test", transform = val_test_transform, 
                               filtration = "original", activity_columns = ['SKNSH'], 
                               stderr_columns = ['SKNSH_lfcSE'], root = "../data/")

In [25]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 857838 (MpraDaraset)
    Number of datapoints: 857838
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 59731 (MpraDaraset)
    Number of datapoints: 59731
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 63730 (MpraDaraset)
    Number of datapoints: 63730
    Used split fold: ['7', '13']


## Train

In [26]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [27]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [28]:
model_SKNSH = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model_SKNSH.apply(initialize_weights)

seq_model_SKNSH = LitModel_Malinois(model = model_SKNSH, num_outputs = out_channels,
                           activity_columns = activity_columns,
                           loss = L1KLmixed(), 
                           use_one_cycle = True, # This parameter in LitModel_Malinois defines the OneCycleLR scheduler and AdamW optimizer — exactly as required for LegNet architectures.
                           weight_decay = WD, lr = LR, print_each = 10)

In [29]:
checkpoint_callback = ModelCheckpoint(
        monitor='val_pearson', 
        mode='max',  
        save_top_k=1,
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./malinois_logs", name="Malinois_model_legnet_SKNSH")

trainer_SKNSH = L.Trainer(
        accelerator="gpu",
        devices=[1],
        precision='16-mixed',
        enable_progress_bar=True,
        max_epochs=35,
        callbacks=[checkpoint_callback],
        logger = logger
    )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [30]:
trainer_SKNSH.fit(seq_model_SKNSH,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
122       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                | 0/? [00:00…



Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


---------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.08406 | Harm Mean Loss: 0.54031 | Enthropy Spearman: 1.00000 |
| Val Pearson SKNSH: 0.85933 |
| Train Pearson SKNSH: 0.83740 |
---------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.08045 | Harm Mean Loss: 0.54147 | Enthropy Spearman: 1.00000 |
| Val Pearson SKNSH: 0.86656 |
| Train Pearson SKNSH: 0.86201 |
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.07372 | Harm Mean Loss: 0.39896 | Enthropy Spearman: 1.00000 |
| Val Pearson SKNSH: 0.89860 |
| Train Pearson SKNSH: 0.90875 |
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=35` reached.


In [31]:
best_model_path = checkpoint_callback.best_model_path
seq_model_SKNSH = LitModel_Malinois.load_from_checkpoint(
    best_model_path,
    model=model_SKNSH,
    num_outputs=out_channels,
    loss=L1KLmixed(),
    weight_decay = 0.1,
    lr = 0.01,
    print_each=1
)

In [32]:
test_forw = MalinoisDataset(split = "test", transform = forw_transform, 
                               filtration = "original", activity_columns = ['SKNSH'], 
                               stderr_columns = ['SKNSH_lfcSE'], root = "../data/")
test_rev = MalinoisDataset(split = "test", transform = rev_transform, 
                               filtration = "original", activity_columns = ['SKNSH'], 
                               stderr_columns = ['SKNSH_lfcSE'], root = "../data/")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer_SKNSH, seq_model_SKNSH, 1, "SKNSH")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

SKNSH Pearson correlation


tensor(0.8900)

# Recall that the authors reported the following performance metrics in their work:

0.88 for K562

0.89 for HepG2

0.88 for SKNSH

HumanLegNet demonstrates **strong, comparable results** to those reported by the authors when trained on individual cell lines. However, there’s also an option to train HumanLegNet on **all three cell lines simultaneously** without sacrificing performance. To do this, specify:

# All 3 cell lines together

In [33]:
# load the data
train_dataset = MalinoisDataset(split = "train",  transform = train_transform, duplication_cutoff = 0.5,
                                filtration = "original", activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                                stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'], root = "../data/") 

val_dataset = MalinoisDataset(split = "val", transform = val_test_transform, 
                              filtration = "original", activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'], 
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'], root = "../data/") 

test_dataset = MalinoisDataset(split = "test", transform = val_test_transform, 
                               filtration = "original", activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'], 
                               stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'], root = "../data/")

In [34]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 932088 (MpraDaraset)
    Number of datapoints: 932088
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 58809 (MpraDaraset)
    Number of datapoints: 58809
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 62582 (MpraDaraset)
    Number of datapoints: 62582
    Used split fold: ['7', '13']


## Train

In [35]:
in_channels = len(train_dataset[0][0])
out_channels = len(ACTIVITY_COLUMNS)

In [36]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [37]:
model_3_outputs = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model_3_outputs.apply(initialize_weights)

seq_model_3_outputs = LitModel_Malinois(model = model_3_outputs, num_outputs = out_channels,
                           activity_columns = ACTIVITY_COLUMNS,
                           loss = L1KLmixed(), 
                           use_one_cycle = True, # This parameter in LitModel_Malinois defines the OneCycleLR scheduler and AdamW optimizer — exactly as required for LegNet architectures.
                           weight_decay = WD, lr = LR, print_each = 10)

In [38]:
checkpoint_callback = ModelCheckpoint(
        monitor='val_pearson', 
        mode='max',  
        save_top_k=1,
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./malinois_logs", name="Malinois_model_legnet_3_outputs")

trainer_3_outputs = L.Trainer(
        accelerator="gpu",
        devices=[1],
        precision='16-mixed',
        enable_progress_bar=True,
        max_epochs=35,
        callbacks=[checkpoint_callback],
        logger = logger
    )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [39]:
trainer_3_outputs.fit(seq_model_3_outputs,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.292     Total estimated model params size (MB)
122       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                | 0/? [00:00…



Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.09699 | Harm Mean Loss: 0.59355 | Enthropy Spearman: 0.44523 |
| Val Pearson K562: 0.83312 | Val Pearson HepG2: 0.83570 | Val Pearson SKNSH: 0.84148 | Mean Val Pearson: 0.83677 |
| Train Pearson K562: 0.85064 | Train Pearson HepG2: 0.84123 | Train Pearson SKNSH: 0.84176 | Mean Train Pearson: 0.84455 |
--------------------------------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.11418 | Harm Mean Loss: 0.55763 | Enthropy Spearman: 0.41368 |
| Val Pearson K562: 0.88031 | Val Pearson HepG2: 0.87994 | Val Pearson SKNSH: 0.88404 | Mean Val Pearson: 0.88143 |
| Train Pearson K562: 0.87656 | Train Pearson HepG2: 0.87105 | Train Pearson SKNSH: 0.87085 | Mean Train Pearson: 0.87282 |
--------------------------------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.07244 | Harm Mean Loss: 0.31045 | Enthropy Spearman: 0.58516 |
| Val Pearson K562: 0.91479 | Val Pearson HepG2: 0.91537 | Val Pearson SKNSH: 0.91198 | Mean Val Pearson: 0.91405 |
| Train Pearson K562: 0.92207 | Train Pearson HepG2: 0.91881 | Train Pearson SKNSH: 0.91479 | Mean Train Pearson: 0.91856 |
--------------------------------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=35` reached.


In [40]:
best_model_path = checkpoint_callback.best_model_path
seq_model_3_outputs = LitModel_Malinois.load_from_checkpoint(
    best_model_path,
    model=model_3_outputs,
    num_outputs=out_channels,
    loss=L1KLmixed(),
    weight_decay = 0.1,
    lr = 0.01,
    print_each=1
)

In [41]:
test_forw = MalinoisDataset(split = "test", transform = forw_transform, 
                               filtration = "original", activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'], 
                               stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'], root = "../data/")
test_rev = MalinoisDataset(split = "test", transform = rev_transform, 
                               filtration = "original", activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'], 
                               stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'], root = "../data/")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer_3_outputs, seq_model_3_outputs, 3, "K562, HepG2, SK-N-SH")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

K562, HepG2, SK-N-SH Pearson correlation


tensor([0.9078, 0.9080, 0.8984])