In [1]:
# Import modules
import os
import sys
import numpy as np

import pandas as pd
import torch

import seqdata as sd
import seqpro as sp
import xarray as xr

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [2]:
# Set wd
os.chdir("/cellar/users/aklie/data/datasets/deAlmeida_DrosophilaS2_UMI-STARR-seq")

# Load data

In [33]:
# Load SeqData
sdata = sd.open_zarr("training/2023_12_20/bp_res/deepstarr_training_bpres_processed.zarr")
sdata.load()

: 

: 

In [None]:
# Check how many b'N' characters exist in "cleaned_seq"
(sdata["cleaned_seq"] == b"N").sum().values

array(203080)

In [None]:
sdata["dev_signal"][0].values.sum(), sdata["hk_signal"][0].values.sum(), sdata["dev_input"][0].values.sum(), sdata["hk_input"][0].values.sum()

(17461, 5535, 1469, 3300)

In [6]:
sdata_train = sdata.sel(_sequence=(sdata["train_val"]==True).compute())
sdata_valid = sdata.sel(_sequence=(sdata["train_val"]==False).compute())

In [20]:
# Define training transformations
from eugene.dataload._augment import RandomRC

def seq_trans(x):
    x = np.char.upper(x)
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    return x

def cov_dtype(x):
    return tuple(np.expand_dims(arr, axis=1).astype('f4') for arr in x)

def jitter(x):
    return sp.jitter(*x, max_jitter=128, length_axis=-1, jitter_axes=0)

def to_tensor(x):
    return tuple(torch.tensor(arr, dtype=torch.float32) for arr in x)

#def random_rc(x):
#    return RandomRC(rc_prob=0.5)(*x)

# Get the train dataloader
dl = sd.get_torch_dataloader(
    sdata_train,
    sample_dims=['_sequence'],
    variables=['cleaned_seq', 'hk_input', 'hk_signal'],
    prefetch_factor=None,
    batch_size=32,
    transforms={
        ('cleaned_seq', 'hk_input', 'hk_signal'): jitter,
        'cleaned_seq': seq_trans,
        'hk_signal': lambda x: x[..., 557:-557],
        ('hk_input', 'hk_signal'): cov_dtype,
        ('hk_input', 'cleaned_seq', 'hk_signal'): to_tensor,
        #('hk_signal', 'hk_input', 'cleaned_seq'): random_rc
    },
    return_tuples=True,
    shuffle=True,
)

In [21]:
# Test a batch
batch = next(iter(dl))
[x.shape for x in batch]

[torch.Size([32, 4, 2114]),
 torch.Size([32, 1, 2114]),
 torch.Size([32, 1, 1000])]

In [22]:
# Get the validation data
def seq_trans(x):
    x = x[..., 128:-128]
    np.char.upper(x)
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    return torch.as_tensor(x.astype('f4'))

def ctl_trans(x):
    x = x[..., 128:-128]
    return torch.as_tensor(np.expand_dims(x, axis=1).astype('f4'))

def cov_trans(x):
    x = x[..., 128+557:-128-557]
    return torch.as_tensor(np.expand_dims(x, axis=1).astype('f4'))

X_valid = seq_trans(sdata_valid["cleaned_seq"].values)
X_ctl_valid = ctl_trans(sdata_valid["hk_input"].values)
y_valid = cov_trans(sdata_valid["hk_signal"].values)
X_valid.shape, X_ctl_valid.shape, y_valid.shape

(torch.Size([40570, 4, 2114]),
 torch.Size([40570, 1, 2114]),
 torch.Size([40570, 1, 1000]))

In [23]:
from bpnetlite import BPNet

In [24]:
# Create the model, we need 1 channel for the signal tracks, and to trim to 1000 bp of input
model = BPNet(n_outputs=1, n_control_tracks=1, trimming=(2114 - 1000) // 2, name="test.bpnet.seqdata")
model

BPNet(
  (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
  (irelu): ReLU()
  (rconvs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
    (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
    (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
    (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(32,))
    (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(64,))
    (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(128,), dilation=(128,))
    (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(256,), dilation=(256,))
  )
  (rrelus): ModuleList(
    (0-7): 8 x ReLU()
  )
  (fconv): Conv1d(65, 1, kernel_size=(75,), stride=(1,), padding=(37,))
  (linear): Linear(in_features=65, out_features=1, bias=True)
)

In [25]:
# Test out a prediction batch
batch = next(iter(dl))
X, X_ctl, y = batch
X.shape, X_ctl.shape, y.shape


(torch.Size([32, 4, 2114]),
 torch.Size([32, 1, 2114]),
 torch.Size([32, 1, 1000]))

In [26]:
y_profile, y_counts = model.forward(X, X_ctl)
y_profile.shape, y_counts.shape

(torch.Size([32, 1, 1000]), torch.Size([32, 1]))

In [27]:
y_profile = y_profile.reshape(y_profile.shape[0], -1)
y_profile = torch.nn.functional.log_softmax(y_profile, dim=-1)
y = y.reshape(y.shape[0], -1)

In [28]:
# Calculate the profile and count losses
from bpnetlite.losses import MNLLLoss, log1pMSELoss
profile_loss = MNLLLoss(y_profile, y).mean()
count_loss = log1pMSELoss(y_counts, y.sum(dim=-1).reshape(-1, 1)).mean()

In [29]:
profile_loss, count_loss

(tensor(3160.8975, grad_fn=<MeanBackward0>),
 tensor(43.6154, grad_fn=<MeanBackward0>))

In [30]:
# Send the model to the GPU
model.cuda()

BPNet(
  (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
  (irelu): ReLU()
  (rconvs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
    (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
    (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
    (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(32,))
    (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(64,))
    (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(128,), dilation=(128,))
    (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(256,), dilation=(256,))
  )
  (rrelus): ModuleList(
    (0-7): 8 x ReLU()
  )
  (fconv): Conv1d(65, 1, kernel_size=(75,), stride=(1,), padding=(37,))
  (linear): Linear(in_features=65, out_features=1, bias=True)
)

In [31]:
# Quickly define your optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [32]:
# Use the models fit_generator method to train the model
res = model.fit(
    dl,
    optimizer,
    X_valid=X_valid,
	X_ctl_valid=X_ctl_valid,
    y_valid=y_valid,
    max_epochs=50
)

Epoch	Iteration	Training Time	Validation Time	Training MNLL	Training Count MSE	Validation MNLL	Validation Profile Pearson	Validation Count Pearson	Validation Count MSE	Saved?
0	0	3.4571	7.9835	2816.4172	44.5365	3869.6133	0.06639323	0.37229145	38.1434	True
0	100	12.2705	6.9255	1909.5676	2.5685	3333.1897	0.16152142	0.3803567	3.371	True
0	200	11.3792	6.9389	2682.4829	3.6374	2914.5569	0.24122709	0.406646	2.8059	True
0	300	10.9439	6.9538	2459.9956	3.1004	2841.1924	0.26697826	0.43771183	2.8311	True
0	400	12.0476	6.9665	2718.2988	3.0288	2742.7915	0.23302194	0.40544364	3.1068	True
0	500	11.5844	6.9703	1149.9036	2.7677	2721.3228	0.26276597	0.39829412	4.2158	True
0	600	11.1144	7.028	2633.1011	4.1051	2528.2271	0.2732668	0.44062012	2.61	True
0	700	11.7196	6.9806	1410.5408	2.8166	2588.2927	0.29798362	0.4345279	2.8816	False
0	800	11.7622	6.9777	2581.9556	1.6955	2818.5645	0.21183512	0.4345917	2.7239	False
0	900	11.8255	6.9773	3127.3149	2.7681	2767.9744	0.29188547	0.43575004	2.7797	False
0	1000	11.037

KeyboardInterrupt: 

# DONE!

---