#  Using EUGENe to generate a BPNet model

In [1]:
from Bio import motifs

In [2]:
from Bio.motifs import jaspar

In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import eugene as eu

Global seed set to 13


# ProfileDataset

In [4]:
# Set paths
data_dir = "/cellar/users/aklie/data/eugene/avsec21/ENCSR000EGM/data"
reference_dir = "/cellar/users/aklie/data/eugene/avsec21/reference"
peaks = os.path.join(data_dir, "peaks.bed")
#peaks = "/cellar/users/aklie/data/eugene/avsec21/ENCSR000EGM/toy.bed"
seqs = os.path.join(reference_dir, "hg38.fa")
signals = [os.path.join(data_dir, "plus.bw"), os.path.join(data_dir, "minus.bw")]
controls = [os.path.join(data_dir, "control_plus.bw"), os.path.join(data_dir, "control_minus.bw")]

# Set training and validation chromosomes
training_chroms = ['chr{}'.format(i) for i in range(1, 17)]
valid_chroms = ['chr{}'.format(i) for i in range(18, 23)]

# Load data

In [5]:
from eugene.dataload import ProfileDataset

In [6]:
X_train, y_train, X_ctl_train = eu.dl.read_profile(peaks, seqs, signals, controls, max_jitter=128, chroms=training_chroms)
#X_val, y_val, X_ctl_val = eu.dl.read_profile(peaks, seqs, signals, controls, max_jitter=0, chroms=valid_chroms)

In [7]:
X_train.shape

torch.Size([45803, 4, 2370])

In [None]:
X_val

In [658]:
#X_train_dataset = ProfileDataset(X_train, y_train, X_ctl_train)
X_val_dataset = ProfileDataset(X_val, y_val, X_ctl_val)

In [659]:
#X_train_loader = X_train_dataset.to_dataloader(batch_size=64, num_workers=4, shuffle=True)
X_val_loader = X_val_dataset.to_dataloader(batch_size=64, num_workers=4, shuffle=False)

# Instantiate a model

In [660]:
from eugene.models._profile_models import BPNet

model = BPNet(
    input_len=2114,
    output_dim=1000,
    n_outputs=2,
    n_control_tracks=2, 
    trimming=(2114 - 1000) // 2
).cuda()

In [661]:
batch = next(iter(X_dataloader_val))
outs = model(batch[0].cuda(), batch[1].cuda())
outs[0].shape, outs[1].shape

(torch.Size([64, 2, 1000]), torch.Size([64, 1]))

# Train with PyTorch Lightning

In [663]:
import pytorch_lightning as pl

: 

In [64]:
logging_dir = "/cellar/users/aklie/projects/ML4GLand/use_cases/avsec21/"
trainer = pl.Trainer(gpus=1, max_epochs=5, progress_bar_refresh_rate=20, logger=pl.loggers.TensorBoardLogger(logging_dir, name="models"))

  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [65]:
trainer.fit(model, X_train_loader, X_val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: /cellar/users/aklie/projects/ML4GLand/use_cases/avsec21/bpnet
Set SLURM handle signals.

  | Name   | Type       | Params
--------------------------------------
0 | iconv  | Conv1d     | 5.4 K 
1 | irelu  | ReLU       | 0     
2 | rconvs | ModuleList | 98.8 K
3 | rrelus | ModuleList | 0     
4 | fconv  | Conv1d     | 9.9 K 
5 | linear | Linear     | 66    
--------------------------------------
114 K     Trainable params
0         Non-trainable params
114 K     Total params
0.457     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 13




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




# Attributions

In [256]:
import seqexplainer

In [257]:
seqexplainer.attribute?

[0;31mSignature:[0m
[0mseqexplainer[0m[0;34m.[0m[0mattribute[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minputs[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmethod[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mCallable[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtarget[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdevice[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'cpu'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/projects/ML4GLand/SeqExplainer/seqexplainer/_feature_attribution.py
[0;31mType:[0m      function


In [267]:
seqexplainer.attribute(
    pretrained_model,
    X,
    method="InputXGradient",
    additional_forward_args=(X_ctl,),
    target=0
)

AttributeError: 'tuple' object has no attribute 'shape'

---

# DONE!

# Scratch

In [None]:
import numpy as np
import torch
import pandas as pd

import pyfaidx
import pyBigWig

from tqdm import tqdm

def extract_loci(
	loci, 
    sequences, 
    signals=None, 
    controls=None, 
    chroms=None, 
	in_window=2114,
    out_window=1000, 
    max_jitter=128, 
    min_counts=None,
	max_counts=None, 
    verbose=False
):
	"""Extract sequences and signals at coordinates from a locus file.
	This function will take in genome-wide sequences, signals, and optionally
	controls, and extract the values of each at the coordinates specified in
	the locus file/s and return them as tensors.
	Signals and controls are both lists with the length of the list, n_s
	and n_c respectively, being the middle dimension of the returned
	tensors. Specifically, the returned tensors of size 
	(len(loci), n_s/n_c, (out_window/in_wndow)+max_jitter*2).
	The values for sequences, signals, and controls, can either be filepaths
	or dictionaries of np arrays or a mix of the two. When a filepath is 
	passed in it is loaded using pyfaidx or pyBigWig respectively.   
	Parameters
	----------
	loci: str or pd.DataFrame or list/tuple of such
		Either the path to a bed file or a pd DataFrame object containing
		three columns: the chromosome, the start, and the end, of each locus
		to train on. Alternatively, a list or tuple of strings/DataFrames where
		the intention is to train on the interleaved concatenation, i.e., when
		you want to train on peaks and negatives.
	sequences: str or dictionary
		Either the path to a fasta file to read from or a dictionary where the
		keys are the unique set of chromosoms and the values are one-hot
		encoded sequences as np arrays or memory maps.
	signals: list of strs or list of dictionaries or None, optional
		A list of filepaths to bigwig files, where each filepath will be read
		using pyBigWig, or a list of dictionaries where the keys are the same
		set of unique chromosomes and the values are np arrays or memory
		maps. If None, no signal tensor is returned. Default is None.
	controls: list of strs or list of dictionaries or None, optional
		A list of filepaths to bigwig files, where each filepath will be read
		using pyBigWig, or a list of dictionaries where the keys are the same
		set of unique chromosomes and the values are np arrays or memory
		maps. If None, no control tensor is returned. Default is None. 
	chroms: list or None, optional
		A set of chromosomes to extact loci from. Loci in other chromosomes
		in the locus file are ignored. If None, all loci are used. Default is
		None.
	in_window: int, optional
		The input window size. Default is 2114.
	out_window: int, optional
		The output window size. Default is 1000.
	max_jitter: int, optional
		The maximum amount of jitter to add, in either direction, to the
		midpoints that are passed in. Default is 128.
	min_counts: float or None, optional
		The minimum number of counts, summed across the length of each example
		and across all tasks, needed to be kept. If None, no minimum. Default 
		is None.
	max_counts: float or None, optional
		The maximum number of counts, summed across the length of each example
		and across all tasks, needed to be kept. If None, no maximum. Default 
		is None.  
	verbose: bool, optional
		Whether to display a progress bar while loading. Default is False.
	Returns
	-------
	seqs: torch.tensor, shape=(n, 4, in_window+2*max_jitter)
		The extracted sequences in the same order as the loci in the locus
		file after optional filtering by chromosome.
	signals: torch.tensor, shape=(n, len(signals), out_window+2*max_jitter)
		The extracted signals where the first dimension is in the same order
		as loci in the locus file after optional filtering by chromosome and
		the second dimension is in the same order as the list of signal files.
		If no signal files are given, this is not returned.
	controls: torch.tensor, shape=(n, len(controls), out_window+2*max_jitter)
		The extracted controls where the first dimension is in the same order
		as loci in the locus file after optional filtering by chromosome and
		the second dimension is in the same order as the list of control files.
		If no control files are given, this is not returned.
	"""

	seqs, signals_, controls_ = [], [], []
	in_width, out_width = in_window // 2, out_window // 2

	# Load the sequences
	if isinstance(sequences, str):
		sequences = pyfaidx.Fasta(sequences)

	names = ['chrom', 'start', 'end']
	if not isinstance(loci, (tuple, list)):
		loci = [loci]

	loci_dfs = []
	for i, df in enumerate(loci):
		if isinstance(df, str):
			df = pd.read_csv(df, sep='\t', usecols=[0, 1, 2], header=None, index_col=False, names=names)
			df['idx'] = np.arange(len(df)) * len(loci) + i
		loci_dfs.append(df)

	loci = pd.concat(loci_dfs).set_index("idx").sort_index().reset_index(drop=True)
	if chroms is not None:
		loci = loci[np.isin(loci['chrom'], chroms)]

	# Load the signal and optional control tracks if filenames are given
	if signals is not None:
		for i, signal in enumerate(signals):
			if isinstance(signal, str):
				signals[i] = pyBigWig.open(signal, "r")

	if controls is not None:
		for i, control in enumerate(controls):
			if isinstance(control, str):
				controls[i] = pyBigWig.open(control, "r")

	desc = "Loading Loci"
	d = not verbose

	max_width = max(in_width, out_width)

	for chrom, start, end in tqdm(loci.values, disable=d, desc=desc):
		mid = start + (end - start) // 2

		if start - max_width - max_jitter < 0:
			continue

		if end + max_width + max_jitter >= len(sequences[chrom]):
			continue
		
		start = mid - out_width - max_jitter
		end = mid + out_width + max_jitter
		
		# Extract the signal from each of the signal files
		if signals is not None:
			signals_.append([])
			for signal in signals:
				if isinstance(signal, dict):
					signal_ = signal[chrom][start:end]
				else:
					signal_ = signal.values(chrom, start, end, numpy=True)
					signal_ = np.nan_to_num(signal_)

				signals_[-1].append(signal_)

		# For the sequences and controls extract a window the size of the input
		start = mid - in_width - max_jitter
		end = mid + in_width + max_jitter

		# Extract the controls from each of the control files
		if controls is not None:
			controls_.append([])
			for control in controls:
				if isinstance(control, dict):
					control_ = control[chrom][start:end]
				else:
					control_ = control.values(chrom, start, end, numpy=True)
					control_ = np.nan_to_num(control_)

				controls_[-1].append(control_)

		# Extract the sequence
		if isinstance(sequences, dict):
			seq = sequences[chrom][start:end].T
		else:
			seq = eu.pp.ohe_seq(sequences[chrom][start:end].seq.upper())
		
		seqs.append(seq)

	seqs = torch.tensor(np.array(seqs), dtype=torch.float32)

	if signals is not None:
		signals_ = torch.tensor(np.array(signals_), dtype=torch.float32)

		idxs = torch.ones(signals_.shape[0], dtype=torch.bool)
		if max_counts is not None:
			idxs = (idxs) & (signals_.sum(dim=(1, 2)) < max_counts)
		if min_counts is not None:
			idxs = (idxs) & (signals_.sum(dim=(1, 2)) > min_counts)

		if controls is not None:
			controls_ = torch.tensor(np.array(controls_), dtype=torch.float32)
			return seqs[idxs], signals_[idxs], controls_[idxs]

		return seqs[idxs], signals_[idxs]
	else:
		if controls is not None:
			controls_ = torch.tensor(np.array(controls_), dtype=torch.float32)
			return seqs, controls_

		return seqs

In [None]:
X_toy, y_toy, X_ctl_toy = extract_loci(peaks, seqs, signals, controls, max_jitter=0)

In [None]:
X_toy.shape, y_toy.shape, X_ctl_toy.shape

(torch.Size([100, 4, 2114]),
 torch.Size([100, 2, 1000]),
 torch.Size([100, 2, 2114]))

In [None]:
from torch.utils.data import Dataset, DataLoader
class ProfileDataset(Dataset):
	"""A data generator for BPNet inputs.
	This generator takes in an extracted set of sequences, output signals,
	and control signals, and will return a single element with random
	jitter and reverse-complement augmentation applied. Jitter is implemented
	efficiently by taking in data that is wider than the in/out windows by
	two times the maximum jitter and windows are extracted from that.
	Essentially, if an input window is 1000 and the maximum jitter is 128, one
	would pass in data with a length of 1256 and a length 1000 window would be
	extracted starting between position 0 and 256. This  generator must be 
	wrapped by a PyTorch generator object.
	Parameters
	----------
	sequences: torch.tensor, shape=(n, 4, in_window+2*max_jitter)
		A one-hot encoded tensor of `n` example sequences, each of input 
		length `in_window`. See description above for connection with jitter.
	signals: torch.tensor, shape=(n, t, out_window+2*max_jitter)
		The signals to predict, usually counts, for `n` examples with
		`t` output tasks (usually 2 if stranded, 1 otherwise), each of 
		output length `out_window`. See description above for connection 
		with jitter.
	controls: torch.tensor, shape=(n, t, out_window+2*max_jitter) or None, optional
		The control signal to take as input, usually counts, for `n`
		examples with `t` strands and output length `out_window`. If
		None, does not return controls.
	in_window: int, optional
		The input window size. Default is 2114.
	out_window: int, optional
		The output window size. Default is 1000.
	max_jitter: int, optional
		The maximum amount of jitter to add, in either direction, to the
		midpoints that are passed in. Default is 0.
	reverse_complement: bool, optional
		Whether to reverse complement-augment half of the data. Default is False.
	random_state: int or None, optional
		Whether to use a deterministic seed or not.
	"""

	def __init__(
		self, 
		sequences, 
		signals, 
		controls=None, 
		in_window=2114, 
		out_window=1000, 
		max_jitter=0, 
		reverse_complement=False, 
		random_state=None
	):
		self.in_window = in_window
		self.out_window = out_window
		self.max_jitter = max_jitter
		
		self.reverse_complement = reverse_complement
		self.random_state = np.random.RandomState(random_state)

		self.signals = signals
		self.controls = controls
		self.sequences = sequences	

	def __len__(self):
		return len(self.sequences)

	def __getitem__(self, idx):
		#i = self.random_state.choice(len(self.sequences))
		j = 0 if self.max_jitter == 0 else self.random_state.randint(self.max_jitter*2) 

		X = self.sequences[idx][:, j:j+self.in_window]
		y = self.signals[idx][:, j:j+self.out_window]

		if self.controls is not None:
			X_ctl = self.controls[idx][:, j:j+self.in_window]

		if self.reverse_complement and self.random_state.choice(2) == 1:
			X = torch.flip(X, [0, 1])
			y = torch.flip(y, [0, 1])

			if self.controls is not None:
				X_ctl = torch.flip(X_ctl, [0, 1])

		if self.controls is not None:
			return X, X_ctl, y

		return X, y
	
	def to_dataloader(
		self, 
        batch_size=None, 
        pin_memory=True, 
        shuffle=True, 
        num_workers=0, 
        **kwargs
    ):
		"""Convert the dataset to a PyTorch DataLoader

		Parameters:
		----------
		batch_size (int, optional):
			batch size for dataloader
		pin_memory (bool, optional):
			whether to pin memory for dataloader
		shuffle (bool, optional):
			whether to shuffle the dataset
		num_workers (int, optional):
			number of workers for dataloader
		**kwargs:
			additional arguments to pass to DataLoader
		"""
		batch_size = batch_size if batch_size is not None else eu.settings.batch_size
		return DataLoader(
			self,
			batch_size=batch_size,
			pin_memory=pin_memory,
			shuffle=shuffle,
			num_workers=num_workers,
			**kwargs
		)

In [None]:
toy_dataset = ProfileDataset(
    X_toy,
    y_toy,
    X_ctl_toy,
)

In [None]:
toy_dataset[0][0].shape, toy_dataset[0][1].shape, toy_dataset[0][2].shape

(torch.Size([4, 2114]), torch.Size([2, 2114]), torch.Size([2, 1000]))

In [None]:
toy_dataloader = toy_dataset.to_dataloader(shuffle=False)

In [None]:
batch = next(iter(toy_dataloader))
batch[0].shape, batch[1].shape, batch[2].shape

(torch.Size([100, 4, 2114]),
 torch.Size([100, 2, 2114]),
 torch.Size([100, 2, 1000]))

In [None]:
batch[0]

tensor([[[1., 1., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 1.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 1., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 1., 0., 1.],
         [0., 1., 0.,  ..., 0., 1., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 1., 0., 0.],
         [0., 1., 0.,  ..., 0., 1., 0.]],

        [[1., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 1., 0.,  ..., 0., 0

In [None]:
from bpnetlite.io import PeakGenerator

In [None]:
bpnetlite_dataloader = PeakGenerator(
    peaks,
    seqs,
    signals,
    controls,
    max_jitter=0,
    batch_size=128,
    random_state=13
)   

In [None]:
bpnetlite_batch = next(iter(bpnetlite_dataloader))
bpnetlite_batch[0].shape, bpnetlite_batch[1].shape, bpnetlite_batch[2].shape

(torch.Size([100, 4, 2114]),
 torch.Size([100, 2, 2114]),
 torch.Size([100, 2, 1000]))

In [None]:
bpnetlite_batch[0]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 1.]],

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 1.,  ..., 0., 1., 0.],
         [1., 0., 0.,  ..., 1., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 1., 0., 0.],
         [0., 1., 0.,  ..., 0., 1., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [1., 1., 1.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 1.,  ..., 1., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 1.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0