# Testing `ResidualBind` model class

**Authorship:**
Adam Klie, *11/05/2022*
***
**Description:**
Notebook for testing out the custom `ResidualBind` model class.

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import eugene as eu

Global seed set to 13


In [None]:
from captum.attr import DeepLiftShap

In [3]:
# bpnetlite
import time 
import numpy
import torch

from .losses import MNLLLoss
from .losses import log1pMSELoss

from .performance import pearson_corr
from .performance import calculate_performance_measures

torch.backends.cudnn.benchmark = True

class BPNet(torch.nn.Module):
	"""A basic BPNet model with stranded profile and total count prediction.
	This is a reference implementation for BPNet. The model takes in
	one-hot encoded sequence, runs it through: 
	(1) a single wide convolution operation 
	THEN 
	(2) a user-defined number of dilated residual convolutions
	THEN
	(3a) profile predictions done using a very wide convolution layer 
	that also takes in stranded control tracks 
	AND
	(3b) total count prediction done using an average pooling on the output
	from 2 followed by concatenation with the log1p of the sum of the
	stranded control tracks and then run through a dense layer.
	This implementation differs from the original BPNet implementation in
	two ways:
	(1) The model concatenates stranded control tracks for profile
	prediction as opposed to adding the two strands together and also then
	smoothing that track 
	(2) The control input for the count prediction task is the log1p of
	the strand-wise sum of the control tracks, as opposed to the raw
	counts themselves.
	(3) A single log softmax is applied across both strands such that
	the logsumexp of both strands together is 0. Put another way, the
	two strands are concatenated together, a log softmax is applied,
	and the MNLL loss is calculated on the concatenation. 
	(4) The count prediction task is predicting the total counts across
	both strands. The counts are then distributed across strands according
	to the single log softmax from 3.
	Parameters
	----------
	n_filters: int, optional
		The number of filters to use per convolution. Default is 64.
	n_layers: int, optional
		The number of dilated residual layers to include in the model.
		Default is 8.
	n_outputs: int, optional
		The number of profile outputs from the model. Generally either 1 or 2 
		depending on if the data is unstranded or stranded. Default is 2.
	alpha: float, optional
		The weight to put on the count loss.
	name: str or None, optional
		The name to save the model to during training.
	trimming: int or None, optional
		The amount to trim from both sides of the input window to get the
		output window. This value is removed from both sides, so the total
		number of positions removed is 2*trimming.
	"""

	def __init__(self, n_filters=64, n_layers=8, n_outputs=2, 
		n_control_tracks=2, alpha=1, profile_output_bias=True, 
		count_output_bias=True, name=None, trimming=None):
		super(BPNet, self).__init__()
		self.n_filters = n_filters
		self.n_layers = n_layers
		self.n_outputs = n_outputs
		self.n_control_tracks = n_control_tracks

		self.alpha = alpha
		self.name = name or "bpnet.{}.{}".format(n_filters, n_layers)
		self.trimming = trimming or 2 ** n_layers

		self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10)

		self.rconvs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, 
				dilation=2**i) for i in range(1, self.n_layers+1)
		])

		self.fconv = torch.nn.Conv1d(n_filters+n_control_tracks, n_outputs, kernel_size=75, 
			padding=37, bias=profile_output_bias)
		
		n_count_control = 1 if n_control_tracks > 0 else 0
		self.linear = torch.nn.Linear(n_filters+n_count_control, 1, 
			bias=count_output_bias)

	def forward(self, X, X_ctl=None):
		"""A forward pass of the model.
		This method takes in a nucleotide sequence X, a corresponding
		per-position value from a control track, and a per-locus value
		from the control track and makes predictions for the profile 
		and for the counts. This per-locus value is usually the
		log(sum(X_ctl_profile)+1) when the control is an experimental
		read track but can also be the output from another model.
		Parameters
		----------
		X: torch.tensor, shape=(batch_size, 4, sequence_length)
			The one-hot encoded batch of sequences.
		X_ctl: torch.tensor, shape=(batch_size, n_strands, sequence_length)
			A value representing the signal of the control at each position in the
			sequence.
		Returns
		-------
		y_profile: torch.tensor, shape=(batch_size, n_strands, out_length)
			The output predictions for each strand.
		"""

		start, end = self.trimming, X.shape[2] - self.trimming

		X = torch.nn.ReLU()(self.iconv(X))
		for i in range(self.n_layers):
			X_conv = torch.nn.ReLU()(self.rconvs[i](X))
			X = torch.add(X, X_conv)

		if X_ctl is None:
			X_w_ctl = X
		else:
			X_w_ctl = torch.cat([X, X_ctl], dim=1)

		y_profile = self.fconv(X_w_ctl)[:, :, start:end]

		# counts prediction
		X = torch.mean(X[:, :, start-37:end+37], axis=2)

		if X_ctl is not None:
			X_ctl = torch.sum(X_ctl[:, :, start-37:end+37], axis=(1, 2))
			X_ctl = X_ctl.unsqueeze(-1)
			X = torch.cat([X, torch.log(X_ctl+1)], dim=-1)

		y_counts = self.linear(X).reshape(X.shape[0], 1)
		return y_profile, y_counts

In [None]:
#yuzu
class BPNet(torch.nn.Module):
	def __init__(self, n_inputs, n_filters=64, kernel_size=21, seq_len=None, n_layers=4, random_state=0):
		super(BPNet, self).__init__()
		torch.manual_seed(random_state)

		
		self.iconv = torch.nn.Conv1d(n_inputs, n_filters, kernel_size=21, padding=10)
		self.irelu = torch.nn.ReLU()

		self.dconv1 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2, dilation=2)
		self.drelu1 = torch.nn.ReLU()

		self.dconv2 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=4, dilation=4)
		self.drelu2 = torch.nn.ReLU()        

		self.dconv3 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=8, dilation=8)
		self.drelu3 = torch.nn.ReLU()

		#self.dconv4 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=16, dilation=16)
		#self.drelu4 = torch.nn.ReLU()

		#self.dconv5 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=32, dilation=32)
		#self.drelu5 = torch.nn.ReLU()

		#self.dconv6 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=64, dilation=64)
		#self.drelu6 = torch.nn.ReLU()

		#self.dconv7 = torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=128, dilation=128)
		#self.drelu7 = torch.nn.ReLU()

		self.fconv = torch.nn.Conv1d(n_filters, 1, kernel_size=75, padding=37)
		#self.logsoftmax = torch.nn.LogSoftmax(dim=-1)

	def forward(self, X):
		with torch.no_grad():
			X = self.irelu(self.iconv(X))
			
			X = self.drelu1(self.dconv1(X))
			X = self.drelu2(self.dconv2(X))
			X = self.drelu3(self.dconv3(X))

			X = self.fconv(X)
			#X = self.logsoftmax(self.fconv(X))
			return X

In [4]:
model = Basset(2)



In [5]:
x = torch.randn(10, 4, 100)
model(x)

tensor([[0.5405, 0.4398],
        [0.6291, 0.5908],
        [0.5014, 0.6883],
        [0.3922, 0.7281],
        [0.5423, 0.5404],
        [0.4932, 0.5351],
        [0.5355, 0.4664],
        [0.4262, 0.4420],
        [0.4865, 0.6399],
        [0.4958, 0.5314]], grad_fn=<SigmoidBackward0>)

In [25]:
sdata = eu.datasets.random1000()
eu.pp.ohe_seqs_sdata(sdata)
eu.pp.train_test_split_sdata(sdata)

One-hot encoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 1000 ohe_seqs added
SeqData object modified:
    seqs_annot:
        + train_val


In [26]:
eu.train.fit(model, sdata, target_keys="activity_0", epochs=1, batch_size=32)

Global seed set to 13
Missing logger folder: /workspaces/EUGENe/tests/notebooks/implement/models/eugene_logs/ssResidualBind_regression
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name           | Type                      | Params
-------------------------------------------------------------
0 | hp_metric      | R2Score                   | 0     
1 | conv           | BasicConv1D               | 4.5 K 
2 | residual_block | ResidualModule            | 83.8 K
3 | average_pool   | AvgPool1d                 | 0     
4 | dropout        | Dropout                   | 0     
5 | flatten        | Flatten                   | 0     
6 | fc             | BasicFullyConnectedModule | 2.0 M 
-------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.320     Total estimated model params size (MB)


Dropping 0 sequences with NaN targets.
No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 13
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]