# Testing `DeepSTARR` model class

**Authorship:**
Adam Klie, *11/05/2022*
***
**Description:**
Notebook for testing out the custom `DeepSTARR` model class.

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import eugene as eu

Global seed set to 13


In [25]:
from eugene.models.base import BaseModel
import torch.nn as nn
import torch.nn.functional as F
from eugene.models.base._utils import GetFlattenDim
from eugene.models.base import BasicFullyConnectedModule, BasicConv1D


class DeepSTARR(BaseModel):
    """DeepSTARR model from de Almeida et al., 2022; 
        see <https://www.nature.com/articles/s41588-022-01048-5>


    Parameters
    """
    def __init__(
        self, 
        input_len: int = 249,
        output_dim = 2, 
        strand = "ss",
        task = "regression",
        aggr = None,
        loss_fxn = "mse",
        conv_kwargs = {},
        fc_kwargs = {},
        **kwargs
    ):
        super().__init__(
            input_len, 
            output_dim, 
            strand=strand, 
            task=task, 
            aggr=aggr, 
            loss_fxn=loss_fxn, 
            **kwargs
        )
        self.conv_kwargs, self.fc_kwargs = self.kwarg_handler(conv_kwargs, fc_kwargs)
        self.convnet = BasicConv1D(
            input_len=input_len, 
            **self.conv_kwargs)
        self.fcn = BasicFullyConnectedModule(
            input_dim=self.convnet.flatten_dim, 
            output_dim=output_dim, 
            **self.fc_kwargs
        )

    def forward(self, x, x_rev_comp=None):
        x = self.convnet(x)
        x = x.view(x.size(0), self.convnet.flatten_dim)
        x = self.fcn(x)
        return x
        
    def kwarg_handler(self, conv_kwargs, fc_kwargs):
        """Sets default kwargs for conv and fc modules if not specified"""
        conv_kwargs.setdefault("channels", [4, 246, 60, 60, 120])
        conv_kwargs.setdefault("conv_kernels", [7, 3, 5, 3])
        conv_kwargs.setdefault("conv_strides", [1, 1, 1, 1])
        conv_kwargs.setdefault("padding", "same")
        conv_kwargs.setdefault("pool_kernels", [2, 2, 2, 2])
        conv_kwargs.setdefault("omit_final_pool", False)
        conv_kwargs.setdefault("dropout_rates", 0.0)
        conv_kwargs.setdefault("batchnorm", True)
        fc_kwargs.setdefault("hidden_dims", [256, 256])
        fc_kwargs.setdefault("dropout_rate", 0.4)
        fc_kwargs.setdefault("batchnorm", True)
        return conv_kwargs, fc_kwargs

In [31]:
model = DeepSTARR(input_len=100, output_dim=2)

In [32]:
x = torch.randn(10, 4, 100)

In [36]:
model(x, x).shape

torch.Size([10, 2])

In [33]:
sdata = eu.datasets.random1000()
eu.pp.ohe_seqs_sdata(sdata)
eu.pp.train_test_split_sdata(sdata)

One-hot encoding sequences:   0%|          | 0/1000 [00:00<?, ?it/s]

SeqData object modified:
	ohe_seqs: None -> 1000 ohe_seqs added
SeqData object modified:
    seqs_annot:
        + train_val


In [37]:
eu.train.fit(model, sdata, target_keys=["activity_0", "activity_1"], epochs=1, batch_size=32)

Global seed set to 13
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type                      | Params
--------------------------------------------------------
0 | hp_metric | R2Score                   | 0     
1 | convnet   | BasicConv1D               | 92.2 K
2 | fcn       | BasicFullyConnectedModule | 251 K 
--------------------------------------------------------
344 K     Trainable params
0         Non-trainable params
344 K     Total params
1.377     Total estimated model params size (MB)


Dropping 0 sequences with NaN targets.
No transforms given, assuming just need to tensorize.
No transforms given, assuming just need to tensorize.


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
Global seed set to 13
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]