In [2]:
!pip install einops
!pip install warp-rnnt
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

import math
import numpy as np
from einops import rearrange
from torchaudio.models import RNNT
import os
import sys

import torch.optim as optim
from torchaudio.functional import rnnt_loss

if os.getenv("COLAB_RELEASE_TAG"):
  from google.colab import drive
  drive.mount('/content/drive')
  py_file_location = '/content/drive/MyDrive/models/'
  sys.path.append(py_file_location)
from activation_functions import aptx, sigmaptx, gelu, glu, relu
from adam_variant import ScaledAdam
from attention_mechanisms import MultiHeadAttention, MultiHeadSelfAttention
from positional_embedding import absolutepositionalembedding, rotarypositionalembedding
from decoders import DecoderRNNT
# from warp_rnnt import rnnt_loss

Collecting warp-rnnt
  Using cached warp_rnnt-0.7.0.tar.gz (15 kB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>[0m See above for output.

[1;35mnote[0m: This is an issue with the package mentioned above, not pip.
[1;36mhint[0m: See above for details.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [26]:
# helper functions
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def calc_same_padding(kernel_size):
    pad = kernel_size // 2
    return (pad, pad - (kernel_size + 1) % 2)

# helper classes
class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.padding = padding
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)

    def forward(self, x):
        x = F.pad(x, self.padding)
        return self.conv(x)

# attention, feedforward, and conv module

class Scale(nn.Module):
    def __init__(self, scale, fn):
        super().__init__()
        self.fn = fn
        self.scale = scale

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)


class FeedForward_Horizontal(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            sigmaptx(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class FeedForward_Vertical(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            aptx(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class ConformerConvModule_Horizontal(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        expansion_factor = 2,
        kernel_size = 31,
        dropout = 0.
    ):
        super().__init__()

        inner_dim = dim * expansion_factor
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            nn.Conv1d(dim, inner_dim, 1),
            gelu(),
            DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
            nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
            sigmaptx(),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange('b c n -> b n c'),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class ConformerConvModule_Vertical(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        expansion_factor = 2,
        kernel_size = 31,
        dropout = 0.
    ):
        super().__init__()

        inner_dim = dim * expansion_factor
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            nn.Conv1d(dim, inner_dim * 2, 1),
            glu(dim = 1),
            DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
            nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
            aptx(),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange('b c n -> b n c'),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# Conformer Block

class ConformerBlock_Vertical(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.ff1 = FeedForward_Vertical(dim = dim, mult = ff_mult, dropout = ff_dropout)
        self.position = rotarypositionalembedding(d_model = dim)
        self.attn = MultiHeadSelfAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, linear_bias = False, include_local_attention = True, local_attention_window = 9, local_attention_dim_vertical = True)
        self.conv = ConformerConvModule_Vertical(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
        self.ff2 = FeedForward_Vertical(dim = dim, mult = ff_mult, dropout = ff_dropout)

        self.attn = PreNorm(dim, self.attn)
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x, mask = None):
        x = self.ff1(x) + x
        x = self.position(x)
        x = self.attn(x, mask = mask) + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x

class ConformerBlock_Horizontal(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.ff1 = FeedForward_Horizontal(dim = dim, mult = ff_mult, dropout = ff_dropout)
        self.position = absolutepositionalembedding(d_model = dim)
        self.attn = MultiHeadSelfAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, linear_bias = True, include_local_attention = True, local_attention_window = 9, local_attention_dim_vertical = True)
        self.conv = ConformerConvModule_Horizontal(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
        self.ff2 = FeedForward_Horizontal(dim = dim, mult = ff_mult, dropout = ff_dropout)

        self.attn = PreNorm(dim, self.attn)
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x, mask = None):
        x = self.ff1(x) + x
        x = self.position(x)
        x = self.attn(x, mask = mask) + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x

# Conformer

class Conformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        seq_length,
        depth,
        output_dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.dim = dim
        self.output_dim = output_dim
        self.output_linear = nn.Linear(dim, output_dim, bias = True)
        self.layers_vertical = nn.ModuleList([])
        self.layers_horizontal = nn.ModuleList([])

        for _ in range(int(depth/2)):
            self.layers_vertical.append(ConformerBlock_Vertical(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                conv_expansion_factor = conv_expansion_factor,
                conv_kernel_size = conv_kernel_size,
                conv_causal = conv_causal

            ))
            self.layers_horizontal.append(ConformerBlock_Horizontal(
                dim = seq_length,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                conv_expansion_factor = conv_expansion_factor,
                conv_kernel_size = conv_kernel_size,
                conv_causal = conv_causal

            ))

    def forward(self, x):
        x_vertical = x
        x_horizontal = x.transpose(-2, -1)
        for block in self.layers_vertical:
            x_vertical = block(x_vertical)
        for block in self.layers_horizontal:
            x_horizontal = block(x_horizontal)
        x_horizontal = x_horizontal.transpose(-2, -1)
        shape = x_vertical.shape
        assert x_vertical.shape == x_horizontal.shape, "Input tensors must have the same shape"

        # Define dynamic weights as learnable parameters with the same shape as the inputs
        weight1 = nn.Parameter(torch.randn(*shape), requires_grad=True)
        weight2 = nn.Parameter(torch.randn(*shape), requires_grad=True)

        # Compute the weighted sum
        weighted_sum = weight1 * x_vertical + weight2 * x_horizontal
        # Linear layer to get it in the output dim
        output = self.output_linear(weighted_sum)
        return output


# JointNet to use the the decoding and transducer part of RNNT
#Imporved upon the code in https://github.com/ZhengkunTian/rnn-transducer/blob/master/rnnt/ for the Base Encoders, Decoders and the overall transducer
class JointNet(nn.Module):
    def __init__(self, input_size, hidden_size, vocab_size):
        super(JointNet, self).__init__()
        self.forward_layer = nn.Linear(input_size, hidden_size, bias=True)
        self.tanh = nn.Tanh()
        self.project_layer = nn.Linear(hidden_size, vocab_size, bias=True)

    def forward(self, enc_state, dec_state):
        if enc_state.dim() == 3 and dec_state.dim() == 3:
            dec_state = dec_state.unsqueeze(1)
            enc_state = enc_state.unsqueeze(2)
            t = enc_state.size(1)
            u = dec_state.size(2)
            enc_state = enc_state.repeat([1, 1, u, 1])
            dec_state = dec_state.repeat([1, t, 1, 1])
        else:
            assert enc_state.dim() == dec_state.dim()

        concat_state = torch.cat((enc_state, dec_state), dim=-1)
        outputs = self.forward_layer(concat_state)
        outputs = self.tanh(outputs)
        outputs = self.project_layer(outputs)
        outputs = outputs.mean(dim=2)
        # outputs = F.log_softmax(outputs, dim=-1)
        return outputs


# Conformer-RNNT Model
class ConformerRNNT(nn.Module):
    def __init__(self, input_dim, seq_len, num_enc_layers, conv_kernel_size, hidden_dim, output_dim, num_dec_layers, conv_dropout=0.1, enc_has_cont_val = True, share_embedding = True):
        super(ConformerRNNT, self).__init__()
        self.encoder = Conformer(dim = input_dim, seq_length = seq_length, depth = num_enc_layers, output_dim = output_dim, conv_kernel_size = conv_kernel_size, conv_dropout = conv_dropout)
        self.decoder = DecoderRNNT(input_dim = output_dim, hidden_dim = hidden_dim, output_dim = output_dim, num_layers = num_dec_layers, enc_has_cont_val = enc_has_cont_val)
        self.joint = JointNet(
            input_size=2*output_dim,
            hidden_size=hidden_dim,
            vocab_size=output_dim
        )
        if share_embedding and not enc_has_cont_val:
            assert self.decoder.embedding.weight.size() == self.joint.project_layer.weight.size(), '%d != %d' % (self.decoder.embedding.weight.size(1),  self.joint.project_layer.weight.size(1))
            self.joint.project_layer.weight = self.decoder.embedding.weight

    def forward(self, inputs, targets, inputs_length = None, targets_length = None):
        enc_state = self.encoder(inputs)
        dec_state, _ = self.decoder(targets, targets_length)
        output = self.joint(enc_state, dec_state)
        return output

    def recognize(self, inputs, inputs_length):
        batch_size = inputs.size(0)
        enc_states = self.encoder(inputs, inputs_length)
        zero_token = torch.LongTensor([[0]])
        if inputs.is_cuda:
            zero_token = zero_token.cuda()

        def decode(enc_state, lengths):
            token_list = []
            dec_state, hidden = self.decoder(zero_token)
            for t in range(lengths):
                logits = self.joint(enc_state[t].view(-1), dec_state.view(-1))
                out = F.softmax(logits, dim=0).detach()
                pred = torch.argmax(out, dim=0)
                pred = int(pred.item())
                if pred != 0:
                    token_list.append(pred)
                    token = torch.LongTensor([[pred]])
                    if enc_state.is_cuda:
                        token = token.cuda()
                    dec_state, hidden = self.decoder(token, hidden=hidden)
            return token_list
        results = []
        for i in range(batch_size):
            decoded_seq = decode(enc_states[i], inputs_length[i])
            results.append(decoded_seq)
        return results

In [None]:
batch_size = 10
seq_length = 20
input_dim = 128
output_dim = 256
hidden_dim = 512
num_enc_layers = 16
num_heads = 8
ff_dim = 2048
conv_kernel_size = 8
num_dec_layers = 16

torch.autograd.set_detect_anomaly(True)
# Create dummy input data
inputs = torch.randn(batch_size, seq_length, input_dim, requires_grad = True)

# Initialize the model
model = ConformerRNNT(input_dim, seq_length, num_enc_layers, conv_kernel_size, hidden_dim, output_dim, num_dec_layers)

# Define the optimizer
# optimizer = optim.Adam(model.parameters(), lr = 0.001)
optimizer = ScaledAdam(model.parameters(), lr = 0.00001)

targets = torch.randn(batch_size, seq_length, output_dim)  # (batch, target_seq_len)
input_lengths = torch.full((batch_size,), seq_length, dtype=torch.long)  # Input lengths (all same in this example)
target_lengths = torch.randint(5, output_dim, (batch_size,), dtype=torch.long)  # Random target lengths

criterion = nn.MSELoss()
# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # Perform forward pass
    output = model(inputs, targets)

    # Calculate loss
    loss = criterion(output, targets)

    # Backward pass
    loss.backward()

    # Perform optimization step
    optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

Epoch [1/100], Loss: 1.0281822681427002
Epoch [2/100], Loss: 1.0243932008743286
Epoch [3/100], Loss: 1.078877329826355
Epoch [4/100], Loss: 1.029217004776001
Epoch [5/100], Loss: 1.0327353477478027
Epoch [6/100], Loss: 1.0327922105789185
Epoch [7/100], Loss: 1.0366309881210327
Epoch [8/100], Loss: 1.027967095375061
Epoch [9/100], Loss: 1.0295909643173218
Epoch [10/100], Loss: 1.023797869682312
Epoch [11/100], Loss: 1.0262619256973267
Epoch [12/100], Loss: 1.0228439569473267
Epoch [13/100], Loss: 2.1144676208496094
Epoch [14/100], Loss: 1.7771921157836914
Epoch [15/100], Loss: 1.0769721269607544
Epoch [16/100], Loss: 1.0611263513565063
Epoch [17/100], Loss: 1.04413640499115
Epoch [18/100], Loss: 1.0298364162445068
Epoch [19/100], Loss: 1.0333975553512573
Epoch [20/100], Loss: 1.0369511842727661
Epoch [21/100], Loss: 1.0311294794082642
Epoch [22/100], Loss: 1.0752931833267212
Epoch [23/100], Loss: 1.1231662034988403
Epoch [24/100], Loss: 1.113823413848877
Epoch [25/100], Loss: 1.53447222