# Set-up

In [153]:
# Imports
import os
import numpy as np
import xarray as xr
import torch

import seqdata as sd
import seqpro as sp

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

from bpnetlite import BPNet
from seqmodels import Module

# local
os.chdir("/cellar/users/aklie/projects/ML4GLand/tutorials/bulk_atac_basepair")
from BPNet import bpnetlite_loss, bpnetlite_metrics

In [154]:
# Define training and validation chromosomes
train_chroms = ['chr{}'.format(i) for i in range(1, 23)]
valid_chroms = ['chr8', 'chr20']
test_chroms = ['chr1', 'chr3', 'chr6']

In [155]:
path_peaks = "/cellar/users/aklie/data/datasets/K562_ATAC-seq/data/K562_ATAC-seq_peaks.zarr"
path_negatives = "/cellar/users/aklie/data/datasets/K562_ATAC-seq/data/K562_ATAC-seq_negatives.zarr"

# Read data

In [156]:
peaks = sd.open_zarr(path_peaks)
neg = sd.open_zarr(path_negatives)
peaks.dims["_sequence"], neg.dims["_sequence"]

(269800, 269774)

# Train bias model

In [157]:
beta = 0.5
seq_len = 2114
target_len = 1000
trimming = (seq_len - target_len) // 2
seq_len, target_len, trimming

(2114, 1000, 557)

In [158]:
arch = BPNet(
    n_layers=4,
    n_outputs=1,
    n_control_tracks=0,
    verbose=True,
    trimming=trimming,
)
arch

BPNet(
  (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
  (irelu): ReLU()
  (rconvs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
    (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
    (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
  )
  (rrelus): ModuleList(
    (0-3): 4 x ReLU()
  )
  (fconv): Conv1d(64, 1, kernel_size=(75,), stride=(1,), padding=(37,))
  (linear): Linear(in_features=64, out_features=1, bias=True)
)

In [159]:
counts = torch.tensor(peaks["cov"].values.sum(axis=(1, 2)), dtype=torch.float32)
max_counts = torch.quantile(counts, 0.1).item() * beta
print("Max Counts: ", max_counts)

Max Counts:  283.0


In [160]:
msk = (neg["cov"].values.sum(axis=(1, 2)) > max_counts)
print("Num > max_counts: ", msk.sum())

Num > max_counts:  11878


In [161]:
# Filter out negatives with too many counts
filtered_neg = neg.isel(_sequence=~msk)
filtered_neg

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 1.97 MiB 257.55 kiB Shape (257896,) (32967,) Dask graph 8 chunks in 3 graph layers Data type object numpy.ndarray",257896  1,

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 1.97 MiB 257.55 kiB Shape (257896,) (32967,) Dask graph 8 chunks in 3 graph layers Data type int64 numpy.ndarray",257896  1,

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 1.97 MiB 257.55 kiB Shape (257896,) (32967,) Dask graph 8 chunks in 3 graph layers Data type int64 numpy.ndarray",257896  1,

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,88.87 MiB
Shape,"(257896, 1, 2370)","(9830, 1, 2370)"
Dask graph,27 chunks in 3 graph layers,27 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 2.28 GiB 88.87 MiB Shape (257896, 1, 2370) (9830, 1, 2370) Dask graph 27 chunks in 3 graph layers Data type float32 numpy.ndarray",2370  1  257896,

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,88.87 MiB
Shape,"(257896, 1, 2370)","(9830, 1, 2370)"
Dask graph,27 chunks in 3 graph layers,27 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,582.90 MiB,22.22 MiB
Shape,"(257896, 2370)","(9830, 2370)"
Dask graph,27 chunks in 3 graph layers,27 chunks in 3 graph layers
Data type,|S1 numpy.ndarray,|S1 numpy.ndarray
"Array Chunk Bytes 582.90 MiB 22.22 MiB Shape (257896, 2370) (9830, 2370) Dask graph 27 chunks in 3 graph layers Data type |S1 numpy.ndarray",2370  257896,

Unnamed: 0,Array,Chunk
Bytes,582.90 MiB,22.22 MiB
Shape,"(257896, 2370)","(9830, 2370)"
Dask graph,27 chunks in 3 graph layers,27 chunks in 3 graph layers
Data type,|S1 numpy.ndarray,|S1 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 1.97 MiB 257.55 kiB Shape (257896,) (32967,) Dask graph 8 chunks in 3 graph layers Data type object numpy.ndarray",257896  1,

Unnamed: 0,Array,Chunk
Bytes,1.97 MiB,257.55 kiB
Shape,"(257896,)","(32967,)"
Dask graph,8 chunks in 3 graph layers,8 chunks in 3 graph layers
Data type,object numpy.ndarray,object numpy.ndarray


In [162]:
# Split into train, valid, test
train_neg = filtered_neg.sel(_sequence=(filtered_neg["chrom"].isin(train_chroms)).compute())
valid_neg = filtered_neg.sel(_sequence=(filtered_neg["chrom"].isin(valid_chroms)).compute())
test_neg = filtered_neg.sel(_sequence=(filtered_neg["chrom"].isin(test_chroms)).compute())

In [205]:
def train_transform(batch):
    batch['seq'], batch['cov'] = sp.jitter(batch['seq'], batch['cov'], max_jitter=128, length_axis=-1, jitter_axes=0)  # jitter
    batch['cov'] = batch['cov'][..., trimming:-trimming]  # crop 
    batch['seq'] = sp.DNA.ohe(batch['seq']).transpose(0, 2, 1)  # one hot encode
    if np.random.rand() < 0.5:  # reverse complement
        #batch['seq'] = sp.reverse_complement(batch['seq'], alphabet=sp.DNA, length_axis=-1, ohe_axis=1)
        #batch['cov'] = np.flip(batch['cov'], axis=-1)
        pass
    return batch


def transform(batch):
    batch['seq'] = batch['seq'][..., 128:-128]  # crop 
    batch['cov'] = batch['cov'][..., 128+trimming:-128-trimming]  # crop 
    batch['seq'] = sp.DNA.ohe(batch['seq']).transpose(0, 2, 1)  # one hot encode
    return batch

In [206]:
# Get the train dataloader
train_dl = sd.get_torch_dataloader(
    train_neg,
    sample_dims=['_sequence'],
    variables=['seq', 'cov'],
    prefetch_factor=None,
    batch_size=32,
    transform=train_transform,
    shuffle=True,
)
batch = next(iter(train_dl))
batch['seq'].shape, batch['cov'].shape

(torch.Size([1, 4, 2114]), torch.Size([1, 1, 1000]))

In [207]:
# Valid dataloader
valid_dl = sd.get_torch_dataloader(
    valid_neg,
    sample_dims=['_sequence'],
    variables=['seq', 'cov'],
    prefetch_factor=None,
    batch_size=32,
    transform=transform,
    shuffle=False,
)
batch = next(iter(valid_dl))
batch["seq"].shape, batch["cov"].shape

(torch.Size([1, 4, 2114]), torch.Size([1, 1, 1000]))

In [208]:
from bpnetlite.losses import MNLLLoss, log1pMSELoss
from bpnetlite.performance import calculate_performance_measures

In [209]:
def bpnetlite_loss(outputs_dict, targets_dict, alpha=1):
    y_profile = outputs_dict['profile']
    y_counts = outputs_dict['counts'].reshape(-1, 1)
    y = targets_dict['cov']
    y_profile = y_profile.reshape(y_profile.shape[0], -1)
    y_profile = torch.nn.functional.log_softmax(y_profile, dim=-1)
    y = y.reshape(y.shape[0], -1)
    profile_loss = MNLLLoss(y_profile, y)
    count_loss = log1pMSELoss(y_counts, y.sum(dim=-1).reshape(-1, 1))
    loss = profile_loss + alpha * count_loss
    return{
        "loss": loss,
        "profile_loss": profile_loss,
        "count_loss": count_loss,
    }


def bpnetlite_metrics(outputs_dict, targets_dict, alpha=1):
    y_profile = outputs_dict['profile']
    y_counts = outputs_dict['counts']
    y = targets_dict['cov']
    z = y_profile.shape
    y_profile = y_profile.reshape(y_profile.shape[0], -1)
    y_profile = torch.nn.functional.log_softmax(y_profile, dim=-1)
    y_profile = y_profile.reshape(*z)
    measures = calculate_performance_measures(
        y_profile, 
        y, 
        y_counts, 
        kernel_sigma=7, 
        kernel_width=81, 
        measures=['profile_mnll', 'profile_pearson', 'count_mse', 'count_pearson']
    )
    profile_mnll = measures['profile_mnll'].cpu()
    count_mse  = measures['count_mse'].cpu()
    profile_corr = measures['profile_pearson'].cpu()
    count_corr = measures['count_pearson'].cpu()
    loss = measures['profile_mnll'].cpu() + alpha * measures['count_mse'].cpu()
    return{
        "profile_mnll": profile_mnll,
        "count_mse": count_mse,
        "profile_corr": profile_corr,
        "count_corr": count_corr,
        "loss": loss,
    }

In [210]:
module = Module(
    arch=arch,
    input_vars=["seq"],
    output_vars=["profile", "counts"],
    target_vars=["cov"],
    loss_fxn=bpnetlite_loss,
    val_metrics_fxn=bpnetlite_metrics,
    val_metrics_kwargs={"alpha": arch.alpha},
    optimizer="adam",
    optimizer_lr=1e-3,
)

In [211]:
test_out = module({"seq": batch["seq"]})
test_out[0].shape, test_out[1].shape

(torch.Size([1, 1, 1000]), torch.Size([1, 1]))

In [212]:
logger = CSVLogger(save_dir="log", name="bias", version="v0.0.1")

callbacks = [
    ModelCheckpoint(
        dirpath=os.path.join(
            logger.save_dir, 
            logger.name, 
            logger.version, 
            "checkpoints"
        ),
        save_top_k=5,
        monitor="val_loss_epoch",
    )
]

# Trainer
trainer = Trainer(
    logger=logger,
    callbacks=callbacks,
    max_epochs=50,
)

/cellar/users/aklie/opt/miniconda3/envs/eugene_tools/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /cellar/users/aklie/opt/miniconda3/envs/eugene_tools ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [213]:
# Fit the weigths
trainer.fit(module, train_dl, valid_dl)

/cellar/users/aklie/opt/miniconda3/envs/eugene_tools/lib/python3.11/site-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory log/bias/v0.0.1 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type          | Params | Mode 
----------------------------------------------------------
0 | arch            | BPNet         | 59.7 K | eval 
1 | loss_fxn        | GeneralLoss   | 0      | train
2 | val_metrics_fxn | GeneralMetric | 0      | train
----------------------------------------------------------
59.7 K    Trainable params
0         Non-trainable params
59.7 K    Total params
0.239     Total estimated model params size (MB)
2         Modules in train mode
15        Modules in eval mode


                                                                            

/cellar/users/aklie/opt/miniconda3/envs/eugene_tools/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/cellar/users/aklie/opt/miniconda3/envs/eugene_tools/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/7318 [00:00<?, ?it/s] 

Epoch 49: 100%|██████████| 7318/7318 [00:57<00:00, 126.27it/s, v_num=.0.1, train_loss=223.0, val_loss_epoch=160.0, val_profile_loss_epoch=159.0, val_count_loss_epoch=0.675, val_profile_mnll_epoch=159.0, val_count_mse_epoch=0.675, val_profile_corr_epoch=0.391, val_count_corr_epoch=0.000, train_loss_epoch=155.0, train_profile_loss_epoch=154.0, train_count_loss_epoch=1.020]   

`Trainer.fit` stopped: `max_epochs=50` reached.


Epoch 49: 100%|██████████| 7318/7318 [00:57<00:00, 126.21it/s, v_num=.0.1, train_loss=223.0, val_loss_epoch=160.0, val_profile_loss_epoch=159.0, val_count_loss_epoch=0.675, val_profile_mnll_epoch=159.0, val_count_mse_epoch=0.675, val_profile_corr_epoch=0.391, val_count_corr_epoch=0.000, train_loss_epoch=155.0, train_profile_loss_epoch=154.0, train_count_loss_epoch=1.020]



# Fit ChromBPNet