# Testing `Basset` model class

**Authorship:**
Adam Klie, *11/06/2022*
***
**Description:**
Notebook for testing out the custom `Basset` model class.

In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import eugene as eu

Global seed set to 13


In [3]:
from eugene.models.base import BaseModel
import torch.nn as nn
import torch.nn.functional as F
from eugene.models.base._utils import GetFlattenDim
from eugene.models.base import BasicFullyConnectedModule, BasicConv1D


class Basset(BaseModel):
    """
    """
    def __init__(
        self, 
        input_len: int = 1000,
        output_dim = 1, 
        strand = "ss",
        task = "binary_classification",
        aggr = None,
        loss_fxn = "bce",
        conv_kwargs = {},
        fc_kwargs = {},
        **kwargs
    ):
        super().__init__(
            input_len, 
            output_dim, 
            strand=strand, 
            task=task, 
            aggr=aggr, 
            loss_fxn=loss_fxn, 
            **kwargs
        )
        self.conv_kwargs, self.fc_kwargs = self.kwarg_handler(conv_kwargs, fc_kwargs)
        self.convnet = BasicConv1D(
            input_len=input_len, 
            **self.conv_kwargs)
        self.fcn = BasicFullyConnectedModule(
            input_dim=self.convnet.flatten_dim, 
            output_dim=output_dim, 
            **self.fc_kwargs
        )

    def forward(self, x, x_rev_comp=None):
        x = self.convnet(x)
        x = x.view(x.size(0), self.convnet.flatten_dim)
        x = self.fcn(x)
        return x
        
    def kwarg_handler(self, conv_kwargs, fc_kwargs):
        """Sets default kwargs for conv and fc modules if not specified"""
        conv_kwargs.setdefault("channels", [4, 300, 200, 200])
        conv_kwargs.setdefault("conv_kernels", [19, 11, 7])
        conv_kwargs.setdefault("conv_strides", [1, 1, 1])
        conv_kwargs.setdefault("padding", [9, 5, 3])
        conv_kwargs.setdefault("pool_kernels", [3, 4, 4])
        conv_kwargs.setdefault("omit_final_pool", False)
        conv_kwargs.setdefault("dropout_rates", 0.0)
        conv_kwargs.setdefault("batchnorm", True)
        conv_kwargs.setdefault("activation", "relu")
        fc_kwargs.setdefault("hidden_dims", [1000, 164])
        fc_kwargs.setdefault("dropout_rate", 0.0)
        fc_kwargs.setdefault("batchnorm", True)
        fc_kwargs.setdefault("activation", "relu")
        return conv_kwargs, fc_kwargs

In [9]:
model = Basset(task="regression")

In [10]:
x = torch.randn(10, 4, 1000)
model(x).shape

torch.Size([10, 200, 20])


torch.Size([10, 1])

In [11]:
sdata = eu.dl.SeqData(seqs=eu.utils.random_seqs(1000, 1000))
sdata.make_names_unique()
sdata["activity"] = np.random.randn(1000)

In [12]:
eu.pp.ohe_seqs_sdata(sdata)
eu.pp.train_test_split_sdata(sdata)

One-hot encoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 1000 ohe_seqs added
SeqData object modified:
    seqs_annot:
        + train_val


In [13]:
eu.train.fit(model, sdata, target_keys="activity", epochs=1, batch_size=32)

Global seed set to 13
Missing logger folder: /workspaces/EUGENe/tests/notebooks/implement/models/eugene_logs/ssBasset_regression
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type                      | Params
--------------------------------------------------------
0 | hp_metric | R2Score                   | 0     
1 | convnet   | BasicConv1D               | 964 K 
2 | fcn       | BasicFullyConnectedModule | 4.2 M 
--------------------------------------------------------
5.1 M     Trainable params
0         Non-trainable params
5.1 M     Total params
20.530    Total estimated model params size (MB)


Dropping 0 sequences with NaN targets.
No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 13


torch.Size([32, 200, 20])
torch.Size([32, 200, 20])


  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])


Validating: 0it [00:00, ?it/s]

torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([32, 200, 20])
torch.Size([8, 200, 20])


In [20]:
from eugene.models.base import BaseModel
import torch.nn as nn
import torch.nn.functional as F
from eugene.models.base._utils import GetFlattenDim
from eugene.models.base import BasicFullyConnectedModule, BasicConv1D

class FactorizedBasset(BaseModel):
	def __init__(
		self, 
		input_len: int = 1000,
		output_dim = 1, 
		strand = "ss",
		task = "binary_classification",
		aggr = None,
		loss_fxn = "bce",
		conv1_kwargs = {},
		conv2_kwargs = {},
		conv3_kwargs = {},
		maxpool_kernels = None,
		fc_kwargs = {},
		**kwargs
	):
		super().__init__(
			input_len, 
			output_dim, 
			strand=strand, 
			task=task, 
			aggr=aggr, 
			loss_fxn=loss_fxn, 
			**kwargs
		)
		self.conv1_kwargs, self.conv2_kwargs, self.conv3_kwargs, self.maxpool_kernels, self.fc_kwargs = self.kwarg_handler(
			conv1_kwargs, 
			conv2_kwargs, 
			conv3_kwargs, 
			maxpool_kernels, 
			fc_kwargs
		)
		self.convnet1 = BasicConv1D(
			input_len=input_len, 
			**self.conv1_kwargs
		)
		self.maxpool1 = nn.MaxPool1d(self.maxpool_kernels[0])
		self.out1 = self.convnet1.flatten_dim/self.convnet1.out_channels // self.maxpool_kernels[0]
		self.convnet2 = BasicConv1D(
			input_len=self.out1,
			**self.conv2_kwargs
		)
		self.maxpool2 = nn.MaxPool1d(self.maxpool_kernels[1])
		self.out2 = self.convnet2.flatten_dim/self.convnet2.out_channels // self.maxpool_kernels[1]
		self.convnet3 = BasicConv1D(
			input_len=self.out2,
			**self.conv3_kwargs
		)
		self.maxpool3 = nn.MaxPool1d(self.maxpool_kernels[2])
		self.out3 = self.convnet3.flatten_dim/self.convnet3.out_channels // self.maxpool_kernels[2]
		self.fcnet_in = int(self.out3*self.convnet3.out_channels)
		self.fcnet = BasicFullyConnectedModule(
			input_dim=self.fcnet_in,
			output_dim=output_dim, 
			**self.fc_kwargs
		)

	def forward(self, x, x_rev_comp=None):
		x = self.convnet1(x)
		x = self.maxpool1(x)
		x = self.convnet2(x)
		x = self.maxpool2(x)
		x = self.convnet3(x)
		x = self.maxpool3(x)
		x = x.view(x.size(0), self.fcnet_in)
		x = self.fcnet(x)
		return x
        
	def kwarg_handler(self, conv1_kwargs, conv2_kwargs, conv3_kwargs, maxpool_kernels, fc_kwargs):
		"""Sets default kwargs for conv and fc modules if not specified"""
		conv1_kwargs.setdefault("channels", [4, 48, 64, 100, 150, 300])
		conv1_kwargs.setdefault("conv_kernels", [3, 3, 3, 7, 7])
		conv1_kwargs.setdefault("conv_strides", [1, 1, 1, 1, 1])
		conv1_kwargs.setdefault("padding", [1, 1, 1, 3, 3])
		conv1_kwargs.setdefault("pool_kernels", None)
		conv1_kwargs.setdefault("dropout_rates", 0.0)
		conv1_kwargs.setdefault("batchnorm", True)
		conv1_kwargs.setdefault("activation", "relu")
		conv2_kwargs.setdefault("channels", [300, 200, 200, 200])
		conv2_kwargs.setdefault("conv_kernels", [7, 3, 3])
		conv2_kwargs.setdefault("conv_strides", [1, 1, 1])
		conv2_kwargs.setdefault("padding", [3, 1, 1])
		conv2_kwargs.setdefault("pool_kernels", None)
		conv2_kwargs.setdefault("dropout_rates", 0.0)
		conv2_kwargs.setdefault("batchnorm", True)
		conv2_kwargs.setdefault("activation", "relu")
		conv3_kwargs.setdefault("channels", [200, 200])
		conv3_kwargs.setdefault("conv_kernels", [7])
		conv3_kwargs.setdefault("conv_strides", [1])
		conv3_kwargs.setdefault("padding", [3])
		conv3_kwargs.setdefault("pool_kernels", None)
		conv3_kwargs.setdefault("dropout_rates", 0.0)
		conv3_kwargs.setdefault("batchnorm", True)
		conv3_kwargs.setdefault("activation", "relu")
		maxpool_kernels = [3, 4, 4] if maxpool_kernels is None else maxpool_kernels
		fc_kwargs.setdefault("hidden_dims", [1000, 164])
		fc_kwargs.setdefault("dropout_rate", 0.0)
		fc_kwargs.setdefault("batchnorm", True)
		fc_kwargs.setdefault("activation", "relu")
		return conv1_kwargs, conv2_kwargs, conv3_kwargs, maxpool_kernels, fc_kwargs

In [23]:
model = FactorizedBasset(input_len=1000, task="regression")

In [25]:
x = torch.randn(10, 4, 1000)
model(x)

tensor([[ 0.1525],
        [-0.0843],
        [-0.1764],
        [ 0.1329],
        [-0.6349],
        [ 1.2232],
        [ 0.1299],
        [-0.6046],
        [ 0.1690],
        [-0.8385]], grad_fn=<AddmmBackward0>)

In [26]:
sdata = eu.dl.SeqData(seqs=eu.utils.random_seqs(1000, 1000))
sdata.make_names_unique()
sdata["activity"] = np.random.randn(1000)

In [27]:
eu.pp.ohe_seqs_sdata(sdata)
eu.pp.train_test_split_sdata(sdata)

One-hot encoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 1000 ohe_seqs added
SeqData object modified:
    seqs_annot:
        + train_val


In [28]:
eu.train.fit(model, sdata, target_keys="activity", epochs=1, batch_size=32)

Global seed set to 13
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type                      | Params
--------------------------------------------------------
0 | hp_metric | R2Score                   | 0     
1 | convnet1  | BasicConv1D               | 450 K 
2 | maxpool1  | MaxPool1d                 | 0     
3 | convnet2  | BasicConv1D               | 661 K 
4 | maxpool2  | MaxPool1d                 | 0     
5 | convnet3  | BasicConv1D               | 280 K 
6 | maxpool3  | MaxPool1d                 | 0     
7 | fcnet     | BasicFullyConnectedModule | 4.2 M 
--------------------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.244    Total estimated model params size (MB)


Dropping 0 sequences with NaN targets.
No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 13
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

: 

---