In [1]:
import mpramnist
from mpramnist.Malinois.dataset import MalinoisDataset

from mpramnist.models import BassetBranched

from mpramnist.trainers import LitModel_Malinois

from mpramnist import transforms as t
from mpramnist import target_transforms as t_t

import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.utils.data as data
import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch import loggers as pl_loggers

In [2]:
BATCH_SIZE = 1076
NUM_WORKERS = 103
activity_columns = ['K562', 'HepG2', 'SKNSH']

In [3]:
class L1KLmixed(nn.Module):
    
    def __init__(self, reduction='mean', alpha=1.0, beta=5.0):
        super().__init__()
        
        self.reduction = reduction
        self.alpha = alpha
        self.beta  = beta
        
        self.MSE = nn.L1Loss(reduction=reduction.replace('batch',''))
        self.KL  = nn.KLDivLoss(reduction=reduction, log_target=True)
        
    def forward(self, preds, targets):
        preds_log_prob  = preds   - torch.logsumexp(preds, dim=-1, keepdim=True)
        target_log_prob = targets - torch.logsumexp(targets, dim=-1, keepdim=True)
        
        MSE_loss = self.MSE(preds, targets)
        KL_loss  = self.KL(preds_log_prob, target_log_prob)
        
        combined_loss = MSE_loss.mul(self.alpha) + \
                        KL_loss.mul(self.beta)
        
        return combined_loss.div(self.alpha+self.beta)

In [4]:
from torchmetrics import PearsonCorrCoef

def meaned_prediction(forw, rev, trainer, seq_model, name):

    predictions_forw = trainer.predict(seq_model, dataloaders = forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["predicted"] for pred in predictions_forw])
    
    predictions_rev = trainer.predict(seq_model, dataloaders = rev)
    y_preds_rev = torch.cat([pred["predicted"] for pred in predictions_rev])
    
    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
    
    pears = PearsonCorrCoef(num_outputs = 3)
    print(name + " Pearson correlation")
    
    return pears(mean_forw, targets)

# Testing the Original Dataset Parameters

To use the **exact same parameters** as in the original study, set `filtration = 'original'`. This will apply all filtering criteria from the [**Original Study (PMC10441439)**](https://pmc.ncbi.nlm.nih.gov/articles/PMC10441439/).

## Key Parameters (Original Settings)

| Parameter               | Value  | Description |
|-------------------------|--------|-------------|
| `stderr_threshold`      | 1.0    | Threshold for feature standard error of examples to be included in train/val/test sets |
| `std_multiple_cut`      | 6.0    | Multiplier for standard deviation to define bounds for trusted measurements (removes extreme outliers) |
| `up_cutoff_move`        | 3.0    | Shift factor for upper bound of outlier filter |
| `duplication_cutoff`    | 0.5    | Sequences with max activities higher than this value are duplicated. **Use only during training!** |
| `activity_columns`      | `['K562', 'HepG2', 'SKNSH']` | Column headers containing features to be modeled |
| `stderr`               | `['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE']` | Column headers containing standard errors of features to be modeled |

## Additional Original Settings

- **Sequence padding**: Automatically pads sequences from length 200 to 600 (as done in the original study)
- **Reverse complements**: Enable with `use_original_reverse_complement = True`

## Required Transform

For sequence conversion to one-hot, we recommend using `mpramnist.transforms.Seq2Tensor()`

In [5]:
transform_to_tensor = t.Compose([
    t.Seq2Tensor()
])

train_dataset = MalinoisDataset(split = "train",
                                filtration = "original", # use "original for author's parameters"
                                duplication_cutoff = 0.5, # dont forget to use duplication_cutoff in train
                                use_original_reverse_complement = True, # this parameter paddes sequences and does rev comp
                                transform = transform_to_tensor,
                                root = "../data/"
                               )

val_dataset = MalinoisDataset(split = "val",
                              filtration = "original",
                              transform = transform_to_tensor,
                              root = "../data/") 

test_dataset = MalinoisDataset(split = "test",
                              filtration = "original",
                              transform = transform_to_tensor,
                              root = "../data/")

In [6]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 1864176 (MpraDaraset)
    Number of datapoints: 1864176
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 58809 (MpraDaraset)
    Number of datapoints: 58809
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 62582 (MpraDaraset)
    Number of datapoints: 62582
    Used split fold: ['7', '13']


## Train

In [7]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [8]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [9]:
model = BassetBranched(n_outputs = out_channels, loss_criterion = L1KLmixed())

seq_model = LitModel_Malinois(model = model, num_outputs = out_channels,
                       loss = L1KLmixed(),
                       weight_decay = 0.0003438210249762151, lr = 0.0032658700881052086, print_each = 100)

In [10]:
checkpoint_callback = ModelCheckpoint(
        monitor='enthropy_spearman', 
        mode='max',  
        save_top_k=1,
        dirpath='./checkpoints_malinois/',
        filename='best_model-{epoch:02d}-{enthropy_spearman:.3f}',
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./logs", name="Malinois")

trainer = L.Trainer(
        accelerator="gpu",
        devices=[0],
        min_epochs = 60,
        max_epochs=200,
        precision='16-mixed',
        enable_progress_bar=True,
        num_sanity_val_steps=0,
        callbacks=[checkpoint_callback, EarlyStopping(monitor='enthropy_spearman', 
        mode='max', patience=30)],
        logger = logger
        )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [11]:
trainer.fit(seq_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | BassetBranched  | 4.1 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.429    Total estimated model params size (MB)
38        Modules in trai

Training: |                                                                                       | 0/? [00:00…



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Trainer was signaled to stop but the required `min_epochs=60` or `min_steps=None` has not been met. Training will continue...


Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

In [12]:
# Load the best model
best_model_path = checkpoint_callback.best_model_path

seq_model = LitModel_Malinois.load_from_checkpoint(
    best_model_path,
    model=model,
    num_outputs=out_channels,
    loss=L1KLmixed(),
    weight_decay = 0.0003438210249762151,
    lr = 0.0032658700881052086,
    print_each=1
)

trainer.test(seq_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   test_HepG2_pearson       0.8562631011009216
    test_K562_pearson       0.8555692434310913
   test_SKNSH_pearson       0.8487597107887268
        test_loss           0.07432707399129868
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.07432707399129868,
  'test_K562_pearson': 0.8555692434310913,
  'test_HepG2_pearson': 0.8562631011009216,
  'test_SKNSH_pearson': 0.8487597107887268}]

## Mean prediction

In [13]:
# Evaluating with 'filter = original'
forw_transform = t.Compose([
    t.Seq2Tensor()
])
rev_transform = t.Compose([
    t.Seq2Tensor(),
    t.ReverseComplement(1)
])

test_forw = MalinoisDataset(split = "test",
                              filtration = "original",
                              transform = forw_transform,
                              root = "../data/")
test_rev = MalinoisDataset(split = "test",
                              filtration = "original",
                              transform = rev_transform,
                              root = "../data/")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Malinois")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

Malinois Pearson correlation


tensor([0.8670, 0.8677, 0.8596])

# Using Custom Dataset Parameters

To apply your own filtering criteria, set:  
`filtration = 'own'`

## Custom Transformation Pipeline

Use the following transforms for flexible data processing:

| Transform | Example Usage | Description |
|-----------|---------------|-------------|
| **Padding** | `t.AddFlanks(left_flank, right_flank)` | Adds custom flanking sequences |
| **Cropping** | `t.CenterCrop(600)` | Centers and crops sequences to specified length (600bp) |
| **Reverse Complement** | `t.ReverseComplement(0.5)` | Applies reverse complement transformation with 0.5 probability |

### Key Benefits:
- **Reproduce author methodology**:  
  Combine `AddFlanks` + `CenterCrop` to exactly replicate the original padding approach
- **Memory efficiency**:  
  The `ReverseComplement(0.5)` transform maintains prediction quality while reducing processed dataset size by 50%

In [14]:
left_flank = MalinoisDataset.LEFT_FLANK
right_flank = MalinoisDataset.RIGHT_FLANK
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.ReverseComplement(0.5),
    t.Seq2Tensor()
])
val_test_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.Seq2Tensor()
])

In [15]:
# load the data
train_dataset = MalinoisDataset( 
                              split = "train", 
                              transform = train_transform,
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'], # change as you want
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],      # change as you want
                              stderr_threshold = 1.0,   # change as you want
                              std_multiple_cut = 6.0,   # change as you want
                              up_cutoff_move = 3.0,     # change as you want
                              duplication_cutoff = 0.5, # change as you want
                              root = "../data/") 
val_dataset = MalinoisDataset(
                              split = "val",
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],
                              stderr_threshold = 1.0,
                              std_multiple_cut = 6.0,
                              up_cutoff_move = 3.0,
                              transform = val_test_transform,
                              root = "../data/") 
test_dataset = MalinoisDataset(
                              split = "test", 
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],
                              stderr_threshold = 1.0,
                              std_multiple_cut = 6.0,
                              up_cutoff_move = 3.0,
                              transform = val_test_transform,
                              root = "../data/")

In [16]:
train_dataset[0]

(tensor([[0., 1., 0.,  ..., 0., 1., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 1.,  ..., 1., 0., 0.]]),
 tensor([ 0.3796,  0.0046, -0.2444]))

In [17]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 932088 (MpraDaraset)
    Number of datapoints: 932088
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 58809 (MpraDaraset)
    Number of datapoints: 58809
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 62582 (MpraDaraset)
    Number of datapoints: 62582
    Used split fold: ['7', '13']


# Train

In [18]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [19]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [20]:
model = BassetBranched(n_outputs = out_channels, loss_criterion = L1KLmixed())

seq_model = LitModel_Malinois(model = model, num_outputs = out_channels,
                       loss = L1KLmixed(),
                       weight_decay = 0.0003438210249762151, lr = 0.0032658700881052086, print_each = 100)

In [21]:
checkpoint_callback = ModelCheckpoint(
        monitor='enthropy_spearman', 
        mode='max',  
        save_top_k=1,
        dirpath='./checkpoints_malinois/',
        filename='best_model-{epoch:02d}-{enthropy_spearman:.3f}',
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./logs", name="Malinois")

trainer = L.Trainer(
        accelerator="gpu",
        devices=[0],
        min_epochs = 60,
        max_epochs=200,
        precision='16-mixed',
        enable_progress_bar=True,
        num_sanity_val_steps=0,
        callbacks=[checkpoint_callback, EarlyStopping(monitor='enthropy_spearman', 
        mode='max', patience=30)],
        logger = logger
        )

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [22]:
trainer.fit(seq_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/nios/miniconda3/envs/mpramnist/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:658: Checkpoint directory /home/nios/5Term/examples/checkpoints_malinois exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | BassetBranched  | 4.1 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
--

Training: |                                                                                       | 0/? [00:00…



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

In [23]:
# Load the best model
best_model_path = checkpoint_callback.best_model_path

seq_model = LitModel_Malinois.load_from_checkpoint(
    best_model_path,
    model=model,
    num_outputs=out_channels,
    loss=L1KLmixed(),
    weight_decay = 0.0003438210249762151,
    lr = 0.0032658700881052086,
    print_each=1
)

trainer.test(seq_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   test_HepG2_pearson       0.8532677888870239
    test_K562_pearson       0.8506614565849304
   test_SKNSH_pearson       0.8485007286071777
        test_loss           0.07524599134922028
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.07524599134922028,
  'test_K562_pearson': 0.8506614565849304,
  'test_HepG2_pearson': 0.8532677888870239,
  'test_SKNSH_pearson': 0.8485007286071777}]

## Mean prediction

In [24]:
# Evaluating with using transforms
forw_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.Seq2Tensor()
])
rev_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.Seq2Tensor(),
    t.ReverseComplement(1)
])
test_forw = MalinoisDataset(
                              split = "test", 
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],
                              stderr_threshold = 1.0,
                              std_multiple_cut = 6.0,
                              up_cutoff_move = 3.0,
                              transform = forw_transform,
                              root = "../data/")
test_rev = MalinoisDataset(
                              split = "test", 
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],
                              stderr_threshold = 1.0,
                              std_multiple_cut = 6.0,
                              up_cutoff_move = 3.0,
                              transform = rev_transform,
                              root = "../data/")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Malinois")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

Malinois Pearson correlation


tensor([0.8622, 0.8656, 0.8594])

# Bypassing Data Filtering

To **disable all filtering and cutoff duplication** and use the **raw** dataset, set:  
`filtration = "none"`

In [25]:
left_flank = MalinoisDataset.LEFT_FLANK
right_flank = MalinoisDataset.RIGHT_FLANK
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.ReverseComplement(0.5),
    t.Seq2Tensor()
])
val_test_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.Seq2Tensor()
])

In [26]:
train_dataset = MalinoisDataset( 
                              split = "train", 
                              transform = train_transform,
                              filtration = "none",
                              root = "../data/") 
val_dataset = MalinoisDataset(
                              split = "val",
                              filtration = "none",
                              transform = val_test_transform,
                              root = "../data/") 
test_dataset = MalinoisDataset(
                              split = "test", 
                              filtration = "none",
                              transform = val_test_transform,
                              root = "../data/")

In [27]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 668946 (MpraDaraset)
    Number of datapoints: 668946
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 62406 (MpraDaraset)
    Number of datapoints: 62406
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 66712 (MpraDaraset)
    Number of datapoints: 66712
    Used split fold: ['7', '13']
