In [4]:
# autoreload
%load_ext autoreload
%autoreload 2

In [10]:
import torch
from eugene import models

In [13]:
from eugene.models.base import _layers as layers

In [21]:
10 // 4

2

In [50]:
test_leo = layers.InceptionConv1D(4, 12)

In [51]:
x = torch.randn(1, 4, 1000)

In [52]:
test_leo(x).shape

torch.Size([1, 12, 1000])

In [58]:
test_leo = layers.InceptionConv1D(
    in_channels=4,
    out_channels=4,
    kernel_size2=4,
    kernel_size3=8,
    conv_maxpool_kernel_size=3,
)

In [59]:
test_leo_w_res = layers.Residual(test_leo)

In [60]:
test_leo_w_res(x).shape

torch.Size([1, 4, 1000])

In [54]:
test_leo(x).shape

torch.Size([1, 64, 1000])

In [48]:
test = torch.nn.MaxPool1d(kernel_size=9, stride=1, padding=(9 // 2))

In [49]:
test(x).shape

torch.Size([1, 4, 1000])

In [37]:
2 // 2

1

In [36]:
test_leo.maxpool(x).shape

torch.Size([1, 4, 1001])

In [74]:
test_mha = layers.MultiHeadAttention(
    input_dim = 1000,
    head_dim = 64,
    num_heads = 8,
    dropout_rate = 0.1
)

In [75]:
x = torch.randn(1, 4, 1000)

In [83]:
test_mha(test_leo_w_res(x))

tensor([[[-0.0898,  0.2110,  0.2210,  ..., -0.1978,  0.1926,  0.3191],
         [-0.0787, -0.0740,  0.0343,  ..., -0.0000,  0.0000,  0.0000],
         [-0.1636,  0.1576,  0.2302,  ..., -0.1308,  0.1619,  0.0623],
         [-0.1708,  0.3417,  0.2593,  ..., -0.0605,  0.1148,  0.0000]]],
       grad_fn=<MulBackward0>)

In [76]:
test_mha

MultiHeadAttention(
  (qkv): Linear(in_features=1000, out_features=1536, bias=False)
  (softmax): Softmax(dim=-1)
  (dropout_layer): Dropout(p=0.1, inplace=False)
  (projection_layer): Sequential(
    (0): Linear(in_features=512, out_features=1000, bias=True)
    (1): Dropout(p=0.1, inplace=False)
  )
)

In [77]:
test_mha.projection_dim

512

In [81]:
test_mha(x).shape

torch.Size([1, 4, 1000])

In [84]:
import torch.nn as nn

In [92]:
from eugene.models.base import _blocks as blocks

In [160]:
class Satori(nn.Module):
        
    def __init__(
        self, 
        input_len, 
        output_dim,
        conv_kwargs: dict = {},
        mha_kwargs: dict = {},
        dense_kwargs: dict = {},
        #task: str = "regression",
        #loss_fxn: str ="mse",
    ):
        super().__init__()
        
        self.input_len = input_len
        self.output_dim = output_dim
        self.conv_kwargs, self.mha_kwargs, self.dense_kwargs = self.kwarg_handler(conv_kwargs, mha_kwargs, dense_kwargs)
        
        self.conv_block = blocks.Conv1DBlock(**self.conv_kwargs)
        self.mha_layer = layers.MultiHeadAttention(input_dim=self.conv_block.output_size[-1], **self.mha_kwargs)
        self.flatten = nn.Flatten()
        self.dense_block = blocks.DenseBlock(input_dim=self.conv_block.output_channels*self.conv_block.output_size[-1], **self.dense_kwargs)
        
    def forward(self, x):
        x = self.conv_block(x)
        print(x.shape)
        x = self.mha_layer(x)
        x = self.flatten(x)
        x = self.dense_block(x)
        return x
    
    def kwarg_handler(self, conv_kwargs, mha_kwargs, dense_kwargs):
        """Sets default kwargs for conv and fc modules if not specified"""
        conv_kwargs.setdefault("input_len", self.input_len)
        conv_kwargs.setdefault("input_channels", 4)
        conv_kwargs.setdefault("output_channels", 320)
        conv_kwargs.setdefault("conv_kernel", 26)
        conv_kwargs.setdefault("conv_padding", "same")
        conv_kwargs.setdefault("norm_type", "batchnorm")
        conv_kwargs.setdefault("activation", "relu")
        conv_kwargs.setdefault("conv_bias", False)
        conv_kwargs.setdefault("pool_type", "max")
        conv_kwargs.setdefault("pool_kernel", 3)
        conv_kwargs.setdefault("pool_padding", 1)
        conv_kwargs.setdefault("dropout_rate", 0.2)
        conv_kwargs.setdefault("order", "conv-norm-act-pool-dropout")
        mha_kwargs.setdefault("head_dim", 64)
        mha_kwargs.setdefault("num_heads", 8)
        mha_kwargs.setdefault("dropout_rate", 0.1)
        dense_kwargs.setdefault("hidden_dims", [])
        dense_kwargs.setdefault("output_dim", self.output_dim)
        return conv_kwargs, mha_kwargs, dense_kwargs

In [161]:
test_satori = Satori(1000, 1).to("cpu")

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


In [162]:
test_satori

Satori(
  (conv_block): Conv1DBlock(
    (layers): Sequential(
      (conv): Conv1d(4, 320, kernel_size=(26,), stride=(1,), padding=same, bias=False)
      (norm): BatchNorm1d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
      (pool): MaxPool1d(kernel_size=3, stride=3, padding=1, dilation=1, ceil_mode=False)
      (dropout): Dropout(p=0.2, inplace=False)
    )
  )
  (mha_layer): MultiHeadAttention(
    (qkv): Linear(in_features=334, out_features=1536, bias=False)
    (softmax): Softmax(dim=-1)
    (dropout_layer): Dropout(p=0.1, inplace=False)
    (projection_layer): Sequential(
      (0): Linear(in_features=512, out_features=334, bias=True)
      (1): Dropout(p=0.1, inplace=False)
    )
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (dense_block): DenseBlock(
    (layers): Sequential(
      (0): Linear(in_features=334, out_features=1, bias=True)
    )
  )
)

In [164]:
x = torch.randn(1, 4, 1000)

In [165]:
test_satori(x)

torch.Size([1, 320, 334])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x106880 and 334x1)

In [90]:
test_satori(x).shape

torch.Size([1, 1])

In [None]:
class DarthMaul