In [1]:
import mpramnist
from mpramnist.Vaishnav.dataset import VaishnavDataset

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Vaishnav

import torch
import torch.nn as nn
import torch.utils.data as data

import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import loggers as pl_loggers

from torchmetrics import PearsonCorrCoef

# Define some required parameters and functions

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 103

In [3]:
def meaned_prediction(forw, rev, trainer, seq_model, name, is_paired=False):
    predictions_forw = trainer.predict(seq_model, dataloaders=forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["ref_predicted"] for pred in predictions_forw])

    predictions_rev = trainer.predict(seq_model, dataloaders=rev)
    y_preds_rev = torch.cat([pred["ref_predicted"] for pred in predictions_rev])

    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)

    pears = PearsonCorrCoef()
    print(name + " Pearson correlation")

    if is_paired:
        y_preds_forw_alt = torch.cat(
            [pred["alt_predicted"] for pred in predictions_forw]
        )
        y_preds_rev_alt = torch.cat([pred["alt_predicted"] for pred in predictions_rev])
        mean_alt = torch.mean(torch.stack([y_preds_forw_alt, y_preds_rev_alt]), dim=0)
        pred = mean_alt - mean_forw
        return pears(pred, targets)

    return pears(mean_forw, targets)

**Important note**: Sequence lengths vary. To standardize them:

* Original flanks should be preserved

* Missing regions need to be supplemented from the source plasmid, use LeftFlank for it

* All sequences must be adjusted to the default 110 bp length (as in the original protocol)

In [4]:
length = 110
plasmid = VaishnavDataset.PLASMID.upper()
insert_start = plasmid.find("N" * 80)
right_flank = VaishnavDataset.RIGHT_FLANK
left_flank = plasmid[insert_start - length : insert_start]

In [5]:
# preprocessing
train_transform = t.Compose(
    [
        t.AddFlanks(left_flank, right_flank),
        t.LeftCrop(length, length),
        t.ReverseComplement(0.5),
        t.Seq2Tensor(),
    ]
)
val_test_transform = t.Compose(
    [
        t.AddFlanks(left_flank, right_flank),
        t.LeftCrop(length, length),
        t.ReverseComplement(0),
        t.Seq2Tensor(),
    ]
)

In the original study, two complementary environments with opposing selective pressures on URA3 gene expression (encoding an enzyme responsible for uracil synthesis) were investigated:

`defined` environment, where organismal fitness increases with gene expression (up to saturation);

`complex` environment + 5-FOA, where fitness decreases with Ura3p expression.

Use the `dataset_env_type` parameter to select either `'defined'` or `'complex'`.

# Dataset Specifications:

`defined`: (1) Contains 20 million sequences (2) 10% allocated for validation (3) Remainder used for training

`complex`: (1) Contains 31 million sequences (2) 10% allocated for validation (3) Remainder used for training

# Train **defined** env type

In [6]:
# load the data
train_dataset = VaishnavDataset(
    split="train",
    dataset_env_type="defined",
    transform=train_transform,
    root="../data/",
)

val_dataset = VaishnavDataset(
    split="val",
    dataset_env_type="defined",
    transform=val_test_transform,
    root="../data/",
)

In [7]:
print(train_dataset)
print("------------")
print(val_dataset)

Dataset VaishnavDataset of size 18933667 (MpraDaraset)
    Number of datapoints: 18933667
    Used split fold: train
------------
Dataset VaishnavDataset of size 2103740 (MpraDaraset)
    Number of datapoints: 2103740
    Used split fold: val


In [8]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [9]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [10]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[1, 1, 1, 1],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model_defined = LitModel_Vaishnav(
    model=model, loss=nn.MSELoss(), weight_decay=1e-1, lr=1e-2, print_each=1
)

In [11]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_pearson",
    mode="max",
    save_top_k=1,
    dirpath="./checkpoints_vaishnav/",
    filename="best_model_without_pooling-{epoch:02d}-{val_pearson:.3f}",
    save_last=False,
)
logger = pl_loggers.TensorBoardLogger("./logs", name="Vaishnav")
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=1,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
    logger=logger,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
# Train the model
trainer.fit(
    seq_model_defined, train_dataloaders=train_loader, val_dataloaders=val_loader
)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-07-31 20:33:06.493257: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-31 20:33:06.508864: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753983186.526717 3399413 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin 

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=1` reached.



-------------------------------------------------------------------------------
| Epoch: 0 | Val Loss: 3.75383 | Val Pearson: 0.88450 | Train Pearson: 0.84446 
-------------------------------------------------------------------------------



## Test Sequences:

Test sequences are divided into three categories per environment:

* Reference (`native`)

* Alternative (`drift`)

* Paired (`paired`)

Use the `test_dataset_type`

In [13]:
# Load the best model
best_model_path = checkpoint_callback.best_model_path
seq_model_defined = LitModel_Vaishnav.load_from_checkpoint(
    best_model_path,
    model=model,
    loss=nn.MSELoss(),
    weight_decay=1e-1,
    lr=1e-2,
    print_each=1,
)

In [14]:
forw_transform = t.Compose(
    [
        t.AddFlanks(left_flank, right_flank),
        t.LeftCrop(length, length),
        t.ReverseComplement(0),
        t.Seq2Tensor(),
    ]
)
rev_transform = t.Compose(
    [
        t.AddFlanks(left_flank, right_flank),
        t.LeftCrop(length, length),
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

In [15]:
test_forw = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="native",
    transform=forw_transform,
    root="../data/",
)
test_rev = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="native",
    transform=rev_transform,
    root="../data/",
)
print("native info")
print(test_forw)

forw_defined_native = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_defined_native = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    forw_defined_native, rev_defined_native, trainer, seq_model_defined, name="native"
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


native info
Dataset VaishnavDataset of size 3978 (MpraDaraset)
    Number of datapoints: 3978
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

native Pearson correlation


tensor(0.9779)

In [16]:
test_forw = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="drift",
    transform=forw_transform,
    root="../data/",
)
test_rev = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="drift",
    transform=rev_transform,
    root="../data/",
)
print("drift info")
print(test_forw)

forw_defined_drift = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_defined_drift = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    forw_defined_drift, rev_defined_drift, trainer, seq_model_defined, name="drift"
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


drift info
Dataset VaishnavDataset of size 2986 (MpraDaraset)
    Number of datapoints: 2986
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

drift Pearson correlation


tensor(0.9869)

### Note on paired sequences:

Each paired sequence contains: (1) A reference sequence (2) An alternative sequence (3) Differential expression column

In [17]:
dataset_paired = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="paired",
    root="../data/",
)
dataset_paired[0]

({'seq': 'CTTTCAATTGGGTGGGGACGCGACGGCGCCCCGGCTAGGATGCTAGCGTACTATGCTGCCTGAAAGTCTATAGGAGCATT',
  'seq_alt': 'CTTTAAATTCGGTGGGGACGCGTCGGCGCCCCGGCTAGGATGCTAGCGTACTATGCTGCCTGAAAGTCTATAGGAGCATT'},
 tensor(0.6423))

In [18]:
test_forw = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="paired",
    transform=forw_transform,
    root="../data/",
)
test_rev = VaishnavDataset(
    split="test",
    dataset_env_type="defined",
    test_dataset_type="paired",
    transform=rev_transform,
    root="../data/",
)
print("paired info")
print(test_forw)

forw_defined_paired = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_defined_paired = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    forw_defined_paired,
    rev_defined_paired,
    trainer,
    seq_model_defined,
    name="paired",
    is_paired=True,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


paired info
Dataset VaishnavDataset of size 2986 (MpraDaraset)
    Number of datapoints: 2986
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

paired Pearson correlation


tensor(0.8529)

# Train **complex** env type

In [19]:
# load the data
train_dataset = VaishnavDataset(
    split="train",
    dataset_env_type="complex",
    transform=train_transform,
    root="../data/",
)

val_dataset = VaishnavDataset(
    split="val",
    dataset_env_type="complex",
    transform=val_test_transform,
    root="../data/",
)

In [20]:
print(train_dataset)
print("------------")
print(val_dataset)

Dataset VaishnavDataset of size 28214427 (MpraDaraset)
    Number of datapoints: 28214427
    Used split fold: train
------------
Dataset VaishnavDataset of size 3134936 (MpraDaraset)
    Number of datapoints: 3134936
    Used split fold: val


In [21]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [22]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [23]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[1, 1, 1, 1],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model_complex = LitModel_Vaishnav(
    model=model, loss=nn.MSELoss(), weight_decay=1e-1, lr=1e-2, print_each=1
)

In [24]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_pearson",
    mode="max",
    save_top_k=1,
    dirpath="./checkpoints_vaishnav/",
    filename="best_model_without_pooling-{epoch:02d}-{val_pearson:.3f}",
    save_last=False,
)
logger = pl_loggers.TensorBoardLogger("./logs", name="Vaishnav")
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=1,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
    logger=logger,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [25]:
# Train the model
trainer.fit(
    seq_model_complex, train_dataloaders=train_loader, val_dataloaders=val_loader
)

/home/nios/miniconda3/envs/mpra/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/nios/5Term/examples/checkpoints_vaishnav exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
120       Modules in train mode
0         Modules in eval mode


Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=1` reached.



-------------------------------------------------------------------------------
| Epoch: 0 | Val Loss: 5.24360 | Val Pearson: 0.84209 | Train Pearson: 0.80531 
-------------------------------------------------------------------------------



## Test Sequences:

Test sequences are divided into three categories per environment:

* Reference (`native`)

* Alternative (`drift`)

* Paired (`paired`)

Use the `test_dataset_type`

In [26]:
# Load the best model
best_model_path = checkpoint_callback.best_model_path

seq_model_complex = LitModel_Vaishnav.load_from_checkpoint(
    best_model_path,
    model=model,
    loss=nn.MSELoss(),
    weight_decay=1e-1,
    lr=1e-2,
    print_each=1,
)

In [27]:
forw_transform = t.Compose(
    [
        t.AddFlanks(left_flank, right_flank),
        t.LeftCrop(length, length),
        t.ReverseComplement(0),
        t.Seq2Tensor(),
    ]
)
rev_transform = t.Compose(
    [
        t.AddFlanks(left_flank, right_flank),
        t.LeftCrop(length, length),
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

In [28]:
test_forw = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="native",
    transform=forw_transform,
    root="../data/",
)
test_rev = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="native",
    transform=rev_transform,
    root="../data/",
)
print("native info")
print(test_forw)

forw_complex_native = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_complex_native = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    forw_complex_native, rev_complex_native, trainer, seq_model_complex, name="native"
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


native info
Dataset VaishnavDataset of size 3929 (MpraDaraset)
    Number of datapoints: 3929
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

native Pearson correlation


tensor(0.9746)

In [29]:
test_forw = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="drift",
    transform=forw_transform,
    root="../data/",
)
test_rev = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="drift",
    transform=rev_transform,
    root="../data/",
)
print("drift info")
print(test_forw)

forw_complex_drift = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_complex_drift = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    forw_complex_drift, rev_complex_drift, trainer, seq_model_complex, name="drift"
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


drift info
Dataset VaishnavDataset of size 2983 (MpraDaraset)
    Number of datapoints: 2983
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

drift Pearson correlation


tensor(0.9853)

### Note on paired sequences:

Each paired sequence contains: (1) A reference sequence (2) An alternative sequence (3) Differential expression column

In [30]:
dataset_paired = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="paired",
    root="../data/",
)
dataset_paired[0]

({'seq': 'CTTTCAATTGGGTGGGGACGCGACGGCGCCCCGGCTAGGATGCTAGCGTACTATGCTGCCTGAAAGTCTATAGGAGCATT',
  'seq_alt': 'CTTTCAATTGGGTGGGGACGCGACGGCGCCCCGACTAGGATGCTAGCGTACTATGCTGCCTGAAAGTCTATAGGAGCATT'},
 tensor(-0.2812))

In [31]:
test_forw = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="paired",
    transform=forw_transform,
    root="../data/",
)
test_rev = VaishnavDataset(
    split="test",
    dataset_env_type="complex",
    test_dataset_type="paired",
    transform=rev_transform,
    root="../data/",
)
print("paired info")
print(test_forw)

forw_complex_paired = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_complex_paired = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    forw_complex_paired,
    rev_complex_paired,
    trainer,
    seq_model_complex,
    name="paired",
    is_paired=True,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


paired info
Dataset VaishnavDataset of size 2983 (MpraDaraset)
    Number of datapoints: 2983
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

paired Pearson correlation


tensor(0.8796)