In [1]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torch import Tensor
from torch.nn import functional as F

In [2]:
# ENCODER PARAMS
# Encoder inputs
enc_n_feats = 256   # spectrogram height
enc_d_model = 1024  # spectrogram max width

# DECODER PARAMS
# Decoder inputs--
dec_n_feats = 38    # len(tokens)
dec_d_model = 152   # maximum symbols in sentence

In [71]:
class Swish(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

    
class AbsolutePositionEncoding(nn.Module):
    def __init__(self):
        super().__init__()
        
    def pos_f(self, row, column, emb_dim):
        func = (np.sin, np.cos)[row % 2]
        w_k = 1/np.power(10000, 2*row/emb_dim)
        pe_i_j = func(column * w_k)
        return torch.Tensor([pe_i_j])
    
    def position_encoding(self, X):
        b, _, h, w = X.shape
        pe = torch.zeros((b, h, w))
        for k in range(b):
            for i in range(h):
                for j in range(w):
                    pe[k][i][j] = self.pos_f(i, j, h)
                
        pe = pe.reshape(b, 1, h, w)
        return pe
    
    def forward(self, x):
        PE = self.position_encoding(x)
        return PE
                
                
class LFFN(nn.Module):
    def __init__(self, inputs_dim, dim_hid):
        """
        Args:
        inputs_dim - tuple, 
            (N, C, H, W) of inputs
        dim_hid - int, 
            number of hidden units
        """
        super().__init__()
        _, _, dim_bn, dim = inputs_dim
        self.E1 = nn.Linear(in_features=dim, out_features=dim_bn, bias=False)
        self.D1 = nn.Linear(in_features=dim_bn, out_features=dim_hid, bias=False)
        self.swish = Swish()
        self.dropout = nn.Dropout(0.5)
        self.E2 = nn.Linear(in_features=dim_hid, out_features=dim_bn, bias=False)
        self.D2 = nn.Linear(in_features=dim_bn, out_features=dim, bias=False)
        
    def forward(self, inputs):
        x = self.E1(inputs)
        x = self.D1(x)
        x = self.swish(x)
        x = self.dropout(x)
        x = self.E2(x)
        y = self.D2(x)
        return y
    

class LightAttention(nn.Module):
    def __init__(self, dim_input, dim_q, dim_k):
        super().__init__()
        self.softmax_col = nn.Softmax(dim=-1)
        self.softmax_row = nn.Softmax(dim=-2)
        
        self.W_q = nn.Linear(in_features=dim_input, out_features=dim_q)
        self.W_k = nn.Linear(in_features=dim_input, out_features=dim_k)
        self.W_v = nn.Linear(in_features=dim_input, out_features=dim_k)
        
        self.d_q = torch.pow(torch.Tensor([dim_q]), 1/4)
        self.d_k = torch.pow(torch.Tensor([dim_k]), 1/4)
        
    def forward(self, x_q, x_k, x_v):
        Q = self.W_q(x_q)
        K = self.W_k(x_k)
        V = self.W_v(x_v)
        
        A = self.softmax_row(Q / self.d_q)
        B = torch.matmul(self.softmax_col(K / self.d_k), V)
        Z = torch.matmul(A, B)
        
        return Z


class MHLA(nn.Module):
    def __init__(self, 
                 num_heads, 
                 dim_input,
                 dim_q,
                 dim_k
                ):
        super().__init__()
        heads = [LightAttention(dim_input, dim_q, dim_k) for _ in range(num_heads)]
        self.heads = nn.ModuleList(heads)                
        self.W_o = nn.Linear(dim_k*num_heads, dim_input)
        
    def forward(self, x_q, x_k, x_v):
        x = torch.cat([latt(x_q, x_k, x_v) for latt in self.heads], dim=-1)
        y = self.W_o(x)
        return y
    
    
class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()
    
    
class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.conv1 = nn.Conv2d(chan_in, chan_in, kernel_size=(1, kernel_size), padding=(0, padding))
        self.conv2 = nn.Conv2d(chan_in, chan_out, kernel_size=(1, 1))

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return x
    
    
class PointWiseConv(nn.Module):
    def __init__(self, chan_in):
        super().__init__()
        self.pw_conv = nn.Conv2d(in_channels=chan_in, out_channels=1, kernel_size=1)
        
    def forward(self, inputs):
        x = self.pw_conv(inputs)
        return x
    
    
class ConvModule(nn.Module):
    def __init__(self, dim_C, dim_H, dim_W, dropout=0.3):
        super().__init__()
        self.ln1 = nn.LayerNorm([dim_H, dim_W*dim_C])
        self.pw_conv1 = PointWiseConv(chan_in=dim_C)
        self.glu = GLU(-2)
        self.dw_conv1d = DepthWiseConv1d(dim_C, dim_C*2, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(dim_C)
        self.swish = Swish()
        self.pw_conv2 = PointWiseConv(chan_in=dim_C)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, inputs):
        x = inputs
        b, c, h, w = x.shape
        x = x.reshape(b, h, w*c)
        
        x = self.ln1(x)
        x = x.reshape(b, c, h, w)
        
        x = self.pw_conv1(x)
        x = self.glu(x)
        x = self.dw_conv1d(x)
        x = x.reshape(b, c, -1, w)
        
        x = self.bn(x)
        x = self.swish(x)
        x = self.pw_conv2(x)
        x = self.dropout(x)
        return x
    
    
class LAC(nn.Module):
    def __init__(self, dim_B=1, dim_C=1, dim_H=64, dim_W=256):
        super().__init__()
        self.lffn1 = LFFN(inputs_dim=(dim_B, dim_C, dim_H, dim_W), dim_hid=1024)
        self.mhlsa = MHLA(num_heads=4, dim_input=dim_W, dim_q=64, dim_k=64)    
        self.conv_module = ConvModule(dim_C, dim_H, dim_W)
        self.lffn2 = LFFN(inputs_dim=(dim_B, dim_C, dim_H, dim_W), dim_hid=1024)
        self.ln = nn.LayerNorm([dim_C, dim_H, dim_W])
        
    def forward(self, inputs):
        x = inputs
        x = x + 1/2 * self.lffn1(x)
        x = x + self.mhlsa(x, x, x)
        x = x + self.conv_module(x)
        x = x + 1/2 * self.lffn2(x)
        x = self.ln(x)
        return inputs + x
    
    
class Encoder(nn.Module):
    def __init__(self, lacs_n=2):
        super().__init__()
        self.lacs = nn.Sequential(*[LAC() for i in range(lacs_n)])
        
    def forward(self, inputs):
        x = self.lacs(inputs)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, memory, outputs):
        x = outputs
        return x

In [72]:
class Model(nn.Module):
    def __init__(self, lacs_n=2):
        super().__init__()
        self.input_preprocess = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=2),
            Swish(),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1),
            Swish()
        )
        self.pos_enc_inp = AbsolutePositionEncoding()
        self.pos_enc_out = AbsolutePositionEncoding()
        self.encoder = Encoder(lacs_n)
        self.decoder = Decoder()
        
    def forward(self, inputs, tgt):
        x = self.input_preprocess(inputs)
        print("after input_preprocess:", x)
        x = x + self.pos_enc_inp(x)
        print("after pos_enc:", x)
        x = self.encoder(x)
        print("after LACs:", x)
        
        y = self.pos_enc_out(tgt)
        y = self.decoder(x, y)
        
        return x
        
        
src = torch.rand(1, 1, 256, 1024)
tgt = torch.rand(1, 1, 38, 152)

model = Model()
result = model(src, tgt)
result.shape

after input_preprocess: tensor([[[[-0.0396, -0.0196, -0.0582,  ..., -0.0689, -0.0155, -0.0364],
          [-0.0602, -0.0462, -0.0494,  ...,  0.0083, -0.0743, -0.0553],
          [-0.0684, -0.0234, -0.0437,  ..., -0.0251, -0.0333, -0.0617],
          ...,
          [-0.0782, -0.0426, -0.0447,  ..., -0.0545, -0.0551, -0.0296],
          [-0.0353, -0.0024,  0.0077,  ..., -0.0430, -0.0614, -0.0359],
          [-0.0940, -0.0447, -0.0509,  ..., -0.0822, -0.0787, -0.0115]]]],
       grad_fn=<MulBackward0>)
after pos_enc: tensor([[[[-0.0396,  0.8219,  0.8511,  ...,  0.9260,  0.4365, -0.5428],
          [ 0.9398,  0.6855,  0.0215,  ...,  0.3447, -0.4700, -0.9708],
          [-0.0684,  0.5098,  0.8584,  ..., -0.8089, -1.0275, -0.9601],
          ...,
          [ 0.9218,  0.9574,  0.9553,  ...,  0.9455,  0.9449,  0.9704],
          [-0.0353, -0.0024,  0.0077,  ..., -0.0429, -0.0614, -0.0359],
          [ 0.9060,  0.9553,  0.9491,  ...,  0.9178,  0.9213,  0.9885]]]],
       grad_fn=<AddBackward0>)

torch.Size([1, 1, 64, 256])

In [65]:
print(256 >> 2 == 256 / 4)
print(1024 >> 2 == 1024 / 4)

True
True


In [None]:
X = torch.rand(1, 1, 64, 256)
E = torch.rand(256, 64)
D = torch.rand(64, 1024)
1/2 * torch.matmul(torch.matmul(X, E), D)

In [None]:
m = nn.Softmax(dim=-1) # softmax row
m = nn.Softmax(dim=-2) # softmax col
input = torch.randn(2, 3)
print(input)
output = m(input)
print(output)

In [None]:
torch.pow(torch.Tensor(16), 1/4)

In [None]:
a = torch.rand(8,9,10,11)
print(a.transpose(3,2).shape)
print(a.permute(3,2,1,0).size())

In [None]:
torch.cuda.device_count()

In [None]:
A = torch.rand(1, 1, 64, 64)
B = torch.rand(1, 1, 64, 64)

torch.cat([A, B], dim=-1).shape

In [None]:
A = torch.rand(1, 1, 64, 64)
A1, A2 = A.chunk(2, dim=-1)
A1.shape, A2.shape

In [None]:
from torch.nn import functional as F
A = torch.rand(1, 1, 64, 256)
print(F.glu(A, -1).shape)

glu = torch.nn.GLU()
print(glu(A).shape)

In [None]:
A = torch.rand(2, 1, 64, 256)

b, c, h, w = A.shape
A = A.reshape(b, h, -1)
A.shape

In [57]:
n = 2
A = torch.rand(64, 256)
linears = nn.Sequential(*[nn.Linear(256, 256) for i in range(n)])
linears(A).shape

torch.Size([64, 256])