In [1]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torch import Tensor
from torch.nn import functional as F

import matplotlib.pyplot as plt

In [40]:
class Swish(nn.Module):
    def forward(self, x):
        return x * x.sigmoid()

# YEP
class EncoderInputsProc(nn.Module):
    def __init__(self, d_inputs, d_model, device="cpu"):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv1d(d_inputs, 256, kernel_size=7, stride=3),
            nn.ReLU(),
            nn.Conv1d(256, d_model, kernel_size=7, stride=3),
            nn.ReLU()
        ).to(device)
        
    def forward(self, X):
        out = X.transpose(-1, -2).contiguous()
        out = self.convs(out)
        out = out.transpose(-1, -2).contiguous()
        return out

# YEP
class DecoderInputsProc(nn.Module):
    def __init__(self, d_inputs, d_model, device="cpu"):
        super().__init__()
        self.lin = nn.Sequential(
            nn.Linear(d_inputs, d_model),
            nn.ReLU()
        ).to(device)
                
    def forward(self, X):
        out = self.lin(X)
        return out
    
# YEP
class AbsolutePositionEncoding(nn.Module):
    def __init__(self):
        super().__init__()
        
    def pos_f(self, row, column, emb_dim):
        func = (np.sin, np.cos)[column % 2]
        w_k = 1/np.power(10000, 2*column/emb_dim)
        pe_i_j = func(row * w_k)
        return torch.Tensor([pe_i_j])
    
    def position_encoding(self, X):
        assert len(X.shape) >= 3, "X shape must have more then 3 dimension"
        b = X.shape[0]
        h = X.shape[-2]
        w = X.shape[-1]
        pe = torch.zeros((b, h, w))
        for k in range(b):
            for i in range(h):
                for j in range(w):
                    pe[k][i][j] = self.pos_f(i, j, h)
                
        pe = pe.reshape(b, h, w)
        return pe
    
    def forward(self, x):
        PE = self.position_encoding(x)
        return PE


# YEP
class LFFN(nn.Module):
    def __init__(self, dim, dim_bn, dim_hid):
        """
        Args:
        dim_bn - int,
            bottleneck dimention
        dim - int,
            dim of input data
        dim_hid - int,
            number of hidden units
        """
        super().__init__()
        self.E1 = nn.Linear(in_features=dim, out_features=dim_bn, bias=False)
        self.D1 = nn.Linear(in_features=dim_bn, out_features=dim_hid, bias=False)
        self.swish = Swish()
        self.dropout = nn.Dropout(0.5)
        self.E2 = nn.Linear(in_features=dim_hid, out_features=dim_bn, bias=False)
        self.D2 = nn.Linear(in_features=dim_bn, out_features=dim, bias=False)

    def forward(self, inputs):
        x = self.E1(inputs)
        x = self.D1(x)
        x = self.swish(x)
        x = self.dropout(x)
        x = self.E2(x)
        y = self.D2(x)
        return y

# YEP
class MHLA2(nn.Module):
    def __init__(self,
                 num_heads,
                 dim_input_q,
                 dim_input_kv,
                 device="cpu",
                 mask=False
                 ):
        """
        Args:

        """
        super().__init__()
        assert dim_input_q % num_heads == 0, "dim_input_q must be devided on num_heads"
        assert dim_input_kv % num_heads == 0, "dim_input_kv must be devided on num_heads"
        dim_q = dim_input_q // num_heads
        dim_k = dim_input_kv // num_heads
        self.device = device
        self.with_mask = mask
        self.W_Q = torch.ones((num_heads, dim_input_q, dim_q), device=device, requires_grad=True)
        self.W_K = torch.ones((num_heads, dim_input_kv, dim_q), device=device, requires_grad=True)
        self.W_V = torch.ones((num_heads, dim_input_kv, dim_q), device=device, requires_grad=True)
        self.W_O = nn.Linear(dim_k * num_heads, dim_k * num_heads, bias=False)
        self.W_Q = nn.init.xavier_uniform_(self.W_Q)
        self.W_K = nn.init.xavier_uniform_(self.W_K)
        self.W_V = nn.init.xavier_uniform_(self.W_V)
        self.d_q = torch.pow(torch.Tensor([dim_q]).to(device), 1 / 4)
        self.d_k = torch.pow(torch.Tensor([dim_k]).to(device), 1 / 4)
        self.softmax_col = nn.Softmax(dim=-1)
        self.softmax_row = nn.Softmax(dim=-2)

    def mask(self, dim: (int, int)) -> Tensor:
        a, b = dim
        mask = torch.ones(b, a)
        mask = torch.triu(mask, diagonal=0)
        mask = torch.log(mask.T)
        return mask.to(self.device)

    def forward(self, x_q, x_k, x_v):
        x_q, x_k, x_v = x_q.unsqueeze(dim=1), x_k.unsqueeze(dim=1), x_v.unsqueeze(dim=1)
        print(f"{self.W_Q.shape=}")
        print(f"{x_q.shape=}")
        Q = torch.matmul(x_q, self.W_Q)
        K = torch.matmul(x_k, self.W_K)
        V = torch.matmul(x_v, self.W_V)
        if self.with_mask == True:
            Q += self.mask(Q.shape[-2:])
        A = torch.matmul(self.softmax_col(K.transpose(-1, -2).contiguous() / self.d_k), V)
        B = torch.matmul(self.softmax_row(Q / self.d_q), A)
        #print(f"{B.shape}")
        b, h, w, d = B.shape
        B = self.W_O(B.view(b, w, h * d))

        return B

# YEP
class GLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        out, gate = x.chunk(2, dim=self.dim)
        return out * gate.sigmoid()

# YEP
class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size=3, padding=1):
        super().__init__()
        self.conv1 = nn.Conv2d(chan_in, chan_in, kernel_size=(1, kernel_size), padding=(0, padding))
        self.conv2 = nn.Conv2d(chan_in, chan_out, kernel_size=(1, 1))

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return x

# YEP
class PointWiseConv(nn.Module):
    def __init__(self, chan_in):
        super().__init__()
        self.pw_conv = nn.Conv2d(in_channels=chan_in, out_channels=1, kernel_size=1)

    def forward(self, inputs):
        x = self.pw_conv(inputs)
        return x

# YEP
class ConvModule(nn.Module):
    def __init__(self, dim_W, dim_bn, dropout=0.3):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim_W)
        self.pw_conv1 = PointWiseConv(chan_in=1)
        self.glu = GLU(dim=-1)
        self.dw_conv1d = DepthWiseConv1d(1, dim_bn, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(dim_bn)
        self.swish = Swish()
        self.pw_conv2 = PointWiseConv(chan_in=dim_bn)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs):
        x = inputs.unsqueeze(dim=1)

        x = self.ln1(x)

        x = self.pw_conv1(x)
        x = self.glu(x)
        x = self.dw_conv1d(x)

        x = self.bn(x)
        x = self.swish(x)
        x = self.pw_conv2(x)
        x = self.dropout(x)
        x = inputs.squeeze(dim=1)
        return x

# YEP
class LAC(nn.Module):
    def __init__(self, d_model, n_heads=2, device="cpu", dropout=0.1):
        super().__init__()
        self.lffn1 = LFFN(dim=d_model, dim_bn=256, dim_hid=1024)
        self.do1 = nn.Dropout(dropout)
        self.mhlsa = MHLA2(num_heads=n_heads, dim_input_q=d_model, dim_input_kv=d_model, device=device)
        self.do2 = nn.Dropout(dropout)
        self.conv_module = ConvModule(d_model, dim_bn=8)
        self.do3 = nn.Dropout(dropout)
        self.lffn2 = LFFN(dim=d_model, dim_bn=256, dim_hid=1024)
        self.do4 = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(d_model)

    def forward(self, inputs):
        x = inputs
        x = x + 1 / 2 * self.do1(self.lffn1(x))
        x = x + self.do2(self.mhlsa(x, x, x))
        x = x + self.do3(self.conv_module(x))
        x = x + 1 / 2 * self.do4(self.lffn2(x))
        x = self.ln(x)
        return inputs + x

# YEP
class Encoder(nn.Module):
    def __init__(self, d_model, n_encoders=2, device="cpu"):
        super().__init__()
        self.lacs = nn.Sequential(*[LAC(d_model=d_model, device=device) for i in range(n_encoders)])

    def forward(self, inputs):
        x = self.lacs(inputs)
        return x


# YEP
class DecoderBlock(nn.Module):
    def __init__(self, dim_tgt, dim_mem, device="cpu", dropout=0.1):
        super().__init__()
        self.mhla_with_mask = MHLA2(num_heads=2, dim_input_q=dim_tgt, dim_input_kv=dim_tgt, mask=True, device=device)
        self.do1 = nn.Dropout(dropout)
        self.ln1 = nn.LayerNorm(dim_tgt)
        self.mhla_with_memory = MHLA2(num_heads=2, dim_input_q=dim_tgt, dim_input_kv=dim_mem, device=device)
        self.do2 = nn.Dropout(dropout)
        self.ln2 = nn.LayerNorm(dim_tgt)
        self.lffn = LFFN(dim=dim_mem, dim_bn=256, dim_hid=1024)
        self.do3 = nn.Dropout(dropout)
        self.ln3 = nn.LayerNorm(dim_tgt)

    def forward(self, mem, y):
        y = y + self.do1(self.mhla_with_mask(y, y, y))
        y = self.ln1(y)
        y = y + self.do2(self.mhla_with_memory(y, mem, mem))
        y = self.ln2(y)
        y = y + self.do3(self.lffn(y))
        y = self.ln3(y)
        return y
    

# YEP
class Decoder(nn.Module):
    def __init__(self, d_model, dropout=0.3, device="cpu", n_decoders=4):
        super().__init__()
        self.device = device
        self.dec_blocks = nn.ModuleList([
            DecoderBlock(dim_tgt=d_model, dim_mem=d_model, device=device)
            for _ in range(n_decoders)])
        self.classifier = nn.Sequential(
            nn.Linear(64, 38).to(device),
            nn.Dropout(dropout),
        )

    def forward(self, mem, tgt):
        y = tgt.to(self.device)
        print(f"{y.shape=}")

        for dec in self.dec_blocks:
            y = dec(mem, y)

        y = self.classifier(y)
        return y



In [41]:
def show_matrix(matrix: torch.Tensor, caption):
    print(caption)
    b, c, w, h = matrix.shape
    m = matrix.reshape(w, h)
    plt.figure(figsize=(10, 6))
    plt.ylabel("position")
    plt.xlabel("dim")
    plt.imshow(m.detach().numpy(), origin='lower')
    plt.show()


class Conformer(nn.Module):
    def __init__(self, n_encoders=2, n_decoders=2, d_model=64, device="cpu", dropout=0.3):
        super().__init__()
        self.enc_proc = EncoderInputsProc(d_inputs=768, d_model=d_model, device=device)
        self.dec_proc = DecoderInputsProc(d_inputs=38, d_model=d_model, device=device)
        self.pos_enc_inp = AbsolutePositionEncoding()
        self.pos_enc_out = AbsolutePositionEncoding()
        self.encoder = Encoder(n_encoders=n_encoders, d_model=64, device=device)
        self.decoder = Decoder(n_decoders=n_decoders, d_model=64, device=device)
        self.device = device
        self.to(device)

    def to(self, device, *args, **kwargs):
        self = super().to(device, *args, **kwargs)
        self.device = device
        return self

    def forward(self, inputs, tgt):
        x = self.enc_proc(inputs)
        x = x + self.pos_enc_inp(x).to(self.device)
        x = self.encoder(x)        
        y = self.dec_proc(tgt)
        y = y + self.pos_enc_out(y).to(self.device)
        y = self.decoder(x, y)

        return x, y
    
    
d_model = 64
    
X1 = torch.randn(1, 206, 768) #X1 = torch.randn(1, 768, 206)
X2 = torch.randn(2, 1024, 768) #X2 = torch.randn(2, 768, 1024)

Y1 = torch.randn(1, 29, 38)
Y2 = torch.randn(2, 151, 38)

enc_proc = EncoderInputsProc(768, d_model) # YEP
X1_p = enc_proc(X1)
X2_p = enc_proc(X2)

dec_proc = DecoderInputsProc(38, d_model) # YEP
Y1_p = dec_proc(Y1)
Y2_p = dec_proc(Y2)

print(f"Model {X1.shape=}")
print(f"Model {X2.shape=}")
tgt = Y2_p
out = X2_p
APE = AbsolutePositionEncoding() # YEP
PE = APE(out)
out = out + PE
print(f"PE {PE.shape}")
print(f"X + PE {out.shape=}")
lffn = LFFN(dim=d_model, dim_bn=256, dim_hid=1024) # YEP
out = lffn(out)
print(f"lffn {out.shape=}")

mhla2 = MHLA2(num_heads=2, dim_input_q=d_model, dim_input_kv=d_model) # YEP
out = mhla2(out, out, out)
print(f"mhla2 {out.shape=}")

conv_m = ConvModule(d_model, dim_bn=8) # YEP
out = conv_m(out)
print(f"conv_m {out.shape=}")

lac = LAC(d_model) # YEP
out = lac(out)
print(f"lac {out.shape=}")

enc = Encoder(d_model) # YEP
enc_out = enc(out)
print(f"enc {out.shape=}")

print(f"{enc_out.shape=}")
print(f"{X2.shape=}")
print("out.shape == X2.shape is", out.shape == X2.shape)

dec_b = DecoderBlock(dim_tgt=d_model, dim_mem=d_model) # YEP
out = dec_b(mem=out, y=tgt)
print(f"dec_b {out.shape=}")

dec = Decoder(d_model=d_model) # YEP
out = dec(mem=enc_out, tgt=Y2)
print(f"dec {out.shape=}")

conf = Conformer(d_model=d_model)
emb, out = conf(inputs=X2, tgt=Y2)

out.shape == Y2.shape

Model X1.shape=torch.Size([1, 206, 768])
Model X2.shape=torch.Size([2, 1024, 768])
PE torch.Size([2, 112, 64])
X + PE out.shape=torch.Size([2, 112, 64])
lffn out.shape=torch.Size([2, 112, 64])
self.W_Q.shape=torch.Size([2, 64, 32])
x_q.shape=torch.Size([2, 1, 112, 64])
mhla2 out.shape=torch.Size([2, 112, 64])
conv_m out.shape=torch.Size([2, 112, 64])
self.W_Q.shape=torch.Size([2, 64, 32])
x_q.shape=torch.Size([2, 1, 112, 64])
lac out.shape=torch.Size([2, 112, 64])
self.W_Q.shape=torch.Size([2, 64, 32])
x_q.shape=torch.Size([2, 1, 112, 64])
self.W_Q.shape=torch.Size([2, 64, 32])
x_q.shape=torch.Size([2, 1, 112, 64])
enc out.shape=torch.Size([2, 112, 64])
enc_out.shape=torch.Size([2, 112, 64])
X2.shape=torch.Size([2, 1024, 768])
out.shape == X2.shape is False
self.W_Q.shape=torch.Size([2, 64, 32])
x_q.shape=torch.Size([2, 1, 151, 64])
self.W_Q.shape=torch.Size([2, 64, 32])
x_q.shape=torch.Size([2, 1, 151, 64])
dec_b out.shape=torch.Size([2, 151, 64])
y.shape=torch.Size([2, 151, 38])
self

RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

In [94]:
d_model = 768
d_tgt = 38

print(d_model/d_tgt)

20.210526315789473


In [18]:
X = torch.randn(1, 768, 206)
conv1 = nn.Conv1d(768, 256, kernel_size=7, stride=3)
conv2 = nn.Conv1d(256, 64, kernel_size=7, stride=3)
out = conv1(X)
out = conv2(out)
out.shape

torch.Size([1, 64, 21])

In [17]:
64/21, 768/206

(3.0476190476190474, 3.7281553398058254)