In [1]:
import mpramnist
from mpramnist.AgarwalJoint.dataset import AgarwalJointDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_AgarwalJoint

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import torch
import torch.nn as nn
import torch.utils.data as data

import lightning.pytorch as L

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 103
lr = 0.01

In [3]:
constant_left_flank = (
    AgarwalJointDataset.CONSTANT_LEFT_FLANK
)  # required for each sequence
constant_rigtht_flank = (
    AgarwalJointDataset.CONSTANT_RIGHT_FLANK
)  # required for each sequence
left_flank = AgarwalJointDataset.LEFT_FLANK  # original flanks from human_legnet
right_flank = AgarwalJointDataset.RIGHT_FLANK

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [4]:
# preprocessing
train_transform = t.Compose(
    [
        t.AddFlanks(constant_left_flank, constant_rigtht_flank),
        t.AddFlanks("", right_flank),  # this is original parameters for human_legnet
        t.RightCrop(230, 250),
        t.RandomCrop(230),
        t.ReverseComplement(0.5),
        t.Seq2Tensor(),
    ]
)
test_transform = t.Compose(
    [  # трансформы теста слегка другие
        t.AddFlanks(constant_left_flank, constant_rigtht_flank),
        t.Seq2Tensor(),
    ]
)
test_transform_reversed = t.Compose(
    [
        t.AddFlanks(constant_left_flank, constant_rigtht_flank),
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

activity_columns = ["HepG2", "K562", "WTC11"]

# load the data
train_dataset = AgarwalJointDataset(
    cell_type=activity_columns,
    split="train",
    transform=train_transform,
    root="../data/",
)  # could use a list e.g. [1,2,5,6,7,8]                                                                                             # for needed folds
val_dataset = AgarwalJointDataset(
    cell_type=activity_columns, split="val", transform=test_transform, root="../data/"
)  # use "val" for default validation set

test_dataset = AgarwalJointDataset(
    cell_type=activity_columns, split="test", transform=test_transform, root="../data/"
)  # use "test" for default test set

# reversed test sequences
test_dataset_rev = AgarwalJointDataset(
    cell_type=activity_columns,
    split="test",
    transform=test_transform_reversed,
    root="../data/",
)

# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)
val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

# reversed test sequences loader
test_loader_rev = data.DataLoader(
    dataset=test_dataset_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
)

In [5]:
print(train_dataset)
print("===================")
print(test_dataset)

Dataset AgarwalJointDataset of size 44264 (MpraDaraset)
    Number of datapoints: 44264
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
Dataset AgarwalJointDataset of size 5541 (MpraDaraset)
    Number of datapoints: 5541
    Used split fold: [10]


In [6]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [7]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[2, 2, 2, 2],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model = LitModel_AgarwalJoint(
    model=model,
    num_outputs=out_channels,
    loss=nn.MSELoss(),
    weight_decay=1e-1,
    lr=1e-2,
    print_each=10,
)

In [8]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=50,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
)

# Train the model
trainer.fit(seq_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Using 16bit Automatic Mixed Precision (AMP)
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.
/home/nios/miniconda3/envs/mpramnist/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (44) is smaller than t

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=2` reached.


In [9]:
from torchmetrics import PearsonCorrCoef

predictions_forw = trainer.predict(seq_model, dataloaders=test_loader)
targets = torch.cat([pred["target"] for pred in predictions_forw])
y_preds_forw = torch.cat([pred["predicted"] for pred in predictions_forw])

predictions_rev = trainer.predict(seq_model, dataloaders=test_loader_rev)
y_preds_rev = torch.cat([pred["predicted"] for pred in predictions_rev])

mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
pears = PearsonCorrCoef(num_outputs=out_channels)

pears(mean_forw, targets)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

tensor([0.6269, 0.5636, 0.5804])