In [1]:
import mpramnist
from mpramnist.Dream.dataset import DreamDataset

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Dream

import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data

import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch import loggers as pl_loggers

# Define some required parameters and functions

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 8

In [3]:
length = 120
plasmid = DreamDataset.PLASMID.upper()
insert_start = plasmid.find("N"*80)
right_flank = DreamDataset.RIGHT_FLANK
left_flank = plasmid[insert_start - length : insert_start]

In [4]:
from torchmetrics import PearsonCorrCoef

forw_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.LeftCrop(length,length),
    t.Seq2Tensor()
])
rev_transform = t.Compose([
     t.AddFlanks(left_flank, right_flank),
    t.LeftCrop(length,length),
    t.Seq2Tensor(),
    t.ReverseComplement(1)
])

def meaned_prediction(forw, rev, trainer, seq_model, name, is_paired = False):

    predictions_forw = trainer.predict(seq_model, dataloaders = forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["ref_predicted"] for pred in predictions_forw])
    
    predictions_rev = trainer.predict(seq_model, dataloaders = rev)
    y_preds_rev = torch.cat([pred["ref_predicted"] for pred in predictions_rev])
    
    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
    
    pears = PearsonCorrCoef()
    print("Task '" + name + "' Pearson r^2")
    
    if is_paired:
        y_preds_forw_alt = torch.cat([pred["alt_predicted"] for pred in predictions_forw])
        y_preds_rev_alt = torch.cat([pred["alt_predicted"] for pred in predictions_rev])
        mean_alt = torch.mean(torch.stack([y_preds_forw_alt, y_preds_rev_alt]), dim=0)
        pred = mean_alt - mean_forw
        return pears(pred, targets) * pears(pred, targets) 
        
    return pears(mean_forw, targets) * pears(mean_forw, targets) 

# Download data 

In [5]:
DreamDataset.TYPES

['high',
 'low',
 'yeast',
 'challenging',
 'random',
 'all',
 'snv',
 'perturbation',
 'tiling']

In [6]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.LeftCrop(length,length),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5)
])
val_test_transform = t.Compose([
    t.AddFlanks(left_flank, right_flank),
    t.LeftCrop(length, length),
    t.Seq2Tensor(), 
    t.ReverseComplement(0)
])

In [7]:
# load the data
train_dataset = DreamDataset(split="train", transform = train_transform, root = "../data")     

val_dataset = DreamDataset(split="val", data_type = ["all"], transform = val_test_transform, root = "../data") 

test_dataset = DreamDataset(split="test", data_type = ["all"], transform = val_test_transform, root = "../data")

In [8]:
print(train_dataset)

Dataset DreamDataset of size 6739258 (MpraDaraset)
    Number of datapoints: 6739258
    Used split fold: train


In [9]:
print(val_dataset)
print("------------")
print(test_dataset)

Dataset DreamDataset of size 9045 (MpraDaraset)
    Number of datapoints: 9045
    Used split fold: public
------------
Dataset DreamDataset of size 62058 (MpraDaraset)
    Number of datapoints: 62058
    Used split fold: private


# Train

In [10]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)

val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [11]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [12]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[1,1,1,1],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model = LitModel_Dream(model = model,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-2, lr = 1e-2, print_each = 10)

In [13]:
checkpoint_callback = ModelCheckpoint(
        monitor='val_pearson', 
        mode='max',  
        save_top_k=1,
        dirpath='./checkpoints_dream/',
        filename='best_model-{epoch:02d}-{val_pearson:.3f}',
        save_last=False
    )
logger = pl_loggers.TensorBoardLogger("./logs", name="Dream")
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs = 60,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
    logger = logger
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [14]:
# Train the model
trainer.fit(seq_model,
            train_dataloaders = train_loader,
            val_dataloaders = val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-07-27 15:34:22.523869: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-27 15:34:22.538585: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753619662.556517  300615 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin 

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


---------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 145.07829 | Val Pearson: 0.94852 | Train Pearson: 0.73891 
---------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 135.91905 | Val Pearson: 0.94994 | Train Pearson: 0.73880 
----------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 137.08499 | Val Pearson: 0.94977 | Train Pearson: 0.74161 
----------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 142.79858 | Val Pearson: 0.95525 | Train Pearson: 0.74675 
----------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 141.80118 | Val Pearson: 0.96352 | Train Pearson: 0.75554 
----------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=60` reached.



----------------------------------------------------------------------------------
| Epoch: 59 | Val Loss: 138.87439 | Val Pearson: 0.96022 | Train Pearson: 0.76490 
----------------------------------------------------------------------------------



# Evaluate

In [15]:
# Load the best model
best_model_path = checkpoint_callback.best_model_path
best_model_path = "./checkpoints_dream/best_model_without_pooling-epoch=43-val_pearson=0.967.ckpt"
seq_model = LitModel_Dream.load_from_checkpoint(
    best_model_path,
    model=model,
    loss = nn.MSELoss(),
    weight_decay = 1e-2, 
    lr = 1e-2, 
    print_each = 10
)

trainer.test(seq_model, dataloaders=test_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-08-01 15:22:43.163678: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-01 15:22:43.187480: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754050963.217084 1602780 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin 

Testing: |                                                                                     | 0/? [00:00<?,…

[{'test_loss': 134.9116668701172, 'test_pearson': 0.9633164405822754}]

## Single

### All Sequences

In [16]:
test_forw = DreamDataset(split="test", data_type = ["all"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["all"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "All Sequences")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'All Sequences' Pearson r^2


tensor(0.9327)

### High

In [17]:
test_forw = DreamDataset(split="test", data_type = ["high"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["high"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "High")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'High' Pearson r^2


tensor(0.3519)

### Low

In [18]:
test_forw = DreamDataset(split="test", data_type = ["low"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["low"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Low")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'Low' Pearson r^2


tensor(0.2876)

### Native

In [19]:
test_forw = DreamDataset(split="test", data_type = ["yeast"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["yeast"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Native")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'Native' Pearson r^2


tensor(0.7448)

### Random

In [20]:
test_forw = DreamDataset(split="test", data_type = ["random"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["random"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Random")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'Random' Pearson r^2


tensor(0.9540)

### Challenging

In [21]:
test_forw = DreamDataset(split="test", data_type = ["challenging"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["challenging"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Challenging")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'Challenging' Pearson r^2


tensor(0.9243)

## Paired

### SNVs

In [25]:
test_forw = DreamDataset(split="test", data_type = ["snv"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["snv"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "SNVs", is_paired = True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'SNVs' Pearson r^2


tensor(0.6918)

### Motif Perturbation

In [26]:
test_forw = DreamDataset(split="test", data_type = ["perturbation"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["perturbation"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Motif Perturbation", is_paired = True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'Motif Perturbation' Pearson r^2


tensor(0.9508)

### Motif Tiling

In [27]:
test_forw = DreamDataset(split="test", data_type = ["tiling"], transform = forw_transform, root = "../data")
test_rev = DreamDataset(split="test", data_type = ["tiling"], transform = rev_transform, root = "../data")

forw = data.DataLoader(dataset = test_forw, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)
rev = data.DataLoader(dataset = test_rev, batch_size = BATCH_SIZE, shuffle = False, num_workers = NUM_WORKERS, pin_memory = True)

meaned_prediction(forw, rev, trainer, seq_model, "Motif Tiling", is_paired = True)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

Task 'Motif Tiling' Pearson r^2


tensor(0.8840)