In [1]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torch import Tensor
from torch.nn import functional as F

In [2]:
# ENCODER PARAMS
# Encoder inputs
enc_n_feats = 256   # spectrogram height
enc_d_model = 1024  # spectrogram max width

# DECODER PARAMS
# Decoder inputs--
dec_n_feats = 38    # len(tokens)
dec_d_model = 152   # maximum symbols in sentence

In [115]:
class Swish(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

    
class AbsolutePositionEncoding(nn.Module):
    def __init__(self):
        super().__init__()
        
    def pos_f(self, row, column, emb_dim):
        func = (np.sin, np.cos)[row % 2]
        w_k = 1/np.power(10000, 2*row/emb_dim)
        pe_i_j = func(column * w_k)
        return torch.Tensor([pe_i_j])
    
    def position_encoding(self, X):
        b, _, h, w = X.shape
        pe = torch.zeros((b, h, w))
        for k in range(b):
            for i in range(h):
                for j in range(w):
                    pe[k][i][j] = self.pos_f(i, j, h)
                
        pe = pe.reshape(b, 1, h, w)
        return pe
    
    def forward(self, x):
        PE = self.position_encoding(x)
        return PE
                
                
class LFFN(nn.Module):
    def __init__(self, inputs_dim, dim_hid):
        """
        Args:
        inputs_dim - tuple, 
            (N, C, H, W) of inputs
        dim_hid - int, 
            number of hidden units
        """
        super().__init__()
        _, _, dim_bn, dim = inputs_dim
        self.E1 = nn.Linear(in_features=dim, out_features=dim_bn, bias=False)
        self.D1 = nn.Linear(in_features=dim_bn, out_features=dim_hid, bias=False)
        self.swish = Swish()
        self.dropout = nn.Dropout(0.5)
        self.E2 = nn.Linear(in_features=dim_hid, out_features=dim_bn, bias=False)
        self.D2 = nn.Linear(in_features=dim_bn, out_features=dim, bias=False)
        
    def forward(self, inputs):
        x = self.E1(inputs)
        x = self.D1(x)
        x = self.swish(x)
        x = self.dropout(x)
        x = self.E2(x)
        y = self.D2(x)
        return y
    

class LightAttention(nn.Module):
    def __init__(self, dim_input_q, dim_input_kv, dim_q, dim_k, with_mask=False):
        super().__init__()
        self.with_mask = with_mask
        self.softmax_col = nn.Softmax(dim=-1)
        self.softmax_row = nn.Softmax(dim=-2)
        
        self.W_q = nn.Linear(in_features=dim_input_q, out_features=dim_q)
        self.W_k = nn.Linear(in_features=dim_input_kv, out_features=dim_k)
        self.W_v = nn.Linear(in_features=dim_input_kv, out_features=dim_k)
        
        self.d_q = torch.pow(torch.Tensor([dim_q]), 1/4)
        self.d_k = torch.pow(torch.Tensor([dim_k]), 1/4)
        
    def mask(self, dim: (int, int)) -> Tensor :
        a, b = dim
        mask = torch.ones(b, a)
        mask = torch.triu(mask, diagonal=0)
        mask = torch.log(mask.T)
        return mask
        
    def forward(self, x_q, x_k, x_v):
        Q = self.W_q(x_q)
        K = self.W_k(x_k)
        V = self.W_v(x_v)
        
        if self.with_mask == True:
            Q += self.mask(Q.shape[-2:])
        
        A = self.softmax_row(Q / self.d_q)
        B = torch.matmul(self.softmax_col(K.transpose(-2, -1) / self.d_k), V)
        Z = torch.matmul(A, B)
        
        return Z


class MHLA(nn.Module):
    def __init__(self, 
                 num_heads, 
                 dim_input_q,
                 dim_input_kv,
                 dim_q = 64,
                 dim_k = 64,
                 mask=False
                ):
        """
        Args:
        dim_input - if shape is (B, C, H, W), then dim_input is W
        """
        super().__init__()
        heads = [LightAttention(dim_input_q, dim_input_kv, dim_q, dim_k, mask) for _ in range(num_heads)]
        self.heads = nn.ModuleList(heads)                
        self.W_o = nn.Linear(dim_k*num_heads, dim_input_kv)
        
    def forward(self, x_q, x_k, x_v):
        x = torch.cat([latt(x_q, x_k, x_v) for latt in self.heads], dim=-1)
        y = self.W_o(x)
        return y
    
    
class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()
    
    
class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.conv1 = nn.Conv2d(chan_in, chan_in, kernel_size=(1, kernel_size), padding=(0, padding))
        self.conv2 = nn.Conv2d(chan_in, chan_out, kernel_size=(1, 1))

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return x
    
    
class PointWiseConv(nn.Module):
    def __init__(self, chan_in):
        super().__init__()
        self.pw_conv = nn.Conv2d(in_channels=chan_in, out_channels=1, kernel_size=1)
        
    def forward(self, inputs):
        x = self.pw_conv(inputs)
        return x
    
    
class ConvModule(nn.Module):
    def __init__(self, dim_C, dim_H, dim_W, dropout=0.3):
        super().__init__()
        self.ln1 = nn.LayerNorm([dim_H, dim_W*dim_C])
        self.pw_conv1 = PointWiseConv(chan_in=dim_C)
        self.glu = GLU(-2)
        self.dw_conv1d = DepthWiseConv1d(dim_C, dim_C*2, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(dim_C)
        self.swish = Swish()
        self.pw_conv2 = PointWiseConv(chan_in=dim_C)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, inputs):
        x = inputs
        b, c, h, w = x.shape
        x = x.reshape(b, h, w*c)
        
        x = self.ln1(x)
        x = x.reshape(b, c, h, w)
        
        x = self.pw_conv1(x)
        x = self.glu(x)
        x = self.dw_conv1d(x)
        x = x.reshape(b, c, -1, w)
        
        x = self.bn(x)
        x = self.swish(x)
        x = self.pw_conv2(x)
        x = self.dropout(x)
        return x
    
    
class LAC(nn.Module):
    def __init__(self, dim_B=1, dim_C=1, dim_H=64, dim_W=256):
        super().__init__()
        self.lffn1 = LFFN(inputs_dim=(dim_B, dim_C, dim_H, dim_W), dim_hid=1024)
        self.mhlsa = MHLA(num_heads=4, dim_input_q=dim_W, dim_input_kv=dim_W)    
        self.conv_module = ConvModule(dim_C, dim_H, dim_W)
        self.lffn2 = LFFN(inputs_dim=(dim_B, dim_C, dim_H, dim_W), dim_hid=1024)
        self.ln = nn.LayerNorm([dim_C, dim_H, dim_W])
        
    def forward(self, inputs):
        x = inputs
        x = x + 1/2 * self.lffn1(x)
        x = x + self.mhlsa(x, x, x)
        x = x + self.conv_module(x)
        x = x + 1/2 * self.lffn2(x)
        x = self.ln(x)
        return inputs + x
    
    
class Encoder(nn.Module):
    def __init__(self, lacs_n=2):
        super().__init__()
        self.lacs = nn.Sequential(*[LAC() for i in range(lacs_n)])
        
    def forward(self, inputs):
        x = self.lacs(inputs)
        return x
    

class DecoderBlock(nn.Module):
    def __init__(self, dim_shape_tgt, dim_shape_mem):
        super().__init__()
        dim_B, dim_C, dim_H, dim_tgt = dim_shape_tgt
        dim_mem = dim_shape_mem[-1]
        self.mhla_with_mask = MHLA(num_heads=2, dim_input_q=dim_tgt, dim_input_kv=dim_tgt)
        self.ln1 = nn.LayerNorm([dim_C, dim_H, dim_tgt])
        self.mhla_with_memory = MHLA(num_heads=2, dim_input_q=dim_tgt, dim_input_kv=dim_mem, mask=True)
        self.ln2 = nn.LayerNorm([dim_C, dim_H, dim_mem])
        self.lffn = LFFN(inputs_dim=(dim_B, dim_C, dim_H, dim_mem), dim_hid=1024)
        self.ln3 = nn.LayerNorm([dim_C, dim_H, dim_mem])
        
    def forward(self, mem, y):
        y = y + self.mhla_with_mask(y, y, y)
        y = self.ln1(y)
        y = y + self.mhla_with_memory(y, mem, mem)
        y = self.ln2(y)
        y = y + self.lffn(y)
        y = self.ln3(y)
        return y
    
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        # TODO: + linar modules to reshape decoder target to (1, 1, 64, 256)
        self.dec_blocks = DecoderBlock(dim_shape_tgt=(1, 1, 38, 152), dim_shape_mem=(1, 1, 38, 256))
        self.dec_blocks1 = DecoderBlock(dim_shape_tgt=(1, 1, 38, 152), dim_shape_mem=(1, 1, 38, 256))
        
    def forward(self, mem, tgt):
        y = tgt
        y = self.dec_blocks(mem, y)
        y = self.dec_blocks1(mem, y)
        print(f"Outputs shape: {y.shape}")
        return y

In [116]:
class Model(nn.Module):
    def __init__(self, lacs_n=2):
        super().__init__()
        self.input_preprocess = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=2),
            Swish(),
            nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=2, padding=1),
            Swish()
        )
        self.pos_enc_inp = AbsolutePositionEncoding()
        self.pos_enc_out = AbsolutePositionEncoding()
        self.encoder = Encoder(lacs_n)
        self.decoder = Decoder()
        
    def forward(self, inputs, tgt):
        x = self.input_preprocess(inputs)
        print("after input_preprocess:", x)
        x = x + self.pos_enc_inp(x)
        print("after pos_enc:", x)
        x = self.encoder(x)
        print("after LACs:", x)
        
        y = self.pos_enc_out(tgt)
        y = self.decoder(x, y)
        
        return x
        
        
src = torch.rand(1, 1, 256, 1024)
tgt = torch.rand(1, 1, 38, 152)

model = Model()
result = model(src, tgt)
result.shape

after input_preprocess: tensor([[[[-0.0876, -0.0960, -0.1383,  ..., -0.1009, -0.0930, -0.1066],
          [-0.1115, -0.1203, -0.1131,  ..., -0.0719, -0.0952, -0.0917],
          [-0.1087, -0.1045, -0.1353,  ..., -0.0994, -0.0794, -0.0857],
          ...,
          [-0.0833, -0.1275, -0.1126,  ..., -0.1012, -0.1118, -0.0886],
          [-0.1125, -0.1299, -0.1230,  ..., -0.1233, -0.0888, -0.1110],
          [-0.1021, -0.1036, -0.1086,  ..., -0.0877, -0.0808, -0.0915]]]],
       grad_fn=<MulBackward0>)
after pos_enc: tensor([[[[-0.0876,  0.7455,  0.7710,  ...,  0.8940,  0.3590, -0.6130],
          [ 0.8885,  0.6114, -0.0421,  ...,  0.2645, -0.4908, -1.0072],
          [-0.1087,  0.4286,  0.7668,  ..., -0.8831, -1.0736, -0.9842],
          ...,
          [ 0.9167,  0.8725,  0.8874,  ...,  0.8988,  0.8882,  0.9114],
          [-0.1125, -0.1299, -0.1230,  ..., -0.1233, -0.0888, -0.1110],
          [ 0.8979,  0.8964,  0.8914,  ...,  0.9123,  0.9192,  0.9085]]]],
       grad_fn=<AddBackward0>)

RuntimeError: The size of tensor a (152) must match the size of tensor b (256) at non-singleton dimension 3

In [5]:
print(256 >> 2 == 256 / 4)
print(1024 >> 2 == 1024 / 4)

True
True


In [6]:
X = torch.rand(1, 1, 64, 256)
E = torch.rand(256, 64)
D = torch.rand(64, 1024)
1/2 * torch.matmul(torch.matmul(X, E), D)

tensor([[[[1131.3501, 1084.0927, 1065.3154,  ..., 1081.5212, 1017.0113,
           1042.0787],
          [1131.8226, 1081.1790, 1062.1481,  ..., 1078.3734, 1015.0567,
           1039.3541],
          [1136.4053, 1086.8667, 1068.2117,  ..., 1084.2784, 1019.6134,
           1045.2657],
          ...,
          [1088.2290, 1044.4198, 1022.9773,  ..., 1040.7019,  977.4701,
            998.2241],
          [1076.0087, 1030.0990, 1010.1744,  ..., 1024.8882,  965.7994,
            986.8787],
          [1092.4478, 1049.4794, 1028.0149,  ..., 1046.0243,  983.2220,
           1005.9284]]]])

In [7]:
m = nn.Softmax(dim=-1) # softmax row
m = nn.Softmax(dim=-2) # softmax col
input = torch.randn(2, 3)
print(input)
output = m(input)
print(output)

tensor([[-0.3066,  1.0177, -0.7403],
        [-0.1246,  1.4291, -0.0796]])
tensor([[0.4546, 0.3986, 0.3406],
        [0.5454, 0.6014, 0.6594]])


In [8]:
torch.pow(torch.Tensor(16), 1/4)

tensor([3.0318e-10, 3.1558e-10, 3.2400e-10, 2.9185e-10, 3.1034e-10, 3.1987e-10,
        3.2264e-10, 1.9396e-11, 2.9185e-10, 2.8999e-10, 2.9092e-10, 2.9458e-10,
        2.8416e-10, 2.7796e-10, 2.8316e-10, 2.8904e-10])

In [9]:
a = torch.rand(8,9,10,11)
print(a.transpose(3,2).shape)
print(a.permute(3,2,1,0).size())

torch.Size([8, 9, 11, 10])
torch.Size([11, 10, 9, 8])


In [10]:
torch.cuda.device_count()

1

In [11]:
A = torch.rand(1, 1, 64, 64)
B = torch.rand(1, 1, 64, 64)

torch.cat([A, B], dim=-1).shape

torch.Size([1, 1, 64, 128])

In [12]:
A = torch.rand(1, 1, 64, 64)
A1, A2 = A.chunk(2, dim=-1)
A1.shape, A2.shape

(torch.Size([1, 1, 64, 32]), torch.Size([1, 1, 64, 32]))

In [13]:
from torch.nn import functional as F
A = torch.rand(1, 1, 64, 256)
print(F.glu(A, -1).shape)

glu = torch.nn.GLU()
print(glu(A).shape)

torch.Size([1, 1, 64, 128])
torch.Size([1, 1, 64, 128])


In [14]:
A = torch.rand(2, 1, 64, 256)

b, c, h, w = A.shape
A = A.reshape(b, h, -1)
A.shape

torch.Size([2, 64, 256])

In [15]:
n = 2
A = torch.rand(64, 256)
linears = nn.Sequential(*[nn.Linear(256, 256) for i in range(n)])
linears(A).shape

torch.Size([64, 256])

In [57]:
def mask(dim: (int, int)) -> Tensor :
    mask = torch.ones(dim)
    mask = torch.triu(mask, diagonal=0)
    mask = torch.log(mask.T)
    return mask


A = torch.ones(10, 10)
m = mask(A.shape)
A += m
A.transpose(-2, -1)

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [-inf, 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [-inf, -inf, 1., 1., 1., 1., 1., 1., 1., 1.],
        [-inf, -inf, -inf, 1., 1., 1., 1., 1., 1., 1.],
        [-inf, -inf, -inf, -inf, 1., 1., 1., 1., 1., 1.],
        [-inf, -inf, -inf, -inf, -inf, 1., 1., 1., 1., 1.],
        [-inf, -inf, -inf, -inf, -inf, -inf, 1., 1., 1., 1.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, 1., 1., 1.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 1., 1.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 1.]])

In [72]:
Y = torch.rand(38, 152)
X = torch.rand(64, 256)

W_q = nn.Linear(in_features=152, out_features=64)
W_k = nn.Linear(in_features=256, out_features=64)
W_v = nn.Linear(in_features=256, out_features=64)

Q = W_q(Y)
K = W_k(X)
V = W_v(X)

print(f"Q shape: {Q.shape}")
print(f"K, V shape: {K.shape}")

output = torch.matmul(torch.matmul(Q, K.T), V)
output.shape

Q shape: torch.Size([38, 64])
K, V shape: torch.Size([64, 64])


torch.Size([38, 64])

In [85]:
A = torch.rand(1, 1, 64, 256)
b, c = A.shape[-2:]
b, c

(64, 256)

In [98]:
a = torch.Tensor([1])

torch.exp(-a/0)

tensor([0.])