In [1]:
import mpramnist
from mpramnist.vaishnavdataset import VaishnavDataset

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Vaishnav

import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data

import pytorch_lightning as L

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 103

**Important note**: Sequence lengths vary. To standardize them:

* Original flanks should be preserved

* Missing regions need to be supplemented from the source plasmid

* All sequences must be adjusted to the default 110 bp length (as in the original protocol)

In [3]:
length = 110
plasmid = VaishnavDataset.PLASMID.upper()
insert_start = plasmid.find("N"*80)
right_flank = VaishnavDataset.RIGHT_FLANK
left_flank = plasmid[insert_start - length : insert_start]

In [None]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.LeftCrop(length,length),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5)
])
val_test_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.LeftCrop(length, length),
    t.Seq2Tensor(), 
    t.ReverseComplement(0)
])

In the original study, two complementary environments with opposing selective pressures on URA3 gene expression (encoding an enzyme responsible for uracil synthesis) were investigated:

`defined` environment, where organismal fitness increases with gene expression (up to saturation);

`complex` environment + 5-FOA, where fitness decreases with Ura3p expression.

Use the `dataset_env_type` parameter to select either `'defined'` or `'complex'`.

## Dataset Specifications:

`defined`: (1) Contains 20 million sequences (2) 10% allocated for validation (3) Remainder used for training

`complex`: (1) Contains 31 million sequences (2) 10% allocated for validation (3) Remainder used for training

In [4]:
# load the data
train_dataset = VaishnavDataset(split="train", dataset_type = "defined", transform=train_transform)                                                              
val_dataset = VaishnavDataset(split="val", dataset_type = "defined", transform=val_test_transform) 
test_dataset_native = VaishnavDataset(split="test", dataset_type = "defined", dataset_origin_type = "native", transform=val_test_transform)

In [5]:
print(train_dataset)

Dataset VaishnavDataset of size 18933667 (MpraDaraset)
    Number of datapoints: 18933667
    Used split fold: train


In [6]:
print(val_dataset)
print("------------")
print(test_dataset_native)

Dataset VaishnavDataset of size 2103740 (MpraDaraset)
    Number of datapoints: 2103740
    Used split fold: val
------------
Dataset VaishnavDataset of size 3978 (MpraDaraset)
    Number of datapoints: 3978
    Used split fold: test


## Test Sequences:

Test sequences are divided into three categories per environment:

* Reference (`native`)

* Alternative (`drift`)

* Paired (`paired`)

In [15]:
test_dataset_drift = VaishnavDataset(split="test", dataset_type = "defined", dataset_origin_type = "drift", transform=val_test_transform)
print(test_dataset_drift)
test_loader_drift = data.DataLoader(dataset=test_dataset_drift, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

Dataset VaishnavDataset of size 2986 (MpraDaraset)
    Number of datapoints: 2986
    Used split fold: test


## Note on paired sequences:

Each paired sequence contains: (1) A reference sequence (2) An alternative sequence (3) Differential expression column

In [21]:
test_dataset_paired = VaishnavDataset(split="test", dataset_type = "defined", dataset_origin_type = "paired", transform=val_test_transform)
print(test_dataset_paired)
test_loader_paired = data.DataLoader(dataset=test_dataset_paired, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

Dataset VaishnavDataset of size 2986 (MpraDaraset)
    Number of datapoints: 2986
    Used split fold: test


In [23]:
dataset_paired = VaishnavDataset(split="test", dataset_type = "defined", dataset_origin_type = "paired")
dataset_paired[0]

({'seq': 'CTTTCAATTGGGTGGGGACGCGACGGCGCCCCGGCTAGGATGCTAGCGTACTATGCTGCCTGAAAGTCTATAGGAGCATT',
  'seq_alt': 'CTTTAAATTCGGTGGGGACGCGTCGGCGCCCCGGCTAGGATGCTAGCGTACTATGCTGCCTGAAAGTCTATAGGAGCATT'},
 tensor(-0.6423))

In [7]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset_native, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [8]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [9]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model = LitModel_Vaishnav(model = model, in_ch = in_channels, out_ch = out_channels,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-1, lr = 1e-2, print_each = 1)

In [10]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=5,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = True,
    num_sanity_val_steps=0
)

# Train the model
trainer.fit(seq_model,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)
trainer.test(seq_model, dataloaders = test_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-04-06 21:12:12.987894: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-06 21:12:13.004233: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one 

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------
| Epoch: 0 | Val Loss: 4.92517 | Val Pearson: 0.85692 | Train Pearson: 0.84454 
-------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------
| Epoch: 1 | Val Loss: 10.69147 | Val Pearson: 0.83675 | Train Pearson: 0.86066 
--------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------
| Epoch: 2 | Val Loss: 25.05196 | Val Pearson: 0.58701 | Train Pearson: 0.86305 
--------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------
| Epoch: 3 | Val Loss: 5.02126 | Val Pearson: 0.85050 | Train Pearson: 0.87030 
-------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=5` reached.



-------------------------------------------------------------------------------
| Epoch: 4 | Val Loss: 3.68661 | Val Pearson: 0.88673 | Train Pearson: 0.88137 
-------------------------------------------------------------------------------



LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

[{'test_loss': 0.5523836016654968, 'test_pearson': 0.9833166599273682}]

In [18]:
trainer.test(seq_model, dataloaders = test_loader_drift)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

[{'test_loss': 0.33399754762649536, 'test_pearson': 0.9872218370437622}]

In [19]:
trainer.test(seq_model, dataloaders = test_loader_paired)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

[{'test_loss': 6.491516590118408, 'test_pearson': -0.8568166494369507}]