In [11]:
import os
import torch
import numpy as np
from pathlib import Path

# bpnet-lite

In [2]:
from bpnetlite.io import extract_loci
from bpnetlite.io import PeakGenerator

In [5]:
# Define paths
data_dir = Path("/cellar/shared/carterlab/data/ml4gland/ENCSR000EGM/data")
reference_dir = "/cellar/users/aklie/data/eugene/avsec21/reference"
peaks = os.path.join(data_dir, "peaks.bed")
seqs = "/cellar/users/aklie/data/ml4gland/use_cases/avsec21/reference/hg38.fa"
signals = [os.path.join(data_dir, "plus.bw"), os.path.join(data_dir, "minus.bw")]
controls = [os.path.join(data_dir, "control_plus.bw"), os.path.join(data_dir, "control_minus.bw")]

In [None]:
# Define training and validation chromosomes
training_chroms = ['chr{}'.format(i) for i in range(1, 17)]
valid_chroms = ['chr{}'.format(i) for i in range(18, 23)]

# Create a dataloader for the training peaks, this takes about just under 4 minutes to complete
training_data = PeakGenerator(peaks, seqs, signals, controls, chroms=training_chroms)

In [6]:
len(training_data)

1432

In [7]:
training_data.batch_size

32

In [4]:
# Grab the validation data, no jittering, augmenting, or shuffling, takes about 30 seconds
X_valid, y_valid, X_ctl_valid = extract_loci(peaks, seqs, signals, controls, chroms=valid_chroms, max_jitter=0)

# SeqData

In [12]:
import xarray as xr
import seqdata as sd
import seqpro as sp
from eugene import preprocess as pp

# Define paths
data_dir = Path("/cellar/shared/carterlab/data/ml4gland/ENCSR000EGM/data")
fasta = Path("/cellar/users/aklie/data/ml4gland/use_cases/avsec21/reference/hg38.fa")
peaks = data_dir / "peaks.bed"
signals = [data_dir / "plus.bw", data_dir / "minus.bw"]
controls = [data_dir / "control_plus.bw", data_dir / "control_minus.bw"]
control_samples = ['plus', 'minus']
bigwigs = signals + controls
sample_names = ['signal+', 'signal-', 'control+', 'control-']
out = '/cellar/users/dlaub/projects/ML4GLand/use_cases/avsec21/avsec21.zarr'

# Define training and validation chromosomes
training_chroms = ['chr{}'.format(i) for i in range(1, 17)]
valid_chroms = ['chr{}'.format(i) for i in range(18, 23)]

In [13]:
# Load in the SeqData
sdata = sd.open_zarr(out)

# Split cov into control and signal
sdata['control'] = (
    sdata.cov.sel(cov_sample=['control+', 'control-'])
    .rename({'cov_sample': 'cov_strand'})
    .assign_coords({'cov_strand': ['+', '-']})
)
sdata['signal'] = (
    sdata.cov.sel(cov_sample=['signal+', 'signal-'])
    .rename({'cov_sample': 'cov_strand'})
    .assign_coords({'cov_strand': ['+', '-']})
)
sdata = sdata.drop_vars(['cov', 'cov_sample'])

# Need to upper case the seqs for ohe
sdata["cleaned_seq"] = xr.DataArray(np.char.upper(sdata["seq"]), dims=["_sequence", "_length"])

# Load the training data into memory for faster training
sdata[['cleaned_seq', 'control', 'signal']].load()

# Keep only training and validation chromosomes
sdata = sdata.sel(_sequence=((sdata["chrom"].isin(training_chroms)) | (sdata["chrom"].isin(valid_chroms))).compute())

# Train-test split based on chromosomes
pp.train_test_chrom_split(sdata, test_chroms=valid_chroms)
sdata_train = sdata.sel(_sequence=(sdata["train_val"]==True).compute())
sdata_valid = sdata.sel(_sequence=(sdata["train_val"]==False).compute())

In [1]:
# Define training transformations
from eugene.dataload._augment import RandomRC

def seq_trans(x):
    x = np.char.upper(x)
    x = sp.ohe(x, sp.alphabets.DNA)
    x = x.swapaxes(1, 2)
    return x

def cov_dtype(x):
    return tuple(arr.astype('f4') for arr in x)

def jitter(x):
    return sp.jitter(*x, max_jitter=128, length_axis=-1, jitter_axes=0)

def to_tensor(x):
    return tuple(torch.tensor(arr, dtype=torch.float32) for arr in x)

def random_rc(x):
    return RandomRC()(*x)

# Get the train dataloader
dl = sd.get_torch_dataloader(
    sdata_train,
    sample_dims=['_sequence'],
    variables=['cleaned_seq', 'control', 'signal'],
    prefetch_factor=None,
    batch_size=32,
    transforms={
        ('cleaned_seq', 'control', 'signal'): jitter,
        'cleaned_seq': seq_trans,
        'signal': lambda x: x[..., 557:-557],
        ('control', 'signal'): cov_dtype,
        ('control', 'cleaned_seq', 'signal'): to_tensor,
        ('signal', 'control', 'cleaned_seq'): random_rc
    },
    shuffle=True,
    return_tuples=True
)

KeyboardInterrupt: 

In [17]:
len(dl)

1432

In [2]:
from eugene import plot as pl

In [7]:
%matplotlib inline

In [8]:
pl.training_summary("/cellar/users/aklie/projects/ML4GLand/EUGENe_paper/logs/revision/bpnet/BPNet/v2")

In [None]:
# Check how many b'N' characters exist in "cleaned_seq"
(sdata["cleaned_seq"] == b"N").sum().values