# I have a custom class, let's test it!

In [1]:
import datetime
import math
from typing import ForwardRef

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat

import sys
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna')

import src.models.nn.utils as U
import src.utils as utils
import src.utils.config
from src.models.sequence.block import SequenceResidualBlock
from src.models.nn.components import Normalization
from src.utils.enformer_pytorch import exponential_linspace_int, MaxPool, AttentionPool, ConvBlock, Residual

In [2]:
class Encoder(nn.Module):
    """Encoder abstraction
    Accepts a tensor and optional kwargs. Outside of the main tensor, all other arguments should be kwargs.
    Returns a tensor and optional kwargs.
    Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting kwargs are accumulated and passed into the model backbone.

    """

    def forward(self, x, **kwargs):
        """
        x: input tensor
        *args: additional info from the dataset (e.g. sequence lengths)

        Returns:
        y: output tensor
        *args: other arguments to pass into the model backbone
        """
        return x, {}

In [16]:
class EnformerEncoder(Encoder):
    def __init__(self, d_input, d_model, filter_sizes, flat=False,
                dim = 1536,
                num_downsamples = 7,    # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
                dim_divisible_by = 128,
                use_tf_gamma = False,
                pool_type = 'max',
                conv_tower = False,
                **kwargs,
             ):
        super().__init__()
        
        self.dim = dim
        self.num_downsamples = num_downsamples
        self.dim_divisible_by = dim_divisible_by
        self.pool_type = pool_type
        self.use_conv_tower = conv_tower
        
        Pool = MaxPool if self.pool_type == 'max' else AttentionPool
        half_dim = self.dim // 2
        twice_dim = self.dim * 2
        
        self.stem = nn.Sequential(
            nn.Conv1d(4, half_dim, 15, padding = 7),
            Residual(ConvBlock(half_dim)),
            Pool(half_dim, pool_size = 2)
        )

        filter_list = exponential_linspace_int(half_dim, self.dim, num = (self.num_downsamples - 1), divisible_by = self.dim_divisible_by)
        filter_list = [half_dim, *filter_list]

        conv_layers = []
        for dim_in, dim_out in zip(filter_list[:-1], filter_list[1:]):
            conv_layers.append(nn.Sequential(
                ConvBlock(dim_in, dim_out, kernel_size = 5),
                Residual(ConvBlock(dim_out, dim_out, 1)),
                Pool(dim_out, pool_size = 2)
            ))

        self.conv_tower = nn.Sequential(*conv_layers)
    
    def forward(self, x):
        x = self.stem(x)
        if self.use_conv_tower:
            x = self.conv_tower(x)
        return x, True

In [4]:
#obviously need to add a self.forward but let's just test it directly here! Then we can see the output of both and determine what's needed
#literally just self.stem and self.conv_tower see the outputs!!

#let's define it
encoder = EnformerEncoder(1,1,1,output_heads = dict(human=1), pool_type = 'max', conv_tower = False)
encoder

EnformerEncoder(
  (stem): Sequential(
    (0): Conv1d(4, 768, kernel_size=(15,), stride=(1,), padding=(7,))
    (1): Residual(
      (fn): Sequential(
        (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): GELU()
        (2): Conv1d(768, 768, kernel_size=(1,), stride=(1,))
      )
    )
    (2): MaxPool()
  )
  (conv_tower): Sequential(
    (0): Sequential(
      (0): Sequential(
        (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): GELU()
        (2): Conv1d(768, 768, kernel_size=(5,), stride=(1,), padding=(2,))
      )
      (1): Residual(
        (fn): Sequential(
          (0): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): GELU()
          (2): Conv1d(768, 768, kernel_size=(1,), stride=(1,))
        )
      )
      (2): MaxPool()
    )
    (1): Sequential(
      (0): Sequential(
        (0): BatchNorm1d(768, eps=1e-05, momentum=

In [6]:
#let's make a simple one hot encoded sequence
pattern = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]
sequence = torch.tensor([pattern*224]).float()
sequence

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         ...,
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]])

In [13]:
sequence.shape

torch.Size([1, 896, 4])

In [15]:
#let's input it
out = encoder.stem(sequence.transpose(1,2))

NotImplementedError: Need to test this, particularly to see what the shape of x is. X should be b, d, n. Test this but probably fine?

In [18]:
#I added that error in max pool, let's test it and ensure
out1 = encoder.stem[0](sequence.transpose(1,2))
out1.shape

torch.Size([1, 768, 896])

In [20]:
#turns this into 768 channels which is good, and keeps it the same length with its padding
out2 = encoder.stem[1](out1)
out2.shape #this is just gelu

torch.Size([1, 768, 896])

In [22]:
#and now the next step
class MaxPool(torch.nn.Module):
    def __init__(self, dim = None, pool_size=2, padding='same'): #note dim is taken as a useless argument since only used for compatibility with AttentionPool
        super().__init__()
        self.pool_size = pool_size
        self.padding = padding

    def forward(self, x):
        if self.padding == 'same':
            pad_total = self.pool_size - 1
            pad_left = pad_total // 2
            pad_right = pad_total - pad_left
            x = F.pad(x, (pad_left, pad_right))
            #let's raise an error now to make sure this isn't used because it isn't implemented
            # raise NotImplementedError("Need to test this, particularly to see what the shape of x is. X should be b, d, n. Test this but probably fine?")
        return F.max_pool1d(x, self.pool_size, stride=self.pool_size)
out3 = MaxPool()(out2)

In [23]:
out3.shape #yup reduces the length dimension not the channel, this is good!!
#input should be 1 x 4 x length where 4 is the one hot encoding and length is the length of the sequence

torch.Size([1, 768, 448])

In [9]:
#now we have the output of the model here, what if we look at the conv tower
#fixed the max pool so we can just do the stem
out1 = encoder.stem(sequence.transpose(1,2))
print(out1.shape)

torch.Size([1, 768, 448])


In [10]:
#now check the conv tower
out2 = encoder.conv_tower(out1)
print(out2.shape) #so length 896 turns to 7... so 128x downsampled

torch.Size([1, 1536, 7])


In [13]:
#but the initial stem is only 2k downsample. So we can try both or maybe stick to 1? we have options!
#wait one of the benefits of 128x downsample is that when we do 128 bp we can just do 1x1 conv to get the output...
#so there is a point, maybe we'll try both?
encoder.use_conv_tower = True

In [15]:
out1 = encoder(sequence.transpose(1,2))
out1.shape

torch.Size([1, 1536, 7])

In [17]:
#let's test a new model
encoder = EnformerEncoder(1,1,1, dim = 256,pool_type = 'max', conv_tower = True) #we'll make sure it's the right model dimension!

In [18]:
out1 = encoder(sequence.transpose(1,2))
out1.shape

torch.Size([1, 256, 7])

In [19]:
#let's test it with attention pool
encoder = EnformerEncoder(1,1,1, dim = 256,pool_type = 'attention', conv_tower = True) #we'll make sure it's the right model dimension!

In [20]:
out1 = encoder(sequence.transpose(1,2))
out1.shape

torch.Size([1, 256, 7])

In [21]:
#also let's see what the output of the stem is!
out1 = encoder.stem(sequence.transpose(1,2))
out1.shape
#so it's half, meaning if we just want to use the tem we do 2* the model dimension we want
#but since we're downsamplign can increase model dim

torch.Size([1, 128, 448])

# testing the decoder as well

In [2]:
#so after the main moel part my model is batch x seqlen x dmodel
import sys
sys.path.append('/data/leslie/sarthak/hyena/hyena-dna')
from einops.layers.torch import Rearrange
from src.utils.enformer_pytorch import exponential_linspace_int, MaxPool, AttentionPool, ConvBlock, Residual, GELU
import torch.nn as nn
dropout_rate = 0.1
d_model = 256
half_dim = d_model//2
twice_dim = d_model*2
filter_list = exponential_linspace_int(half_dim, d_model, num = (7 - 1), divisible_by = 128)
filter_list = [half_dim, *filter_list]
final_pointwise = nn.Sequential(
    Rearrange('b n d -> b d n'),
    ConvBlock(filter_list[-1], twice_dim, 1),
    Rearrange('b d n -> b n d'),
    nn.Dropout(dropout_rate / 4),
    GELU()
)

In [4]:
#now let's make a test samplee
import torch
seq = torch.randn(1,1024,256)
#now let's see what happens if we pass it through the final pointwise
out = final_pointwise(seq)

In [5]:
print(out.shape)

torch.Size([1, 1024, 512])


In [6]:
#actually before that we have this 
class TargetLengthCrop(nn.Module):
    def __init__(self, target_length):
        super().__init__()
        self.target_length = target_length

    def forward(self, x):
        seq_len, target_len = x.shape[-2], self.target_length

        if target_len == -1:
            return x

        if seq_len < target_len:
            raise ValueError(f'sequence length {seq_len} is less than target length {target_len}')

        trim = (target_len - seq_len) // 2

        if trim == 0:
            return x

        return x[:, -trim:trim]

target_length = 896
crop = TargetLengthCrop(target_length)

In [7]:
seq = torch.randn(1,1024,256)
out = crop(seq)
print(out.shape)
out = final_pointwise(out)
print(out.shape)

torch.Size([1, 896, 256])
torch.Size([1, 896, 512])


In [None]:
#then after that convolution it does a linear