In [15]:
# Extensions
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload
%autoreload 2

In [27]:
# Imports
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from eugene.models.base import _layers as layers

In [2]:
# Configuration
ID = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
x = torch.randn(10, 4, 100)
x_rev_comp = torch.randn(10, 4, 100)
y = torch.randint(0, 2, (10, 1))
batch = (ID, x, x_rev_comp, y)

In [4]:
# Helper functions
def move_module_to_cpu(module):
    module.to("cpu")
    for child in module.children():
        move_module_to_cpu(child)

## Activations

### Identity

In [5]:
layer = layers.Identity(inplace=False)
layer_out = layer(x)
layer, layer_out.shape

(Identity(), torch.Size([10, 4, 100]))

In [6]:
def test_identity_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3).requires_grad_(True)
    
    # Initialize Identity layer
    identity_layer = layers.Identity()
    
    # Test forward pass
    output_tensor = identity_layer(input_tensor)
    assert torch.all(torch.eq(input_tensor, output_tensor))
    
    # Test backward pass
    output_tensor.sum().backward()
    assert torch.all(torch.eq(input_tensor.grad, torch.ones_like(input_tensor)))

In [7]:
test_identity_layer()

### Exponential

In [8]:
layer = layers.Exponential(inplace=False)
layer_out = layer(x)
layer, layer_out.shape

(Exponential(), torch.Size([10, 4, 100]))

In [9]:
def test_exponential_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3).requires_grad_(True)
    
    # Initialize Exponential layer
    exponential_layer = layers.Exponential()
    
    # Test forward pass
    output_tensor = exponential_layer(input_tensor)
    assert torch.all(torch.eq(output_tensor, torch.exp(input_tensor)))
    
    # Test backward pass
    output_tensor.sum().backward()
    assert torch.all(torch.eq(input_tensor.grad, output_tensor))

In [10]:
test_exponential_layer()

### GELU
Written because PyToch nn.GELU has not inplace

In [11]:
layer = layers.GELU(inplace=False)
layer_out = layer(x)
layer, layer_out.shape

(GELU(), torch.Size([10, 4, 100]))

### Sigmoid

In [12]:
layer = layers.Sigmoid(inplace=False)
layer_out = layer(x)
layer, layer_out.shape

(Sigmoid(), torch.Size([10, 4, 100]))

### Softplus

In [13]:
layer = layers.Softplus(inplace=False)
layer_out = layer(x)
layer, layer_out.shape

(Softplus(), torch.Size([10, 4, 100]))

## Convolutional

### BiConv1D

In [14]:
layer = layers.BiConv1D(4, 8, 3)
layer_out = layer(x)
layer, layer_out.shape

(BiConv1D(4, 8, kernel_size=3, stride=1, padding=same, dilation=1, groups=1, bias=True),
 torch.Size([10, 8, 100]))

In [15]:
def test_biconv1d_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 10)
    
    # Initialize BiConv1D layer
    biconv1d_layer = layers.BiConv1D(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding="same", dilation=1, groups=1, bias=True, dropout_rate=0.0)
    
    # Test forward pass
    output_tensor = biconv1d_layer(input_tensor)
    assert output_tensor.shape == (2, 4, 10)
    
    # Test backward pass
    output_tensor.sum().backward()
    assert biconv1d_layer.weight.grad.shape == (4, 3, 3)
    if biconv1d_layer.bias is not None:
        assert biconv1d_layer.bias.grad.shape == (4,)

In [16]:
test_biconv1d_layer()

### Inception1D

In [17]:
layer = layers.InceptionConv1D(in_channels=4, out_channels=16)
layer_out = layer(x)
layer, layer_out.shape

(InceptionConv1D(
   (conv1): Conv1d(4, 4, kernel_size=(1,), stride=(1,))
   (conv2_1): Conv1d(4, 4, kernel_size=(1,), stride=(1,))
   (conv2_2): Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=same)
   (conv3_1): Conv1d(4, 4, kernel_size=(1,), stride=(1,))
   (conv3_2): Conv1d(4, 4, kernel_size=(5,), stride=(1,), padding=same)
   (maxpool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
   (conv_maxpool): Conv1d(4, 4, kernel_size=(1,), stride=(1,))
 ),
 torch.Size([10, 16, 100]))

In [18]:
def test_inception_conv1d_layer():
    # Define input tensor
    x = torch.randn(1, 3, 10)

    # Initialize InceptionConv1D layer
    in_channels = 3
    out_channels = 16
    kernel_size2 = 3
    kernel_size3 = 5
    conv_maxpool_kernel_size = 3
    conv1_out_channels = out_channels // 4
    conv2_out_channels = out_channels // 4
    conv3_out_channels = out_channels // 4  
    conv_maxpool_out_channels = out_channels // 4
    layer = layers.InceptionConv1D(
        in_channels=in_channels,
        out_channels=out_channels,
        conv1_out_channels=conv1_out_channels,
        conv2_out_channels=conv2_out_channels,
        kernel_size2=kernel_size2,
        conv3_out_channels=conv3_out_channels,
        kernel_size3=kernel_size3,
        conv_maxpool_kernel_size=conv_maxpool_kernel_size,
        conv_maxpool_out_channels=conv_maxpool_out_channels
    )

    # Check output shape
    output = layer(x)
    assert output.shape == (1, out_channels, 10)

    # Check that each path produces the expected output shape
    conv1_out = layer.conv1(x)
    assert conv1_out.shape == (1, conv1_out_channels, 10)

    conv2_out = layer.conv2_2(layer.conv2_1(x))
    assert conv2_out.shape == (1, conv2_out_channels, 10)

    conv3_out = layer.conv3_2(layer.conv3_1(x))
    assert conv3_out.shape == (1, conv3_out_channels, 10)

    conv_maxpool_out = layer.conv_maxpool(layer.maxpool(x))
    assert conv_maxpool_out.shape == (1, conv_maxpool_out_channels, 10)

    # Check that the concatenated output has the expected shape
    expected_out_channels = conv1_out_channels + conv2_out_channels + conv3_out_channels + conv_maxpool_out_channels
    expected_output = torch.cat([conv1_out, conv2_out, conv3_out, conv_maxpool_out], dim=1)
    assert output.shape == expected_output.shape

In [19]:
test_inception_conv1d_layer()

## Transformer

### MultiHeadAttention

In [21]:
layer = layers.MultiHeadAttention(
    input_dim=4,
    head_dim=10,
    num_heads=2
)
layer_out = layer(x.transpose(1, 2), mask=None)
layer, layer_out.shape

(MultiHeadAttention(
   (qkv): Linear(in_features=4, out_features=60, bias=False)
   (softmax): Softmax(dim=-1)
   (dropout_layer): Dropout(p=0.0, inplace=False)
   (projection_layer): Sequential(
     (0): Linear(in_features=20, out_features=4, bias=True)
     (1): Dropout(p=0.0, inplace=False)
   )
 ),
 torch.Size([10, 100, 4]))

In [23]:
def test_multihead_attention_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4).requires_grad_(True)
    
    # Initialize MultiHeadAttention layer
    multihead_attention_layer = layers.MultiHeadAttention(input_dim=4, head_dim=2, num_heads=2, dropout_rate=0.0)
    
    # Test forward pass
    output_tensor = multihead_attention_layer(input_tensor)
    assert output_tensor.shape == (2, 3, 4)
    
    # Test backward pass
    output_tensor.sum().backward()
    assert input_tensor.grad is not None

test_multihead_attention_layer()

In [24]:
def test_attention_mechanism():
    # Test input tensor
    input_tensor = torch.arange(12).reshape(2, 3, 2).float()
    
    # Initialize MultiHeadAttention layer
    multihead_attention_layer = layers.MultiHeadAttention(input_dim=2, head_dim=1, num_heads=2, dropout_rate=0.0)
    
    # Test forward pass
    output_tensor = multihead_attention_layer(input_tensor)
    
    # Test attention mechanism
    attention_scores = multihead_attention_layer.attention_scores
    expected_scores = torch.tensor([[[1.0000, 0.0000, 0.0000],
                                      [0.7311, 0.2689, 0.0000],
                                      [0.5000, 0.5000, 0.0000]],
                                     [[1.0000, 0.0000, 0.0000],
                                      [0.7311, 0.2689, 0.0000],
                                      [0.5000, 0.5000, 0.0000]]])
    assert torch.allclose(attention_scores, expected_scores, rtol=1e-3, atol=1e-3)

def test_masking_functionality():
    # Test input tensor
    input_tensor = torch.arange(12).reshape(2, 3, 2).float()
    
    # Define mask
    mask = torch.tensor([[1, 1, 0], [1, 0, 0]])
    
    # Initialize MultiHeadAttention layer
    multihead_attention_layer = layers.MultiHeadAttention(input_dim=2, head_dim=1, num_heads=2, dropout_rate=0.0)
    
    # Test forward pass
    output_tensor = multihead_attention_layer(input_tensor, mask=mask)
    
    # Test masking functionality
    attention_scores = multihead_attention_layer.attention_scores
    expected_scores = torch.tensor([[[1.0000, 0.0000, 0.0000],
                                      [0.7311, 0.2689, 0.0000],
                                      [0.0000, 0.0000, 0.0000]],
                                     [[1.0000, 0.0000, 0.0000],
                                      [0.7311, 0.2689, 0.0000],
                                      [0.0000, 0.0000, 0.0000]]])
    assert torch.allclose(attention_scores, expected_scores, rtol=1e-3, atol=1e-3)

def test_projection_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4)
    
    # Initialize MultiHeadAttention layer
    multihead_attention_layer = layers.MultiHeadAttention(input_dim=4, head_dim=2, num_heads=2, dropout_rate=0.0, need_projection=True)
    
    # Test forward pass
    output_tensor = multihead_attention_layer(input_tensor)
    
    # Test projection layer
    expected_shape = (2, 3, 4)
    assert output_tensor.shape == expected_shape
    assert output_tensor.min() > -1.0 and output_tensor.max() < 1.0

def test_dropout_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4)
    
    # Initialize MultiHeadAttention layer
    multihead_attention_layer = layers.MultiHeadAttention(input_dim=4, head_dim=2, num_heads=2, dropout_rate=0.5)
    
    # Test forward pass
    output_tensor = multihead_attention_layer(input_tensor)
    
    # Test dropout layer
    expected_num_zeros = 6
    assert (output_tensor == 0.0).sum() == expected_num_zeros

## Wrappers

### Residual

In [26]:
conv_layer = torch.nn.Conv1d(4, 4, 5, padding="same")
layer = layers.Residual(conv_layer)
layer_out = layer(x)
layer, layer_out.shape

(Residual(
   (wrapped): Conv1d(4, 4, kernel_size=(5,), stride=(1,), padding=same)
 ),
 torch.Size([10, 4, 100]))

In [None]:
def test_residual_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4).requires_grad_(True)
    
    # Initialize Residual layer with a linear layer as the wrapped module
    residual_layer = layers.Residual(nn.Linear(4, 4))
    
    # Test forward pass
    output_tensor = residual_layer(input_tensor)
    assert output_tensor.shape == (2, 3, 4)
    assert torch.all(torch.eq(output_tensor, input_tensor + residual_layer.wrapped(input_tensor)))
    
    # Test backward pass
    #output_tensor.sum().backward()
    #assert torch.all(torch.eq(input_tensor.grad, output_tensor.grad + residual_layer.wrapped.weight.grad))

In [None]:
test_residual_layer()

## Gluers

### Flatten

In [28]:
layer = layers.Flatten()
layer_out = layer(x)
layer, x.shape, layer_out.shape

(Flatten(), torch.Size([10, 4, 100]), torch.Size([10, 400]))

In [31]:
def test_flatten_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4).requires_grad_(True)
    
    # Initialize Flatten layer
    flatten_layer = layers.Flatten()
    
    # Test forward pass
    output_tensor = flatten_layer(input_tensor)
    assert output_tensor.shape == (2, 12)
    assert torch.all(torch.eq(output_tensor, input_tensor.permute(0, 2, 1).contiguous().view(input_tensor.shape[0], -1)))
    
    # Test backward pass
    output_tensor.sum().backward()
    assert torch.all(torch.eq(input_tensor.grad, output_tensor.grad.permute(0, 2, 1).contiguous().view(input_tensor.shape)))


In [32]:
test_flatten_layer()

  assert torch.all(torch.eq(input_tensor.grad, output_tensor.grad.permute(0, 2, 1).contiguous().view(input_tensor.shape)))


AttributeError: 'NoneType' object has no attribute 'permute'

### Unsqueeze

In [33]:
layer = layers.Unsqueeze(1)
layer_out = layer(x)
layer, x.shape, layer_out.shape

(Unsqueeze(), torch.Size([10, 4, 100]), torch.Size([10, 1, 4, 100]))

In [34]:
def test_unsqueeze_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4)
    
    # Initialize Unsqueeze layer
    unsqueeze_layer = layers.Unsqueeze(1)
    
    # Test forward pass
    output_tensor = unsqueeze_layer(input_tensor)
    assert output_tensor.shape == (2, 1, 3, 4)
    assert torch.all(torch.eq(output_tensor, input_tensor.unsqueeze(1)))
    
    # Test backward pass
    output_tensor.sum().backward()
    assert torch.all(torch.eq(input_tensor.grad, output_tensor.grad.squeeze(1)))


In [35]:
test_unsqueeze_layer()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

### View

In [37]:
layer = layers.View((100, 4))
layer_out = layer(x)
layer, x.shape, layer_out.shape

(View(100, 4), torch.Size([10, 4, 100]), torch.Size([10, 100, 4]))

In [39]:
def test_view_layer():
    # Test input tensor
    input_tensor = torch.randn(2, 3, 4)
    
    # Initialize View layer
    view_layer = layers.View((3, 8))
    
    # Test forward pass
    output_tensor = view_layer(input_tensor)
    assert output_tensor.shape == (2, 3, 8)
    assert torch.all(torch.eq(output_tensor, input_tensor.view(2, 3, 3, 8)))
    
    # Test backward pass
    output_tensor.sum().backward()
    assert torch.all(torch.eq(input_tensor.grad, output_tensor.grad.view(2, 3, 4)))

In [40]:
test_view_layer()

RuntimeError: shape '[2, 3, 8]' is invalid for input of size 24

## Miscallaneous

### Clip

In [41]:
layer = layers.Clip(0, 0.1)
layer_out = layer(x)
layer, x.shape, layer_out.shape, layer_out.min(), layer_out.max()

(Clip(),
 torch.Size([10, 4, 100]),
 torch.Size([10, 4, 100]),
 tensor(0.),
 tensor(0.1000))

In [42]:
def test_clip_forward():
    # Test the forward method of the Clip class
    clip = layers.Clip(min=0, max=1)
    x = torch.tensor([-1, 0.5, 2])
    expected_output = torch.tensor([0, 0.5, 1])
    assert torch.all(torch.eq(clip(x), expected_output))

In [43]:
test_clip_forward()

### RevComp

In [45]:
layer = layers.RevComp()
layer_out = layer(x)
layer, x.shape, layer_out.shape

(RevComp(), torch.Size([10, 4, 100]), torch.Size([10, 4, 100]))

In [46]:
def test_revcomp_forward():
    # Test the forward method of the RevComp class
    revcomp = layers.RevComp(dim=[1,2])
    x = torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
    expected_output = torch.tensor([[[4,3],[2,1]],[[8,7],[6,5]]])
    assert torch.all(torch.eq(revcomp(x), expected_output))

In [47]:
test_revcomp_forward()

# DONE

---

# Scratch