# Using bpnet-lite and SeqData to train a CTCF profile model
This notebook uses SeqData () and to load in a CTCF ChIP-seq dataset and bpnet-lite  (https://github.com/jmschrei/bpnet-lite/tree/master) to train a base-pair resolution model in the BPNet style. 

In [1]:
import torch
import numpy as np
import xarray as xr
import seqdata as sd
import seqpro as sp
from pathlib import Path
from bpnetlite import BPNet

from eugene import preprocess as pp

# Define paths
data_dir = Path("/cellar/shared/carterlab/data/ml4gland/ENCSR000EGM/data")
fasta = Path("/cellar/users/aklie/data/ml4gland/use_cases/avsec21/reference/hg38.fa")
peaks = data_dir / "peaks.bed"
signals = [data_dir / "plus.bw", data_dir / "minus.bw"]
controls = [data_dir / "control_plus.bw", data_dir / "control_minus.bw"]
control_samples = ['plus', 'minus']
bigwigs = signals + controls
sample_names = ['signal+', 'signal-', 'control+', 'control-']
out = '/cellar/users/dlaub/projects/ML4GLand/use_cases/avsec21/avsec21.zarr'

# Define training and validation chromosomes
training_chroms = ['chr{}'.format(i) for i in range(1, 17)]
valid_chroms = ['chr{}'.format(i) for i in range(18, 23)]

# Load data

In [2]:
# sdata = sd.from_region_files(
#     sd.GenomeFASTA(
#         'seq',
#         fasta,
#         batch_size=2048,
#         n_threads=4,
#     ),
#     sd.BigWig(
#         'cov',
#         bigwigs,
#         sample_names,
#         batch_size=2048,
#         n_jobs=4,
#         threads_per_job=2,
#     ),
#     path=out,
#     fixed_length=2114,
#     bed=peaks,
#     overwrite=True,
#     max_jitter=128
# )

In [44]:
# Load in the SeqData
sdata = sd.open_zarr(out)

In [45]:
# Split cov into control and signal
sdata['control'] = (
    sdata.cov.sel(cov_sample=['control+', 'control-'])
    .rename({'cov_sample': 'cov_strand'})
    .assign_coords({'cov_strand': ['+', '-']})
)
sdata['signal'] = (
    sdata.cov.sel(cov_sample=['signal+', 'signal-'])
    .rename({'cov_sample': 'cov_strand'})
    .assign_coords({'cov_strand': ['+', '-']})
)
sdata = sdata.drop_vars(['cov', 'cov_sample'])
sdata

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,uint64 numpy.ndarray,uint64 numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type uint64 numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,uint64 numpy.ndarray,uint64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type object numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type int64 numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type int64 numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type object numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type object numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type float64 numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,129.38 MiB,4.63 MiB
Shape,"(57244, 2370)","(2048, 2370)"
Dask graph,28 chunks in 2 graph layers,28 chunks in 2 graph layers
Data type,|S1 numpy.ndarray,|S1 numpy.ndarray
"Array Chunk Bytes 129.38 MiB 4.63 MiB Shape (57244, 2370) (2048, 2370) Dask graph 28 chunks in 2 graph layers Data type |S1 numpy.ndarray",2370  57244,

Unnamed: 0,Array,Chunk
Bytes,129.38 MiB,4.63 MiB
Shape,"(57244, 2370)","(2048, 2370)"
Dask graph,28 chunks in 2 graph layers,28 chunks in 2 graph layers
Data type,|S1 numpy.ndarray,|S1 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type object numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,object numpy.ndarray,object numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type int64 numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 447.22 kiB 223.61 kiB Shape (57244,) (28622,) Dask graph 2 chunks in 2 graph layers Data type int64 numpy.ndarray",57244  1,

Unnamed: 0,Array,Chunk
Bytes,447.22 kiB,223.61 kiB
Shape,"(57244,)","(28622,)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,517.53 MiB,9.26 MiB
Shape,"(57244, 2, 2370)","(2048, 1, 2370)"
Dask graph,56 chunks in 3 graph layers,56 chunks in 3 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 517.53 MiB 9.26 MiB Shape (57244, 2, 2370) (2048, 1, 2370) Dask graph 56 chunks in 3 graph layers Data type uint16 numpy.ndarray",2370  2  57244,

Unnamed: 0,Array,Chunk
Bytes,517.53 MiB,9.26 MiB
Shape,"(57244, 2, 2370)","(2048, 1, 2370)"
Dask graph,56 chunks in 3 graph layers,56 chunks in 3 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,517.53 MiB,9.26 MiB
Shape,"(57244, 2, 2370)","(2048, 1, 2370)"
Dask graph,56 chunks in 3 graph layers,56 chunks in 3 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 517.53 MiB 9.26 MiB Shape (57244, 2, 2370) (2048, 1, 2370) Dask graph 56 chunks in 3 graph layers Data type uint16 numpy.ndarray",2370  2  57244,

Unnamed: 0,Array,Chunk
Bytes,517.53 MiB,9.26 MiB
Shape,"(57244, 2, 2370)","(2048, 1, 2370)"
Dask graph,56 chunks in 3 graph layers,56 chunks in 3 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray


In [46]:
# Need to upper case the seqs for ohe
sdata["cleaned_seq"] = xr.DataArray(np.char.upper(sdata["seq"]), dims=["_sequence", "_length"])

In [47]:
# Check how many b'N' characters exist in "cleaned_seq"
(sdata["cleaned_seq"] == b"N").sum().values

array(4246)

In [48]:
# Load the training data into memory for faster training
sdata[['cleaned_seq', 'control', 'signal']].load()

In [49]:
# Keep only training and validation chromosomes
sdata = sdata.sel(_sequence=((sdata["chrom"].isin(training_chroms)) | (sdata["chrom"].isin(valid_chroms))).compute())

In [50]:
# Train-test split based on chromosomes
pp.train_test_chrom_split(sdata, test_chroms=valid_chroms)
sdata_train = sdata.sel(_sequence=(sdata["train_val"]==True).compute())
sdata_valid = sdata.sel(_sequence=(sdata["train_val"]==False).compute())

In [51]:
# Check that the chromosomes are correct
sdata_train["chrom"].to_series().value_counts(), sdata_valid["chrom"].to_series().value_counts()

(chr1     6085
 chr2     4260
 chr6     3557
 chr3     3437
 chr7     3220
 chr11    3055
 chr5     2990
 chr12    2801
 chr10    2668
 chr8     2480
 chr4     2378
 chr9     2173
 chr16    2120
 chr15    1915
 chr14    1601
 chr13    1063
 Name: chrom, dtype: int64,
 chr19    2387
 chr20    1589
 chr22    1301
 chr18    1086
 chr21     688
 Name: chrom, dtype: int64)

In [52]:
# Define training transformations
from eugene.dataload._augment import RandomRC

def seq_trans(x):
    x = np.char.upper(x)
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    return x

def cov_dtype(x):
    return tuple(arr.astype('f4') for arr in x)

def jitter(x):
    return sp.jitter(*x, max_jitter=128, length_axis=-1, jitter_axes=0)

def to_tensor(x):
    return tuple(torch.tensor(arr, dtype=torch.float32) for arr in x)

def random_rc(x):
    return RandomRC(rc_prob=0.5)(*x)

# Get the train dataloader
dl = sd.get_torch_dataloader(
    sdata_train,
    sample_dims=['_sequence'],
    variables=['cleaned_seq', 'control', 'signal'],
    prefetch_factor=None,
    batch_size=32,
    transforms={
        ('cleaned_seq', 'control', 'signal'): jitter,
        'cleaned_seq': seq_trans,
        'signal': lambda x: x[..., 557:-557],
        ('control', 'signal'): cov_dtype,
        ('control', 'cleaned_seq', 'signal'): to_tensor,
        ('signal', 'control', 'cleaned_seq'): random_rc
    },
    return_tuples=True,
    shuffle=True,
)

In [53]:
# Test a batch
batch = next(iter(dl))
[x.shape for x in batch]

In [None]:
# Get the validation data
def seq_trans(x):
    x = x[..., 128:-128]
    np.char.upper(x)
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    return torch.as_tensor(x.astype('f4'))

def ctl_trans(x):
    x = x[..., 128:-128]
    return torch.as_tensor(x.astype('f4'))

def cov_trans(x):
    x = x[..., 128+557:-128-557]
    return torch.as_tensor(x.astype('f4'))

X_valid = seq_trans(sdata_valid["cleaned_seq"].values)
X_ctl_valid = ctl_trans(sdata_valid["control"].values)
y_valid = cov_trans(sdata_valid["signal"].values)
X_valid.shape, X_ctl_valid.shape, y_valid.shape

(torch.Size([7051, 4, 2114]),
 torch.Size([7051, 2, 2114]),
 torch.Size([7051, 2, 1000]))

# Train a model

In [None]:
# Create the model, we need 2 channels for the control tracks, 2 channels for the signal tracks, and to trim to 1000 bp of input
model = BPNet(n_outputs=2, n_control_tracks=2, trimming=(2114 - 1000) // 2, name="bpnet.seqdata")

In [None]:
# Test out a prediction batch
batch = next(iter(dl))
pred_ctl = model.forward(*batch[:2])
pred_ctl[0].shape, pred_ctl[1].shape

(torch.Size([32, 2, 1256]), torch.Size([32, 1]))

In [None]:
# Send the model to the GPU
model.cuda()

BPNet(
  (iconv): Conv1d(4, 64, kernel_size=(21,), stride=(1,), padding=(10,))
  (irelu): ReLU()
  (rconvs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
    (2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
    (3): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
    (4): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(32,), dilation=(32,))
    (5): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(64,), dilation=(64,))
    (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(128,), dilation=(128,))
    (7): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(256,), dilation=(256,))
  )
  (rrelus): ModuleList(
    (0-7): 8 x ReLU()
  )
  (fconv): Conv1d(66, 2, kernel_size=(75,), stride=(1,), padding=(37,))
  (linear): Linear(in_features=65, out_features=1, bias=True)
)

In [None]:
# Quickly define your optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [43]:
# Use the models fit_generator method to train the model
res = model.fit(
    dl,
    optimizer,
    X_valid=X_valid,
	X_ctl_valid=X_ctl_valid,
    y_valid=y_valid,
    max_epochs=50
)

Epoch	Iteration	Training Time	Validation Time	Training MNLL	Training Count MSE	Validation MNLL	Validation Profile Pearson	Validation Count Pearson	Validation Count MSE	Saved?
0	0	4.6641	1.8916	719.5731	38.1668	630.3253	0.0016708192	-0.358666	25.7244	True
0	100	8.5274	1.2695	577.7614	4.8351	340.1001	0.09122563	-0.12384434	2.5671	True
0	200	7.5811	1.2688	603.0289	1.1336	313.1705	0.237045	-0.11252865	0.702	True
0	300	7.9923	1.2632	440.5643	0.5202	290.3875	0.33387583	-0.15623756	1.0023	True
0	400	7.3477	1.2624	454.6971	0.7394	281.2389	0.36048895	-0.090743326	0.7031	True
0	500	8.0346	1.2708	506.2081	0.8303	277.2175	0.37812486	-0.040716868	0.6542	True
0	600	8.0278	1.3541	461.7531	0.5143	274.5461	0.3866451	0.00065450364	0.7518	True
0	700	6.8262	1.2604	403.9507	0.5806	272.9304	0.39379632	0.02435911	0.8997	True
0	800	7.6983	1.265	505.2987	0.5619	272.3665	0.39651522	0.14550175	0.9497	True
0	900	8.059	1.2655	483.8552	0.4399	271.5637	0.39877772	0.088804096	1.5097	True
0	1000	7.8436	1.2782	411.9011

KeyboardInterrupt: 

# Scratch

In [None]:
import numpy as np
from typing import Tuple, Union

def random_rc(x: np.ndarray, rc_prob: float = 0.5) -> Union[np.ndarray, Tuple[np.ndarray]]:
    """Randomly applies a reverse-complement transformation to each sequence in a training
    batch according to a user-defined probability, rc_prob. This is applied to each sequence
    independently.

    Parameters
    ----------
    x : np.ndarray
        Batch (or tuple of batches) of one-hot sequences (shape: (N, A, L)).
    rc_prob : float, optional
        Probability to apply a reverse-complement transformation, defaults to 0.5.

    Returns
    -------
    np.ndarray
        Sequences with random reverse-complements applied.
    """
    n = x.shape[0]
    # randomly select sequences to apply rc transformation
    ind_rc = np.random.rand(n) < rc_prob

    # make a copy of the sequence
    x_aug = np.copy(x)

    # apply reverse-complement transformation
    x_aug[ind_rc] = np.flip(x_aug[ind_rc], axis=(1, 2))

    return x_aug

In [None]:
# Define training transformations
def seq_trans(x):
    x = np.char.upper(x)
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    #print("Post ohe:", (x.shape, x.dtype))
    return x

def cov_dtype(x):
    #[print("Pre cov:", (x.shape, x.dtype)) for x in x]
    return tuple(arr.astype('f4') for arr in x)

def jitter(x):
    #[print("Pre jitter:", (x.shape, x.dtype)) for x in x]
    return sp.jitter(*x, max_jitter=128, length_axis=-1, jitter_axes=0)

def to_tensor(x):
    #[print("Pre tensor:", x.shape, x.dtype) for x in x]
    return tuple(torch.tensor(arr, dtype=torch.float32) for arr in x)

def random_rc(x):
    #[print("Pre rc:",x.shape, x.dtype) for x in x]
    return RandomRC()(*x)

# Get the train dataloader
dl = sd.get_torch_dataloader(
    sdata_train,
    sample_dims=['_sequence'],
    variables=['cleaned_seq', 'control', 'signal'],
    prefetch_factor=None,
    batch_size=32,
    transforms={
        ('cleaned_seq', 'control', 'signal'): jitter,
        'cleaned_seq': seq_trans,
        'signal': lambda x: x[..., 557:-557],
        ('control', 'signal'): cov_dtype,
        ('control', 'cleaned_seq', 'signal'): to_tensor,
        ('signal', 'control', 'cleaned_seq'): random_rc
    },
    return_tuples=True
)

In [None]:
# 
_transforms = {
    ('seq', 'control', 'signal'): jitter,
    'seq': seq_trans,
    'signal': lambda x: x[..., 557:-557],
    ('control', 'signal'): cov_dtype
}

In [None]:
vars_with_transforms = set()
for k in _transforms:
    if isinstance(k, tuple):
        vars_with_transforms.update(k)
    else:
        vars_with_transforms.add(k)

In [None]:
def seq_trans(x):
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    jitter = RandomJitter(128, -1)
    x = jitter(x)
    return torch.as_tensor(x.astype('f4'))

def ctl_trans(x):
    x = x[..., 128:-128]
    return torch.as_tensor(x.astype('f4'))

def cov_trans(x):
    x = x[..., 128+557:-128-557]
    return torch.as_tensor(x.astype('f4'))

dl = sd.get_torch_dataloader(
    sdata,
    sample_dims=['_sequence'],
    variables=['seq', 'control', 'signal'],
    prefetch_factor=None,
    batch_size=8,
    transforms={
        'seq': seq_trans,
        'control': ctl_trans,
        'signal': cov_trans,
    },
    return_tuples=True
)

In [None]:
# Test out
example_seq = sdata_train["cleaned_seq"][0:64].values
example_ctl = sdata_train["control"][0:64].values
example_signal = sdata_train["signal"][0:64].values

In [None]:
# Jitter first
jitter_out = jitter((example_seq, example_ctl, example_signal))
jittered_seq = jitter_out[0]
jittered_ctl = jitter_out[1]
jittered_signal = jitter_out[2]
jittered_seq.shape, jittered_ctl.shape, jittered_signal.shape

((64, 2114), (64, 2, 2114), (64, 2, 2114))

In [None]:
# Then complete transforms
seq = seq_trans(jittered_seq)
ctl, signal = cov_dtype((jittered_ctl, jittered_signal[..., 557:-557]))

In [None]:
# How does this look?
seq.shape, signal.shape, ctl.shape

((64, 4, 2114), (64, 2, 1000), (64, 2, 2114))