In [1]:
import mpramnist
from mpramnist.DeepStarr.dataset import DeepStarrDataset

from mpramnist.models import DeepStarr
from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_DeepStarr

from mpramnist import transforms as t
from mpramnist import target_transforms as t_t

import torch
import torch.nn as nn
import torch.utils.data as data

import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint

from torchmetrics import PearsonCorrCoef

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 103

activity_columns = DeepStarrDataset.ACTIVITY_COLUMNS

In [3]:
def meaned_prediction(forw, rev, trainer, seq_model, name):
    predictions_forw = trainer.predict(seq_model, dataloaders=forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["predicted"] for pred in predictions_forw])

    predictions_rev = trainer.predict(seq_model, dataloaders=rev)
    y_preds_rev = torch.cat([pred["predicted"] for pred in predictions_rev])

    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)

    pears = PearsonCorrCoef(num_outputs=2)
    print(name + " Pearson correlation")

    return pears(mean_forw, targets)

# Reverse-complement

In case you want to use reverse-complement as it was used in the original study(*Dataset size: 2N (N original + N reverse-complemented sequences)*), then use *use_original_reverse_complement* attribute. Reverse-complement will be added in case if you use `split = 'train'`.

To turn off reverse-complement with training set use `use_original_reverse_complement = False, split = 'train'`

**WARNING**: in the original study, reverse-complement  was applied **only** to the training set.

For example:

In [4]:
train_transform = t.Compose(
    [
        t.Seq2Tensor(),  # t.Reverse(0.5) is not needed here
    ]
)
val_test_transform = t.Compose(
    [
        t.Seq2Tensor(),
    ]
)
orig_rev_comp_train_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="train",
    transform=train_transform,
    root="../data/",
)
# VAL NO CHANGES
val_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="val",
    transform=val_test_transform,
    root="../data/",
)
# TEST NO CHANGES
test_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=val_test_transform,
    root="../data/",
)

Note: The training set contains reverse-complement augmentation as implemented in the original study.  
• Dataset size: 2N (N original + N reverse-complemented sequences)  
• Label consistency: y_rc ≡ y_original  
• Do not reapply this transformation during preprocessing. 


In [5]:
print(len(orig_rev_comp_train_dataset))
print(len(val_dataset))
print(len(test_dataset))

402278
40570
41186


But we suggest using reverse-complement transformation by writing **transforms.Reverse(prob = 0.5)**. *prob = 0.5* means that the sequence will be reversed with probability of 0.5. So use `use_original_reverse_complement = False, split = 'train'`

For example:

In [6]:
train_transform = t.Compose(
    [
        t.ReverseComplement(0.5),
        t.Seq2Tensor(),
    ]
)
val_test_transform = t.Compose(
    [
        t.ReverseComplement(0),
        t.Seq2Tensor(),
    ]
)

In [7]:
train_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    use_original_reverse_complement=False,
    split="train",
    transform=train_transform,
    root="../data/",
)

val_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="val",
    transform=val_test_transform,
    root="../data/",
)

test_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=val_test_transform,
    root="../data/",
)

In [8]:
print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

201139
40570
41186


# Data splitting

Sequences from the **first** and **second** half of chr2R were held out for validation and testing, respectively.

The remainig chromosomes are used for training set.

You can use *"train"*, *"val"*, *"test"* to define the training, validation or test set respectively **using the same approach as the original study**

For example:

In [9]:
train_transform = t.Compose(
    [
        t.Seq2Tensor(),
    ]
)
val_test_transform = t.Compose(
    [
        t.Seq2Tensor(),
    ]
)

In [10]:
orig_train_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="train",
    transform=train_transform,
    root="../data/",
)

val_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="val",
    transform=val_test_transform,
    root="../data/",
)

test_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=val_test_transform,
    root="../data/",
)

Note: The training set contains reverse-complement augmentation as implemented in the original study.  
• Dataset size: 2N (N original + N reverse-complemented sequences)  
• Label consistency: y_rc ≡ y_original  
• Do not reapply this transformation during preprocessing. 


In [11]:
print(len(orig_train_dataset))
print(len(val_dataset))
print(len(test_dataset))

402278
40570
41186


On the other hand, you can define a list of specific chromosomes that you want to use as training, validation ot test set

For example:

In [12]:
list_of_chr = DeepStarrDataset.LIST_OF_CHR
print(list_of_chr)

['chr2L', 'chr2LHet', 'chr2RHet', 'chr3L', 'chr3LHet', 'chr3R', 'chr3RHet', 'chr4', 'chrX', 'chrXHet', 'chrYHet', 'chr2R']


In [13]:
my_train_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split=[
        "chr2L",
        "chr2LHet",
        "chr2RHet",  # Reverse complement transformation is disabled for chromosome list splits.
        "chr3L",
        "chr3LHet",
        "chr3R",  # Set use_original_reverse_complement=True to apply original paper augmentation.
        "chr3RHet",
        "chr4",
    ],
    use_original_reverse_complement=False,
    transform=train_transform,
    root="../data/",
)

my_val_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split=["chrX", "chrXHet", "chrYHet"],
    transform=val_test_transform,
    root="../data/",
)

my_test_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="chr2R",
    transform=val_test_transform,
    root="../data/",
)

In [14]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [15]:
print(len(my_train_dataset))
print(len(my_val_dataset))
print(len(my_test_dataset))

151641
49498
81756


# Regression task

In [16]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

## Trainer

In [17]:
model = DeepStarr(out_channels)

seq_model = LitModel_DeepStarr(
    model=model,
    num_outputs=out_channels,
    loss=nn.MSELoss(),
    activity_columns=activity_columns,
    weight_decay=1e-6,
    lr=2e-3,
    print_each=10,
)

In [18]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_pearson", mode="max", save_top_k=1, save_last=False
)
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=50,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [19]:
# Train the model
trainer.fit(seq_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
2025-07-29 17:15:36.631922: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-29 17:15:36.647320: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753798536.665449 3808966 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin 

Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 1.56910 
| Val Pearson Dev_log2: 0.57692 | Val Pearson Hk_log2: 0.68949 | Mean Val Pearson: 0.63320 |
| Train Pearson Dev_log2: 0.59807 | Train Pearson Hk_log2: 0.69319 | Mean Train Pearson: 0.64563 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 1.41722 
| Val Pearson Dev_log2: 0.63139 | Val Pearson Hk_log2: 0.71465 | Mean Val Pearson: 0.67302 |
| Train Pearson Dev_log2: 0.67677 | Train Pearson Hk_log2: 0.74817 | Mean Train Pearson: 0.71247 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 1.41323 
| Val Pearson Dev_log2: 0.63774 | Val Pearson Hk_log2: 0.71505 | Mean Val Pearson: 0.67639 |
| Train Pearson Dev_log2: 0.72856 | Train Pearson Hk_log2: 0.78651 | Mean Train Pearson: 0.75753 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 1.40563 
| Val Pearson Dev_log2: 0.63076 | Val Pearson Hk_log2: 0.71479 | Mean Val Pearson: 0.67277 |
| Train Pearson Dev_log2: 0.76995 | Train Pearson Hk_log2: 0.81603 | Mean Train Pearson: 0.79299 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=50` reached.



-------------------------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 1.41771 
| Val Pearson Dev_log2: 0.62772 | Val Pearson Hk_log2: 0.71193 | Mean Val Pearson: 0.66983 |
| Train Pearson Dev_log2: 0.78580 | Train Pearson Hk_log2: 0.82674 | Mean Train Pearson: 0.80627 |
-------------------------------------------------------------------------------------------------



In [20]:
best_model_path = checkpoint_callback.best_model_path
seq_model = LitModel_DeepStarr.load_from_checkpoint(
    best_model_path,
    model=model,
    num_outputs=out_channels,
    loss=nn.MSELoss(),
    weight_decay=1e-6,
    lr=2e-3,
    print_each=1,
)

trainer.test(seq_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

[{'test_loss': 1.3943507671356201,
  'test_Developmental_pearson': 0.6484516263008118,
  'test_HouseKeeping_pearson': 0.7330984473228455}]

In [21]:
forw_transform = t.Compose([t.Seq2Tensor()])
rev_transform = t.Compose(
    [
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

test_forw = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=forw_transform,
    root="../data/",
)
test_rev = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=rev_transform,
    root="../data/",
)

forw = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(forw, rev, trainer, seq_model, "DeepStarr")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

DeepStarr Pearson correlation


tensor([0.6756, 0.7539])

# Now let's train using HumanLegNet

In [22]:
train_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    use_original_reverse_complement=False,
    split="train",
    transform=train_transform,
    root="../data/",
)
val_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="val",
    transform=val_test_transform,
    root="../data/",
)

test_dataset = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=val_test_transform,
    root="../data/",
)

In [23]:
print(train_dataset)
print("=" * 50)
print(val_dataset)
print("=" * 50)
print(test_dataset)

Dataset DeepStarrDataset of size 201139 (MpraDaraset)
    Number of datapoints: 201139
    Used split fold: ['train']
Dataset DeepStarrDataset of size 40570 (MpraDaraset)
    Number of datapoints: 40570
    Used split fold: ['val']
Dataset DeepStarrDataset of size 41186 (MpraDaraset)
    Number of datapoints: 41186
    Used split fold: ['test']


In [24]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [25]:
in_channels = len(train_dataset[0][0])
out_channels = len(activity_columns)

In [31]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[2, 2, 2, 2],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model = LitModel_DeepStarr(
    model=model,
    num_outputs=out_channels,
    loss=nn.MSELoss(),
    activity_columns=activity_columns,
    weight_decay=1e-1,
    lr=1e-2,
    print_each=10,
)

In [32]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_pearson", mode="max", save_top_k=1, save_last=False
)
# Initialize a trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=50,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [33]:
# Train the model
trainer.fit(seq_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.291     Total estimated model params size (MB)
120       Modules in train mode
0         Modules in eval mode


Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 1.49161 
| Val Pearson Dev_log2: 0.62380 | Val Pearson Hk_log2: 0.71212 | Mean Val Pearson: 0.66796 |
| Train Pearson Dev_log2: 0.71595 | Train Pearson Hk_log2: 0.77383 | Mean Train Pearson: 0.74489 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 2.04394 
| Val Pearson Dev_log2: 0.57375 | Val Pearson Hk_log2: 0.61260 | Mean Val Pearson: 0.59317 |
| Train Pearson Dev_log2: 0.76593 | Train Pearson Hk_log2: 0.80722 | Mean Train Pearson: 0.78657 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 2.09911 
| Val Pearson Dev_log2: 0.56413 | Val Pearson Hk_log2: 0.60888 | Mean Val Pearson: 0.58651 |
| Train Pearson Dev_log2: 0.84033 | Train Pearson Hk_log2: 0.86603 | Mean Train Pearson: 0.85318 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 1.71451 
| Val Pearson Dev_log2: 0.57991 | Val Pearson Hk_log2: 0.67396 | Mean Val Pearson: 0.62693 |
| Train Pearson Dev_log2: 0.96162 | Train Pearson Hk_log2: 0.96601 | Mean Train Pearson: 0.96382 |
-------------------------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=50` reached.



-------------------------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 1.68054 
| Val Pearson Dev_log2: 0.56814 | Val Pearson Hk_log2: 0.66344 | Mean Val Pearson: 0.61579 |
| Train Pearson Dev_log2: 0.99488 | Train Pearson Hk_log2: 0.99505 | Mean Train Pearson: 0.99497 |
-------------------------------------------------------------------------------------------------



In [34]:
best_model_path = checkpoint_callback.best_model_path
seq_model = LitModel_DeepStarr.load_from_checkpoint(
    best_model_path,
    model=model,
    num_outputs=out_channels,
    loss=nn.MSELoss(),
    weight_decay=1e-6,
    lr=2e-3,
    print_each=1,
)

trainer.test(seq_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                        | 0/? [00:00…

[{'test_loss': 1.5605711936950684,
  'test_Developmental_pearson': 0.6390811204910278,
  'test_HouseKeeping_pearson': 0.747373640537262}]

In [35]:
forw_transform = t.Compose([t.Seq2Tensor()])
rev_transform = t.Compose(
    [
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

test_forw = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=forw_transform,
    root="../data/",
)
test_rev = DeepStarrDataset(
    activity_column=activity_columns,
    split="test",
    transform=rev_transform,
    root="../data/",
)

forw = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(forw, rev, trainer, seq_model, "DeepStarr")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

DeepStarr Pearson correlation


tensor([0.6706, 0.7654])