# Testing CLAIM Modules

**Authorship:**
Adam Klie, *03/19/2022*
***
**Description:**
Notebook for testing out CLAIM modules for building EUGENE architectures

<div class="alert alert-block alert-warning">
<b>TODOs</b>:
<ul>
    <b><li></li></b>
    </ul>
</div>

In [1]:
import numpy as np
import pandas as pd
import torch

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

import claim.modules as cm

# DeepSea module
<div class="alert alert-info" role="alert">
  <b>Just test out the basic function of a DeepSea module</b>
</div>

In [2]:
convnet = cm.DeepSeaModule(input_len=66, channels=[15, 5, 5], pool_kernels=[1, 1], dropout_rates=0.1)
convnet

DeepSeaModule(
  (module): Sequential(
    (0): Conv1d(4, 15, kernel_size=(8,), stride=(1,))
    (1): ReLU(inplace=True)
    (2): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0, inplace=False)
    (4): Conv1d(15, 5, kernel_size=(8,), stride=(1,))
    (5): ReLU(inplace=True)
    (6): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=False)
    (7): Dropout(p=0, inplace=False)
    (8): Conv1d(5, 5, kernel_size=(8,), stride=(1,))
    (9): ReLU(inplace=True)
    (10): Dropout(p=0, inplace=False)
  )
)

In [3]:
x = torch.randn(10, 4, 66)
out = convnet(x)
assert out.shape[1]*out.shape[2] == convnet.flatten_dim
out.shape

torch.Size([10, 5, 45])

# BasicConv1D module
<div class="alert alert-info" role="alert">
  <b>Generate a `conv_out` to pass to the recurrent module</b>
</div>

In [4]:
convnet = cm.BasicConv1D(input_len=66, channels=[4, 16, 32, 64], conv_kernels=[15, 5, 5], pool_kernels=[1, 1, 1], pool_strides = [1, 1, 1], dropout_rates=0.2)

In [6]:
x = torch.randn(10, 4, 66)
conv_out = convnet(x)
assert conv_out.shape[1]*conv_out.shape[2] == convnet.flatten_dim
conv_out.shape

torch.Size([10, 64, 44])

# BasicRecurrent module
<div class="alert alert-info" role="alert">
  <b>Takes in input from the `convnet` and applies an RNN to it</b>
</div>

<div class="alert alert-info" role="alert">
  Start with a basic test of the architecture
</div>

In [7]:
rnn = cm.BasicRecurrent(input_dim=4, output_dim=32, unit_type="lstm", batch_first=True)

In [10]:
x = torch.randn(10, 66, 4)
out, (seq, hidden) = rnn(x)
print(out.shape, seq.shape, hidden.shape)

torch.Size([10, 66, 32]) torch.Size([1, 10, 32]) torch.Size([1, 10, 32])


<div class="alert alert-info" role="alert">
  Now pass in the conv_out
</div>

In [12]:
rnn = cm.BasicRecurrent(input_dim=conv_out.shape[1], output_dim=32, unit_type="lstm", bidirectional=False, batch_first=True)
rnn

BasicRecurrent(
  (module): LSTM(64, 32, batch_first=True)
)

In [13]:
rnn_out, _ = rnn(conv_out.transpose(1, 2))
rnn_out.shape, rnn_out[:, -1, :].shape, rnn.out_dim

(torch.Size([10, 44, 32]), torch.Size([10, 32]), 32)

# Fully connected module
<div class="alert alert-info" role="alert">
  <b>Takes in input from anywhere (CNN, RNN, etc)</b>
</div>


<div class="alert alert-info" role="alert">
  Start with a basic test of the architecture
</div>

In [15]:
fcnet = cm.BasicFullyConnectedModule(100, 1, [25, 5, 5], activation="relu", dropout_rate=0.2, batchnorm=True)
fcnet

BasicFullyConnectedModule(
  (module): Sequential(
    (0): Linear(in_features=100, out_features=25, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=25, out_features=5, bias=True)
    (5): ReLU(inplace=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Linear(in_features=5, out_features=5, bias=True)
    (9): ReLU(inplace=True)
    (10): Dropout(p=0.2, inplace=False)
    (11): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Linear(in_features=5, out_features=1, bias=True)
  )
)

In [16]:
x = torch.randn(10, 100)
out = fcnet(x)
assert out.shape[1] == 1
out.shape

torch.Size([10, 1])

<div class="alert alert-info" role="alert">
  Test the conv_out too
</div>

In [18]:
fcnet = cm.BasicFullyConnectedModule(convnet.flatten_dim, 1, [25, 5, 5], activation="relu", dropout_rate=0.2, batchnorm=True)
fcnet

BasicFullyConnectedModule(
  (module): Sequential(
    (0): Linear(in_features=2816, out_features=25, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=25, out_features=5, bias=True)
    (5): ReLU(inplace=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Linear(in_features=5, out_features=5, bias=True)
    (9): ReLU(inplace=True)
    (10): Dropout(p=0.2, inplace=False)
    (11): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Linear(in_features=5, out_features=1, bias=True)
  )
)

In [19]:
out = fcnet(conv_out.view(out.size(0), convnet.flatten_dim))
assert out.shape[1] == 1
out.shape

torch.Size([10, 1])

<div class="alert alert-info" role="alert">
  Finally, pass in the rnn_out
</div>

In [23]:
num_classes=4
fcnet = cm.BasicFullyConnectedModule(rnn.out_dim, num_classes, [25, 5, 5], activation="relu", dropout_rate=0.2, batchnorm=True)
fcnet

BasicFullyConnectedModule(
  (module): Sequential(
    (0): Linear(in_features=32, out_features=25, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): BatchNorm1d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Linear(in_features=25, out_features=5, bias=True)
    (5): ReLU(inplace=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): Linear(in_features=5, out_features=5, bias=True)
    (9): ReLU(inplace=True)
    (10): Dropout(p=0.2, inplace=False)
    (11): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Linear(in_features=5, out_features=4, bias=True)
  )
)

In [24]:
out = fcnet(rnn_out[:, -1, :])
assert out.shape[1] == num_classes
out.shape

torch.Size([10, 4])