In [1]:
import mpramnist
from mpramnist.AgarwalJoint.dataset import AgarwalJointDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_AgarwalJoint

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import pytorch_lightning as L

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 103
lr = 0.01

In [3]:
left_flank = AgarwalJointDataset.LEFT_FLANK # original flanks from human_legnet
right_flank = AgarwalJointDataset.RIGHT_FLANK

## First, we read the MPRAdata, preprocess them and encapsulate them into dataloader form.

In [4]:
# preprocessing
train_transform = t.Compose([
    t.AddFlanks("", right_flank), # this is original parameters for human_legnet
    t.RightCrop(230,250),
    t.RandomCrop(230),
    t.Seq2Tensor(),
    t.ReverseComplement(0.5),

])
test_transform = t.Compose([ # трансформы теста слегка другие
    t.Seq2Tensor(), 

])
test_transform_rev = t.Compose([ # трансформы теста слегка другие
    t.Seq2Tensor(), 
    t.ReverseComplement(1)
])

activity_columns = ["HepG2","K562","WTC11"]
# load the data
train_dataset = AgarwalJointDataset(cell_type = activity_columns, split="train", transform=train_transform) # could use a list e.g. [1,2,5,6,7,8]                                                                                             # for needed folds
val_dataset = AgarwalJointDataset(cell_type =activity_columns, split="val", transform=test_transform) # use "val" for default validation set

test_dataset = AgarwalJointDataset(cell_type = activity_columns, split="test", transform=test_transform) # use "test" for default test set
test_dataset_rev = AgarwalJointDataset(cell_type = activity_columns, split="test", transform=test_transform_rev) 

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers = NUM_WORKERS)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)
test_loader_rev = data.DataLoader(dataset=test_dataset_rev, batch_size=BATCH_SIZE, shuffle=False, num_workers = NUM_WORKERS)

In [5]:
print(train_dataset)
print("===================")
print(test_dataset)

Dataset AgarwalJointDataset of size 44264 (MpraDaraset)
    Number of datapoints: 44264
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
Dataset AgarwalJointDataset of size 5541 (MpraDaraset)
    Number of datapoints: 5541
    Used split fold: [10]


In [6]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [7]:
model = HumanLegNet(in_ch=in_channels,
                     output_dim = out_channels,
                     stem_ch=64,
                     stem_ks=11,
                     ef_ks=9,
                     ef_block_sizes=[80, 96, 112, 128],
                     pool_sizes=[2,2,2,2],
                     resize_factor=4)
model.apply(initialize_weights)

seq_model = LitModel_AgarwalJoint(model = model, num_outputs = out_channels,
                           loss = nn.MSELoss(),
                           weight_decay = 1e-1, lr = 1e-2, print_each = 1)

In [8]:
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=1,
    gradient_clip_val=1,
    precision='16-mixed', 
    enable_progress_bar = True,
    num_sanity_val_steps=0
)

# Train the model
trainer.fit(seq_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-04-10 19:39:12.582475: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-10 19:39:12.597459: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one 

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=1` reached.



-------------------------------------------------------------------------------
| Epoch: 0 | Val Loss: 1.27111 | Val Pearson: 0.42286 | Train Pearson: 0.39261 
-------------------------------------------------------------------------------



In [10]:
from torchmetrics import PearsonCorrCoef
predictions_forw = trainer.predict(seq_model, dataloaders = test_loader)
targets = torch.cat([pred["target"] for pred in predictions_forw])
y_preds_forw = torch.cat([pred["predicted"] for pred in predictions_forw])

predictions_rev = trainer.predict(seq_model, dataloaders = test_loader_rev)
y_preds_rev = torch.cat([pred["predicted"] for pred in predictions_rev])

mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)
pears = PearsonCorrCoef(num_outputs = out_channels)
    
pears(mean_forw, targets)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

tensor([0.4719, 0.3864, 0.4671])