In [1]:
import mpramnist
from mpramnist.Malinois.dataset import MalinoisDataset

from mpramnist.models import BassetBranched

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Malinois

from mpramnist import transforms as t
from mpramnist import target_transforms as t_t
import pandas as pd

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import pytorch_lightning as L

In [2]:
BATCH_SIZE = 1076
NUM_WORKERS = 103

# Testing the Original Dataset Parameters

To use the **exact same parameters** as in the original study, set `filtration = 'original'`. This will apply all filtering criteria from the [**Original Study (PMC10441439)**](https://pmc.ncbi.nlm.nih.gov/articles/PMC10441439/).

## Key Parameters (Original Settings)

| Parameter               | Value  | Description |
|-------------------------|--------|-------------|
| `stderr_threshold`      | 1.0    | Threshold for feature standard error of examples to be included in train/val/test sets |
| `std_multiple_cut`      | 6.0    | Multiplier for standard deviation to define bounds for trusted measurements (removes extreme outliers) |
| `up_cutoff_move`        | 3.0    | Shift factor for upper bound of outlier filter |
| `duplication_cutoff`    | 0.5    | Sequences with max activities higher than this value are duplicated. **Use only during training!** |
| `activity_columns`      | `['K562', 'HepG2', 'SKNSH']` | Column headers containing features to be modeled |
| `stderr`               | `['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE']` | Column headers containing standard errors of features to be modeled |

## Additional Original Settings

- **Sequence padding**: Automatically pads sequences from length 200 to 600 (as done in the original study)
- **Reverse complements**: Enable with `use_original_reverse_complement = True`

## Recommended Transformation

For sequence conversion, we recommend using:
```python
mpramnist.transforms.Seq2Tensor()

In [3]:
transform_to_tensor = t.Compose([
    t.Seq2Tensor()
])

train_dataset = MalinoisDataset(split = "train",
                                filtration = "original", # use "original for author's parameters"
                                duplication_cutoff = 0.5,
                                use_original_reverse_complement = True, # this parameter paddes sequences and does rev comp
                                transform = transform_to_tensor,
                                root = "../data/"
                               )

val_dataset = MalinoisDataset(split = "val",
                              filtration = "original",
                              transform = transform_to_tensor,
                              root = "../data/") 

test_dataset = MalinoisDataset(split = "test",
                              filtration = "original",
                              transform = transform_to_tensor,
                              root = "../data/")

In [4]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 1864176 (MpraDaraset)
    Number of datapoints: 1864176
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 58809 (MpraDaraset)
    Number of datapoints: 58809
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 62582 (MpraDaraset)
    Number of datapoints: 62582
    Used split fold: ['7', '13']


# Using Custom Dataset Parameters

To apply your own filtering criteria, set:  
`filtration = 'own'`

## Custom Transformation Pipeline

Use the following transforms for flexible data processing:

| Transform | Example Usage | Description |
|-----------|---------------|-------------|
| **Padding** | `t.AddFlanks(left_flank, right_flank)` | Adds custom flanking sequences |
| **Cropping** | `t.CenterCrop(600)` | Centers and crops sequences to specified length (600bp) |
| **Reverse Complement** | `t.ReverseComplement(0.5)` | Applies reverse complement transformation with 0.5 probability |

### Key Benefits:
- **Reproduce author methodology**:  
  Combine `AddFlanks` + `CenterCrop` to exactly replicate the original padding approach
- **Memory efficiency**:  
  The `ReverseComplement(0.5)` transform maintains prediction quality while reducing processed dataset size by 50%

In [3]:
left_flank = MalinoisDataset.LEFT_FLANK
right_flank = MalinoisDataset.RIGHT_FLANK
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.ReverseComplement(0.5),
    t.Seq2Tensor()
])
val_test_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.Seq2Tensor()
])

In [4]:
# load the data
train_dataset = MalinoisDataset( 
                              split = "train", 
                              transform = train_transform,
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'], # change as you want
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],      # change as you want
                              stderr_threshold = 1.0,   # change as you want
                              std_multiple_cut = 6.0,   # change as you want
                              up_cutoff_move = 3.0,     # change as you want
                              duplication_cutoff = 0.5, # change as you want
                              root = "../data/") 
val_dataset = MalinoisDataset(
                              split = "val",
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],
                              stderr_threshold = 1.0,
                              std_multiple_cut = 6.0,
                              up_cutoff_move = 3.0,
                              transform = val_test_transform,
                              root = "../data/") 
test_dataset = MalinoisDataset(
                              split = "test", 
                              filtration = "own",
                              activity_columns = ['K562_log2FC', 'HepG2_log2FC', 'SKNSH_log2FC'],
                              stderr_columns = ['K562_lfcSE', 'HepG2_lfcSE', 'SKNSH_lfcSE'],
                              stderr_threshold = 1.0,
                              std_multiple_cut = 6.0,
                              up_cutoff_move = 3.0,
                              transform = val_test_transform,
                              root = "../data/")

In [5]:
train_dataset[0]

(tensor([[0., 0., 1.,  ..., 1., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 1., 0.]]),
 tensor([ 0.3796,  0.0046, -0.2444]))

In [6]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 932088 (MpraDaraset)
    Number of datapoints: 932088
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 58809 (MpraDaraset)
    Number of datapoints: 58809
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 62582 (MpraDaraset)
    Number of datapoints: 62582
    Used split fold: ['7', '13']


# Bypassing Data Filtering

To **disable all filtering** and use the raw dataset, set:  
`filtration = "none"`

In [None]:
left_flank = MalinoisDataset.LEFT_FLANK
right_flank = MalinoisDataset.RIGHT_FLANK
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.ReverseComplement(0.5),
    t.Seq2Tensor()
])
val_test_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.CenterCrop(600),
    t.Seq2Tensor()
])

In [10]:
train_dataset = MalinoisDataset( 
                              split = "train", 
                              transform = train_transform,
                              filtration = "none",
                              root = "../data/") 
val_dataset = MalinoisDataset(
                              split = "val",
                              filtration = "none",
                              transform = val_test_transform,
                              root = "../data/") 
test_dataset = MalinoisDataset(
                              split = "test", 
                              filtration = "none",
                              transform = val_test_transform,
                              root = "../data/")

In [11]:
print(train_dataset)
print("="*50)
print(val_dataset)
print("="*50)
print(test_dataset)

Dataset MalinoisDataset of size 668946 (MpraDaraset)
    Number of datapoints: 668946
    Used split fold: ['1', '2', '3', '4', '5', '6', '8', '9', '10', '11', '12', '14', '15', '16', '17', '18', '20', '22', 'Y']
Dataset MalinoisDataset of size 62406 (MpraDaraset)
    Number of datapoints: 62406
    Used split fold: ['19', '21', 'X']
Dataset MalinoisDataset of size 66712 (MpraDaraset)
    Number of datapoints: 66712
    Used split fold: ['7', '13']


# Train

In [7]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [8]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [9]:
class L1KLmixed(nn.Module):
    
    def __init__(self, reduction='mean', alpha=1.0, beta=5.0):
        super().__init__()
        
        self.reduction = reduction
        self.alpha = alpha
        self.beta  = beta
        
        self.MSE = nn.L1Loss(reduction=reduction.replace('batch',''))
        self.KL  = nn.KLDivLoss(reduction=reduction, log_target=True)
        
    def forward(self, preds, targets):
        preds_log_prob  = preds   - torch.logsumexp(preds, dim=-1, keepdim=True)
        target_log_prob = targets - torch.logsumexp(targets, dim=-1, keepdim=True)
        
        MSE_loss = self.MSE(preds, targets)
        KL_loss  = self.KL(preds_log_prob, target_log_prob)
        
        combined_loss = MSE_loss.mul(self.alpha) + \
                        KL_loss.mul(self.beta)
        
        return combined_loss.div(self.alpha+self.beta)

## Train 5 times

In [11]:
with open('./results_summary_orig_opt_50_epochs.txt', 'w') as f:
    f.write("Run\tK562\tHepG2\tSKNSH\n")

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

test_results = []

n_runs = 5

for run in range(n_runs):
    
    model = BassetBranched(n_outputs = out_channels, loss_criterion = L1KLmixed())

    seq_model = LitModel_Malinois(model = model, num_outputs = out_channels,
                           loss = L1KLmixed(),
                           weight_decay = 0.0003438210249762151, lr = 0.0032658700881052086, print_each = 10)
    
    # Callback to save the best validation model
    checkpoint_callback = ModelCheckpoint(
        monitor='val_mean_pearson', 
        mode='max',  
        save_top_k=1,
        dirpath='./checkpoints_run/',
        filename=f'run_{run}_' + 'best_model-{epoch:02d}-{val_mean_pearson:.3f}',
        save_last=False
    )
    
    # Initialize a trainer
    trainer = L.Trainer(
        accelerator="gpu",
        devices=[0],
        max_epochs=50,
        gradient_clip_val=1,
        precision='16',
        enable_progress_bar=True,
        num_sanity_val_steps=0,
        callbacks=[checkpoint_callback]
    )
    
    trainer.fit(seq_model,
                train_dataloaders=train_loader,
                val_dataloaders=val_loader)
    
    # Load the best model
    best_model_path = checkpoint_callback.best_model_path
    seq_model = LitModel_Malinois.load_from_checkpoint(
        best_model_path,
        model=model,
        num_outputs=out_channels,
        loss=L1KLmixed(),
        weight_decay = 0.0003438210249762151,
        lr = 0.0032658700881052086,
        print_each=1
    )
    
    result = trainer.test(seq_model, dataloaders=test_loader)[0]
    
    line = f"{run+1}\t"
    for col in activity_columns:
        pearson = result[f'test_{col}_pearson']
        line += f"{pearson:.4f}\t"
    
    # write in file
    with open('./results_summary_orig_opt_50_epochs.txt', 'a') as f:
        f.write(line + '\n')

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/nios/miniconda3/envs/mpramnist/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/nios/5Term/examples/checkpoints_run exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | BassetBranched  | 4.1 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | tr

Training: |                                                                                                   …



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.11173 | Arith Mean Loss: 0.11190 | Harm Mean Loss: 0.61206 || Val Mean Pearson: 0.83539 |
| Val Pearson K562: 0.83946 | Val Pearson HepG2: 0.83033 | Val Pearson SKNSH: 0.83638 
| Train Pearson K562: 0.80760 | Train Pearson HepG2: 0.79510 | Train Pearson SKNSH: 0.80205 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.10173 | Arith Mean Loss: 0.10182 | Harm Mean Loss: 0.62119 || Val Mean Pearson: 0.83398 |
| Val Pearson K562: 0.83502 | Val Pearson HepG2: 0.82944 | Val Pearson SKNSH: 0.83747 
| Train Pearson K562: 0.79633 | Train Pearson HepG2: 0.78696 | Train Pearson SKNSH: 0.79700 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.09802 | Arith Mean Loss: 0.09807 | Harm Mean Loss: 0.60888 || Val Mean Pearson: 0.84625 |
| Val Pearson K562: 0.84863 | Val Pearson HepG2: 0.84269 | Val Pearson SKNSH: 0.84744 
| Train Pearson K562: 0.80860 | Train Pearson HepG2: 0.79972 | Train Pearson SKNSH: 0.80783 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.09359 | Arith Mean Loss: 0.09368 | Harm Mean Loss: 0.52188 || Val Mean Pearson: 0.85763 |
| Val Pearson K562: 0.86557 | Val Pearson HepG2: 0.85225 | Val Pearson SKNSH: 0.85507 
| Train Pearson K562: 0.81893 | Train Pearson HepG2: 0.81108 | Train Pearson SKNSH: 0.81826 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.09380 | Arith Mean Loss: 0.09387 | Harm Mean Loss: 0.56461 || Val Mean Pearson: 0.86132 |
| Val Pearson K562: 0.86757 | Val Pearson HepG2: 0.85688 | Val Pearson SKNSH: 0.85952 
| Train Pearson K562: 0.82927 | Train Pearson HepG2: 0.82133 | Train Pearson SKNSH: 0.82743 
--------------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                                    …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   test_HepG2_pearson       0.8498441576957703
    test_K562_pearson       0.8495825529098511
   test_SKNSH_pearson       0.8478450179100037
       train_loss           0.07553713023662567
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | BassetBranched  | 4.1 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.429    Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.10503 | Arith Mean Loss: 0.10511 | Harm Mean Loss: 0.70362 || Val Mean Pearson: 0.81792 |
| Val Pearson K562: 0.82051 | Val Pearson HepG2: 0.81310 | Val Pearson SKNSH: 0.82015 
| Train Pearson K562: 0.81047 | Train Pearson HepG2: 0.80117 | Train Pearson SKNSH: 0.81032 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.10706 | Arith Mean Loss: 0.10726 | Harm Mean Loss: 0.58915 || Val Mean Pearson: 0.84144 |
| Val Pearson K562: 0.84450 | Val Pearson HepG2: 0.83927 | Val Pearson SKNSH: 0.84054 
| Train Pearson K562: 0.79885 | Train Pearson HepG2: 0.79194 | Train Pearson SKNSH: 0.79989 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.10033 | Arith Mean Loss: 0.10036 | Harm Mean Loss: 0.60023 || Val Mean Pearson: 0.83970 |
| Val Pearson K562: 0.84515 | Val Pearson HepG2: 0.83380 | Val Pearson SKNSH: 0.84017 
| Train Pearson K562: 0.81034 | Train Pearson HepG2: 0.80467 | Train Pearson SKNSH: 0.81157 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …


--------------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.09331 | Arith Mean Loss: 0.09335 | Harm Mean Loss: 0.55411 || Val Mean Pearson: 0.85888 |
| Val Pearson K562: 0.86284 | Val Pearson HepG2: 0.85665 | Val Pearson SKNSH: 0.85715 
| Train Pearson K562: 0.82093 | Train Pearson HepG2: 0.81420 | Train Pearson SKNSH: 0.82061 
--------------------------------------------------------------------------------------



Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.08991 | Arith Mean Loss: 0.09000 | Harm Mean Loss: 0.48397 || Val Mean Pearson: 0.87026 |
| Val Pearson K562: 0.87554 | Val Pearson HepG2: 0.86669 | Val Pearson SKNSH: 0.86855 
| Train Pearson K562: 0.83187 | Train Pearson HepG2: 0.82621 | Train Pearson SKNSH: 0.83285 
--------------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                                    …

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
   test_HepG2_pearson       0.8526626825332642
    test_K562_pearson       0.8501366376876831
   test_SKNSH_pearson       0.8464100956916809
       train_loss           0.07523658126592636
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | BassetBranched  | 4.1 M  | train
1 | loss          | L1KLmixed       | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.429    Total estimated model params size (MB)
38        Modules in train mode
0         Modules in eval mode


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …