# Testing `ResidualBind` model class

**Authorship:**
Adam Klie, *11/05/2022*
***
**Description:**
Notebook for testing out the custom `ResidualBind` model class.

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import eugene as eu

Global seed set to 13


In [2]:
from eugene.models.base import BasicConv1D
class ResidualModule(nn.Module):
    """Generates a PyTorch module with the residual binding architecture described in:

    Parameters
    ----------
    input_len : int
        Length of the input sequence

    """
    def __init__(
        self, 
        input_len, 
        channels, 
        conv_kernels, 
        conv_strides,
        dilations, 
        pool_kernels=None, 
        activation="relu", 
        pool_strides=None, 
        dropout_rates=0.0, 
        padding="same", 
        batchnorm=True
    ):
        super().__init__()
        self.module = BasicConv1D(
            input_len=input_len,
            channels=channels,
            conv_kernels=conv_kernels,
            conv_strides=conv_strides,
            pool_kernels=pool_kernels,
            activation=activation,
            pool_strides=pool_strides,
            dropout_rates=dropout_rates,
            dilations=dilations,
            padding=padding,
            batchnorm=batchnorm
        )

    def forward(self, x):
        x_fwd = self.module(x)
        return F.relu(x_fwd + x)

In [8]:
ResidualModule(
    input_len=100,
    channels=[96, 96, 96, 96],
    conv_kernels=[3, 3, 3],
    conv_strides=[1, 1, 1],
    dilations=[1, 2, 4],
    dropout_rates=0.1,
)

ResidualModule(
  (module): BasicConv1D(
    (module): Sequential(
      (0): Conv1d(96, 96, kernel_size=(3,), stride=(1,), padding=same)
      (1): ReLU()
      (2): Dropout(p=0.1, inplace=False)
      (3): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): Conv1d(96, 96, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (5): ReLU()
      (6): Dropout(p=0.1, inplace=False)
      (7): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): Conv1d(96, 96, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (9): ReLU()
      (10): Dropout(p=0.1, inplace=False)
      (11): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

In [22]:
from eugene.models.base import BaseModel
import torch.nn as nn
import torch.nn.functional as F
from eugene.models.base._utils import GetFlattenDim
from eugene.models.base import BasicFullyConnectedModule

class ResidualBind(BaseModel):
    def __init__(
        self,
        input_len,
        output_dim,
        strand="ss",
        task="regression",
        aggr=None,
        conv_channels=[96],
        conv_kernel_size=[11],
        conv_stride_size=[1],
        conv_dilation_rate=[1],
        conv_padding="same",
        conv_activation="relu",
        conv_batchnorm=True,
        conv_dropout=0.1,
        residual_channels=[3, 3, 3],
        residual_kernel_size=[11, 11, 11],
        residual_stride_size=[1, 1, 1],
        residual_dilation_rate=[1, 1, 1],
        residual_padding="same",
        residual_activation="relu",
        residual_batchnorm=True,
        residual_dropout=0.1,
        pool_kernel_size=10,
        pool_dropout=0.2,
        fc_hidden_dims=[256],
        fc_activation="relu",
        fc_batchnorm=True,
        fc_dropout=0.0,
        **kwargs
    ):
        super().__init__(
            input_len, output_dim, strand=strand, task=task, aggr=aggr, **kwargs
        )
        if isinstance(conv_channels, int):
            conv_channels = [conv_channels]
        self.conv = BasicConv1D(
            input_len=input_len,
            channels=[4] + conv_channels,
            conv_kernels=conv_kernel_size,
            conv_strides=conv_stride_size,
            pool_kernels=None,
            activation=conv_activation,
            pool_strides=None,
            dropout_rates=conv_dropout,
            dilations=conv_dilation_rate,
            padding=conv_padding,
            batchnorm=conv_batchnorm
        )
        res_block_input_len = GetFlattenDim(self.conv.module, seq_len=input_len)
        self.residual_block = ResidualModule(
            input_len=res_block_input_len,
            channels=[self.conv.out_channels] + residual_channels,
            conv_kernels=residual_kernel_size,
            conv_strides=residual_stride_size,
            pool_kernels=None,
            activation=residual_activation,
            pool_strides=None,
            dropout_rates=residual_dropout,
            dilations=residual_dilation_rate,
            padding=residual_padding,
            batchnorm=residual_batchnorm
        )
        self.average_pool = nn.AvgPool1d(pool_kernel_size, stride=1)
        self.dropout = nn.Dropout(pool_dropout)
        self.flatten = nn.Flatten()
        self.fc = BasicFullyConnectedModule(
            input_dim=self.residual_block.module.out_channels*(res_block_input_len-pool_kernel_size+1),
            output_dim=output_dim,
            hidden_dims=fc_hidden_dims,
            activation=fc_activation,
            batchnorm=fc_batchnorm,
            dropout_rate=fc_dropout
        )

    def forward(self, x, x_rev):
        x = self.conv(x)
        x = self.residual_block(x)
        x = self.average_pool(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x


In [23]:
model = ResidualBind(
    input_len=100,
    output_dim=1,
    conv_channels=[96],
    conv_kernel_size=[11],
    conv_stride_size=[1],
    conv_dilation_rate=[1],
    conv_padding="valid",
    conv_activation="relu",
    conv_batchnorm=True,
    conv_dropout=0.1,
    residual_channels=[96, 96, 96],
    residual_kernel_size=[3, 3, 3],
    residual_stride_size=[1, 1, 1],
    residual_dilation_rate=[1, 2, 4],
    residual_padding="same",
    residual_activation="relu",
    residual_batchnorm=True,
    residual_dropout=0.1,
    pool_dropout=0.2,
    fc_hidden_dims=[256],
    fc_activation="relu",
    fc_batchnorm=True,
    fc_dropout=0.5,
)
model

ResidualBind(
  (hp_metric): R2Score()
  (conv): BasicConv1D(
    (module): Sequential(
      (0): Conv1d(4, 96, kernel_size=(11,), stride=(1,), padding=valid)
      (1): ReLU()
      (2): Dropout(p=0.1, inplace=False)
      (3): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (residual_block): ResidualModule(
    (module): BasicConv1D(
      (module): Sequential(
        (0): Conv1d(96, 96, kernel_size=(3,), stride=(1,), padding=same)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): Conv1d(96, 96, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
        (5): ReLU()
        (6): Dropout(p=0.1, inplace=False)
        (7): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): Conv1d(96, 96, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
        (9): ReLU()
        

In [24]:
x = torch.randn(10, 4, 100)
model(x, x)

tensor([[ 0.3461],
        [-0.3445],
        [-0.2476],
        [-0.3241],
        [ 0.2127],
        [-0.0573],
        [ 0.3616],
        [-0.1408],
        [ 0.3614],
        [-0.1719]], grad_fn=<AddmmBackward0>)

In [25]:
sdata = eu.datasets.random1000()
eu.pp.ohe_seqs_sdata(sdata)
eu.pp.train_test_split_sdata(sdata)

One-hot encoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 1000 ohe_seqs added
SeqData object modified:
    seqs_annot:
        + train_val


In [26]:
eu.train.fit(model, sdata, target_keys="activity_0", epochs=1, batch_size=32)

Global seed set to 13
Missing logger folder: /workspaces/EUGENe/tests/notebooks/implement/models/eugene_logs/ssResidualBind_regression
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name           | Type                      | Params
-------------------------------------------------------------
0 | hp_metric      | R2Score                   | 0     
1 | conv           | BasicConv1D               | 4.5 K 
2 | residual_block | ResidualModule            | 83.8 K
3 | average_pool   | AvgPool1d                 | 0     
4 | dropout        | Dropout                   | 0     
5 | flatten        | Flatten                   | 0     
6 | fc             | BasicFullyConnectedModule | 2.0 M 
-------------------------------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.320     Total estimated model params size (MB)


Dropping 0 sequences with NaN targets.
No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 13
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]