In [1]:
import mpramnist
from mpramnist.DeepPromoter.dataset import DeepPromoterDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_DeepPromoter

from mpramnist import transforms as t
from mpramnist import target_transforms as t_t

import torch
import torch.nn as nn
import torch.utils.data as data

import lightning.pytorch as L

In [2]:
BATCH_SIZE = 128
NUM_WORKERS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Define set of transforms
train_transform = t.Compose(
    [
        t.ReverseComplement(0.5),
        t.Seq2Tensor(),
    ]
)
val_test_transform = t.Compose(
    [  # probability of reverse-complement = 0.
        t.Seq2Tensor(),
    ]
)

In [4]:
train_dataset = DeepPromoterDataset(
    split="train", transform=train_transform, root="../data/"
)

val_dataset = DeepPromoterDataset(
    split="val", transform=val_test_transform, root="../data/"
)

test_dataset = DeepPromoterDataset(
    split="test", transform=val_test_transform, root="../data/"
)

In [5]:
print(train_dataset)
print("=" * 50)
print(val_dataset)
print("=" * 50)
print(test_dataset)

Dataset DeepPromoterDataset of size 9000 (MpraDaraset)
    Number of datapoints: 9000
    Used split fold: ['train']
Dataset DeepPromoterDataset of size 1000 (MpraDaraset)
    Number of datapoints: 1000
    Used split fold: ['val']
Dataset DeepPromoterDataset of size 1884 (MpraDaraset)
    Number of datapoints: 1884
    Used split fold: ['test']


In [6]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [7]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [8]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[2, 2, 2, 2],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model = LitModel_DeepPromoter(
    model=model, loss=nn.MSELoss(), weight_decay=1e-4, lr=1e-3, print_each=5
)

In [9]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=25,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
# Train the model
trainer.fit(seq_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-07-22 19:02:39.268683: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-22 19:02:39.285680: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753200159.308692 2259581 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin 

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


---------------------------------------------------------------------------------------
| Epoch: 4 | Val Loss: 208404080.00000 | Val Pearson: 0.07277 | Train Pearson: 0.15751 
---------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


---------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 208206432.00000 | Val Pearson: 0.00000 | Train Pearson: 0.00000 
---------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 14 | Val Loss: 208272336.00000 | Val Pearson: 0.13363 | Train Pearson: 0.00000 
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


----------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 208128464.00000 | Val Pearson: 0.18396 | Train Pearson: 0.00000 
----------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=25` reached.



----------------------------------------------------------------------------------------
| Epoch: 24 | Val Loss: 208182000.00000 | Val Pearson: 0.18046 | Train Pearson: 0.00000 
----------------------------------------------------------------------------------------



In [11]:
trainer.test(seq_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

[{'test_loss': 148308160.0, 'test_pearson': 0.25248026847839355}]