In [1]:
import mpramnist
from mpramnist.Agarwal.dataset import AgarwalDataset
from mpramnist.Kircher.dataset import KircherDataset

from mpramnist.models import HumanLegNet
from mpramnist.models import initialize_weights
from mpramnist.trainers import LitModel_Kircher

import mpramnist.transforms as t
import mpramnist.target_transforms as t_t

import numpy as np
import random
import torch
import torch.nn as nn
import torch.utils.data as data
import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint

from torchmetrics import PearsonCorrCoef

# Kircher Dataset and Experimental Design

The Kircher dataset is based on **MPRA (Massively Parallel Reporter Assay)** results from **saturation mutagenesis** experiments. The study characterized **44,647** variants of regulatory elements, including:

1. 5 enhancers (genes: IRF4, IRF6, MYC, SORT1, ZFAND3)

2. 9 promoters (genes: F9, GP1BB, HBB, HBG, HNF4A, LDLR, MSMB, PKLR, TERT)

across multiple cell lines.

# Proposed KircherDataset Benchmark Application

These experimentally validated sequences are recommended as a reference dataset for evaluating the performance of machine learning models.

# Current Workflow

In this notebook, we:

1. Train the **HumanLegNet** model on the **Agarwal dataset**

2. Assess its predictive power using **Kircher's saturation mutagenesis data**

Focus: Promoters tested in **HepG2** and **K562** cell lines



# Define some default parameters and functions

In [2]:
BATCH_SIZE = 1024
NUM_WORKERS = 8
seed = 777

In [3]:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True  # type: ignore
torch.backends.cudnn.benchmark = False  # type: ignore

In [4]:
constant_left_flank = AgarwalDataset.CONSTANT_LEFT_FLANK  # required for each sequence
constant_right_flank = AgarwalDataset.CONSTANT_RIGHT_FLANK  # required for each sequence
# original flanks from human_legnet
left_flank = AgarwalDataset.LEFT_FLANK
right_flank = AgarwalDataset.RIGHT_FLANK

In [5]:
forw_transform = t.Compose(
    [t.AddFlanks(constant_left_flank, constant_right_flank), t.Seq2Tensor()]
)
rev_transform = t.Compose(
    [
        t.AddFlanks(constant_left_flank, constant_right_flank),
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

In [6]:
def meaned_prediction(forw, rev, trainer, seq_model, name, is_kircher=False):
    predictions_forw = trainer.predict(seq_model, dataloaders=forw)
    targets = torch.cat([pred["target"] for pred in predictions_forw])
    y_preds_forw = torch.cat([pred["ref_predicted"] for pred in predictions_forw])

    predictions_rev = trainer.predict(seq_model, dataloaders=rev)
    y_preds_rev = torch.cat([pred["ref_predicted"] for pred in predictions_rev])

    mean_forw = torch.mean(torch.stack([y_preds_forw, y_preds_rev]), dim=0)

    pears = PearsonCorrCoef()
    print(name + " Pearson correlation")

    if is_kircher:
        y_preds_forw_alt = torch.cat(
            [pred["alt_predicted"] for pred in predictions_forw]
        )
        y_preds_rev_alt = torch.cat([pred["alt_predicted"] for pred in predictions_rev])
        mean_alt = torch.mean(torch.stack([y_preds_forw_alt, y_preds_rev_alt]), dim=0)
        pred = mean_alt - mean_forw
        return pears(pred, targets)

    return pears(mean_forw, targets)

In [7]:
# shift (0,15)
# preprocessing
shift = 15

train_transform = t.Compose(
    [
        t.AddFlanks(constant_left_flank, constant_right_flank),
        t.AddFlanks(left_flank[-shift:], right_flank[:shift]),
        t.LeftCrop(230, 260),
        t.RightCrop(230, 230),
        t.ReverseComplement(0.5),
        t.Seq2Tensor(),
    ]
)

test_transform = t.Compose(
    [  # test transforms are slightly different
        t.AddFlanks(constant_left_flank, constant_right_flank),
        t.ReverseComplement(
            0
        ),  # the Reverse complementary transformation is applied deterministically, not randomly with a 50% chance
        t.Seq2Tensor(),
    ]
)

In [8]:
elements_for_hepg2 = ["F9", "LDLR.2", "LDLR", "SORT1.2", "SORT1", "SORT1-flip"]
elements_for_k562 = ["PKLR-24h", "PKLR-48h"]

# Pretrain on Agarwal's HepG2

In [9]:
# load the data
train_dataset = AgarwalDataset(
    cell_type="HepG2",
    split="train",  # could use a list e.g. [1,2,5,6,7,8] for needed folds
    transform=train_transform,
    root="../data/",
)

val_dataset = AgarwalDataset(
    cell_type="HepG2",
    split="val",  # use "val" for default validation set or use list
    transform=test_transform,
    root="../data/",
)

test_dataset = AgarwalDataset(
    cell_type="HepG2",
    split="test",  # use "test" for default test set or use list
    transform=test_transform,
    root="../data/",
)

In [10]:
print(train_dataset)
print("=" * 50)
print(val_dataset)
print("=" * 50)
print(test_dataset)

Dataset AgarwalDataset of size 98336 (MpraDaraset)
    Number of datapoints: 98336
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
Dataset AgarwalDataset of size 12292 (MpraDaraset)
    Number of datapoints: 12292
    Used split fold: [9]
Dataset AgarwalDataset of size 12298 (MpraDaraset)
    Number of datapoints: 12298
    Used split fold: [10]


In [11]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [12]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [13]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[2, 2, 2, 2],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model = LitModel_Kircher(
    model=model, loss=nn.MSELoss(), weight_decay=1e-1, lr=1e-2, print_each=10
)

In [14]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_pearson", mode="max", save_top_k=1, save_last=False
)
# Initialize a trainer
trainer_hepg2 = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=50,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [15]:
# Train the model
trainer_hepg2.fit(seq_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290  

Training: |                                                                                    | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…


-------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.33054 | Val Pearson: 0.67351 | Train Pearson: 0.73167 
-------------------------------------------------------------------------------



Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…


--------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.34740 | Val Pearson: 0.69029 | Train Pearson: 0.77449 
--------------------------------------------------------------------------------



Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…


--------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.26379 | Val Pearson: 0.74997 | Train Pearson: 0.83037 
--------------------------------------------------------------------------------



Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…


--------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.24712 | Val Pearson: 0.76918 | Train Pearson: 0.89280 
--------------------------------------------------------------------------------



Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

Validation: |                                                                                  | 0/? [00:00<?,…

`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.24414 | Val Pearson: 0.77688 | Train Pearson: 0.92977 
--------------------------------------------------------------------------------



## HepG2 Evaluation

In [16]:
best_model_path = checkpoint_callback.best_model_path
# best_model_path = "../data/Kircher/best_model_test10_val9.ckpt"
seq_model_hepg2 = LitModel_Kircher.load_from_checkpoint(
    best_model_path,
    model=model,
    loss=nn.MSELoss(),
    weight_decay=0.1,
    lr=0.01,
    print_each=1,
)
trainer_hepg2.test(seq_model_hepg2, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: |                                                                                     | 0/? [00:00<?,…

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.2340375781059265
      test_pearson          0.7890375852584839
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.2340375781059265, 'test_pearson': 0.7890375852584839}]

In [18]:
forw_transform_hepg2 = t.Compose(
    [t.AddFlanks(constant_left_flank, constant_right_flank), t.Seq2Tensor()]
)
rev_transform_hepg2 = t.Compose(
    [
        t.AddFlanks(constant_left_flank, constant_right_flank),
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

test_forw = AgarwalDataset(
    cell_type="HepG2", split="test", transform=forw_transform_hepg2, root="../data/"
)
test_rev = AgarwalDataset(
    cell_type="HepG2", split="test", transform=rev_transform_hepg2, root="../data/"
)

forw_hepg2 = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_hepg2 = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(forw_hepg2, rev_hepg2, trainer_hepg2, seq_model_hepg2, name="HepG2")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

HepG2 Pearson correlation


tensor(0.8082)

## Kircher's sequences evaluation

In [19]:
elements_for_hepg2 = ["F9", "LDLR.2", "LDLR", "SORT1.2", "SORT1", "SORT1-flip"]

### LDLR

In [20]:
kircher_dataset_forw = KircherDataset(
    length=200, elements=["LDLR", "LDLR.2"], transform=forw_transform, root="../data/"
)
kircher_dataset_rev = KircherDataset(
    length=200, elements=["LDLR", "LDLR.2"], transform=rev_transform, root="../data/"
)
print("LDLR info")
print(kircher_dataset_forw)

kircher_forw = data.DataLoader(
    dataset=kircher_dataset_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
kircher_rev = data.DataLoader(
    dataset=kircher_dataset_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    kircher_forw,
    kircher_rev,
    trainer_hepg2,
    seq_model_hepg2,
    name="LDLR",
    is_kircher=True,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


LDLR info
Dataset KircherDataset of size 2176 (MpraDaraset)
    Number of datapoints: 2176
    Used split fold: test


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

LDLR Pearson correlation


tensor(0.3660)

### F9

In [21]:
kircher_dataset_forw = KircherDataset(
    length=200, elements=["F9"], transform=forw_transform, root="../data/"
)
kircher_dataset_rev = KircherDataset(
    length=200, elements=["F9"], transform=rev_transform, root="../data/"
)
print("F9 info")
print(kircher_dataset_forw)

kircher_forw = data.DataLoader(
    dataset=kircher_dataset_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
kircher_rev = data.DataLoader(
    dataset=kircher_dataset_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    kircher_forw,
    kircher_rev,
    trainer_hepg2,
    seq_model_hepg2,
    name="F9",
    is_kircher=True,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


F9 info
Dataset KircherDataset of size 984 (MpraDaraset)
    Number of datapoints: 984
    Used split fold: test


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

F9 Pearson correlation


tensor(0.4769)

### SORT1 

In [22]:
kircher_dataset_forw = KircherDataset(
    length=200,
    elements=["SORT1.2", "SORT1", "SORT1-flip"],
    transform=forw_transform,
    root="../data/",
)
kircher_dataset_rev = KircherDataset(
    length=200,
    elements=["SORT1.2", "SORT1", "SORT1-flip"],
    transform=rev_transform,
    root="../data/",
)
print("SORT1 info")
print(kircher_dataset_forw)

kircher_forw = data.DataLoader(
    dataset=kircher_dataset_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
kircher_rev = data.DataLoader(
    dataset=kircher_dataset_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    kircher_forw,
    kircher_rev,
    trainer_hepg2,
    seq_model_hepg2,
    name="SORT1",
    is_kircher=True,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


SORT1 info
Dataset KircherDataset of size 5898 (MpraDaraset)
    Number of datapoints: 5898
    Used split fold: test


Predicting: |                                                                                  | 0/? [00:00<?,…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                  | 0/? [00:00<?,…

SORT1 Pearson correlation


tensor(0.4938)

# Pretrain on Agarwal's K562

In [23]:
# load the data
train_dataset = AgarwalDataset(
    cell_type="K562",
    split="train",  # could use a list e.g. [1,2,5,6,7,8] for needed folds
    transform=train_transform,
    root="../data/",
)

val_dataset = AgarwalDataset(
    cell_type="K562",
    split="val",  # use "val" for default validation set or use list
    transform=test_transform,
    root="../data/",
)

test_dataset = AgarwalDataset(
    cell_type="K562",
    split="test",  # use "test" for default test set or use list
    transform=test_transform,
    root="../data/",
)

In [98]:
print(train_dataset)
print("=" * 50)
print(val_dataset)
print("=" * 50)
print(test_dataset)

Dataset AgarwalDataset of size 157328 (MpraDaraset)
    Number of datapoints: 157328
    Used split fold: [1, 2, 3, 4, 5, 6, 7, 8]
Dataset AgarwalDataset of size 19666 (MpraDaraset)
    Number of datapoints: 19666
    Used split fold: [9]
Dataset AgarwalDataset of size 19670 (MpraDaraset)
    Number of datapoints: 19670
    Used split fold: [10]


In [99]:
# encapsulate data into dataloader form
train_loader = data.DataLoader(
    dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

val_loader = data.DataLoader(
    dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

test_loader = data.DataLoader(
    dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

In [24]:
in_channels = len(train_dataset[0][0])
out_channels = 1

In [25]:
model = HumanLegNet(
    in_ch=in_channels,
    output_dim=out_channels,
    stem_ch=64,
    stem_ks=11,
    ef_ks=9,
    ef_block_sizes=[80, 96, 112, 128],
    pool_sizes=[2, 2, 2, 2],
    resize_factor=4,
)
model.apply(initialize_weights)

seq_model = LitModel_Kircher(
    model=model, loss=nn.MSELoss(), weight_decay=1e-1, lr=1e-2, print_each=10
)

In [26]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_pearson", mode="max", save_top_k=1, save_last=False
)
# Initialize a trainer
trainer_k562 = L.Trainer(
    accelerator="gpu",
    devices=[1],
    max_epochs=50,
    gradient_clip_val=1,
    precision="16-mixed",
    enable_progress_bar=True,
    num_sanity_val_steps=0,
    callbacks=[checkpoint_callback],
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [27]:
# Train the model
trainer_k562.fit(seq_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name          | Type            | Params | Mode 
----------------------------------------------------------
0 | model         | HumanLegNet     | 1.3 M  | train
1 | loss          | MSELoss         | 0      | train
2 | train_pearson | PearsonCorrCoef | 0      | train
3 | val_pearson   | PearsonCorrCoef | 0      | train
4 | test_pearson  | PearsonCorrCoef | 0      | train
----------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.290     Total estimated model params size (MB)
120       Modules in train mode
0         Modules in eval mode


Training: |                                                                                       | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


-------------------------------------------------------------------------------
| Epoch: 9 | Val Loss: 0.14667 | Val Pearson: 0.70573 | Train Pearson: 0.75466 
-------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------
| Epoch: 19 | Val Loss: 0.17103 | Val Pearson: 0.74526 | Train Pearson: 0.78996 
--------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------
| Epoch: 29 | Val Loss: 0.12919 | Val Pearson: 0.75422 | Train Pearson: 0.82394 
--------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…


--------------------------------------------------------------------------------
| Epoch: 39 | Val Loss: 0.13276 | Val Pearson: 0.77974 | Train Pearson: 0.87416 
--------------------------------------------------------------------------------



Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

Validation: |                                                                                     | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=50` reached.



--------------------------------------------------------------------------------
| Epoch: 49 | Val Loss: 0.09762 | Val Pearson: 0.81021 | Train Pearson: 0.91190 
--------------------------------------------------------------------------------



## K562 Evaluation

In [28]:
best_model_path = checkpoint_callback.best_model_path
seq_model_k562 = LitModel_Kircher.load_from_checkpoint(
    best_model_path,
    model=model,
    loss=nn.MSELoss(),
    weight_decay=0.1,
    lr=0.01,
    print_each=1,
)

In [29]:
forw_transform_k562 = t.Compose(
    [t.AddFlanks(constant_left_flank, constant_right_flank), t.Seq2Tensor()]
)
rev_transform_k562 = t.Compose(
    [
        t.AddFlanks(constant_left_flank, constant_right_flank),
        t.ReverseComplement(1),
        t.Seq2Tensor(),
    ]
)

test_forw = AgarwalDataset(
    cell_type="K562", split="test", transform=forw_transform_k562, root="../data/"
)
test_rev = AgarwalDataset(
    cell_type="K562", split="test", transform=rev_transform_k562, root="../data/"
)

forw_k562 = data.DataLoader(
    dataset=test_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
rev_k562 = data.DataLoader(
    dataset=test_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(forw_k562, rev_k562, trainer_k562, seq_model_k562, name="K562")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

K562 Pearson correlation


tensor(0.8280)

## Kircher's sequences evaluation

In [30]:
elements_for_k562 = ["PKLR-24h", "PKLR-48h"]

### PKLR

In [40]:
kircher_dataset_forw = KircherDataset(
    length=200, elements=["PKLR-48h"], transform=forw_transform, root="../data/"
)
kircher_dataset_rev = KircherDataset(
    length=200, elements=["PKLR-48h"], transform=rev_transform, root="../data/"
)
print("PKLR info")
print(kircher_dataset_forw)

kircher_forw = data.DataLoader(
    dataset=kircher_dataset_forw,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
kircher_rev = data.DataLoader(
    dataset=kircher_dataset_rev,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

meaned_prediction(
    kircher_forw,
    kircher_rev,
    trainer_k562,
    seq_model_k562,
    name="PKLR",
    is_kircher=True,
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


PKLR info
Dataset KircherDataset of size 1794 (MpraDaraset)
    Number of datapoints: 1794
    Used split fold: test


Predicting: |                                                                                     | 0/? [00:00…

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                     | 0/? [00:00…

PKLR Pearson correlation


tensor(0.5429)