In [None]:
!pip install einops
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

import math
import numpy as np
from einops import rearrange
from torchaudio.models import RNNT
import os
import sys

import torch.optim as optim
from torchaudio.functional import rnnt_loss

if os.getenv("COLAB_RELEASE_TAG"):
  from google.colab import drive
  drive.mount('/content/drive')
  py_file_location = '/content/drive/MyDrive/models/'
  sys.path.append(py_file_location)
from activation_functions import aptx, sigmaptx, glu
from adam_variant import ScaledAdam
from attention_mechanisms import MultiHeadAttention, MultiHeadSelfAttention
from positional_embedding import absolutepositionalembedding, rotarypositionalembedding
from decoders import DecoderRNNT

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# helper functions
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def calc_same_padding(kernel_size):
    pad = kernel_size // 2
    return (pad, pad - (kernel_size + 1) % 2)

# helper classes
class DepthWiseConv1d(nn.Module):
    def __init__(self, chan_in, chan_out, kernel_size, padding):
        super().__init__()
        self.padding = padding
        self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)

    def forward(self, x):
        x = F.pad(x, self.padding)
        return self.conv(x)

# attention, feedforward, and conv module

class Scale(nn.Module):
    def __init__(self, scale, fn):
        super().__init__()
        self.fn = fn
        self.scale = scale

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) * self.scale

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)


class FeedForward_Horizontal(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            sigmaptx(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class FeedForward_Vertical(nn.Module):
    def __init__(
        self,
        dim,
        mult = 4,
        dropout = 0.
    ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            aptx(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class ConformerConvModule_Horizontal_Vertical(nn.Module):
    def __init__(
        self,
        dim,
        causal = False,
        expansion_factor = 2,
        kernel_size = 31,
        dropout = 0.
    ):
        super().__init__()

        inner_dim = dim * expansion_factor
        padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n c -> b c n'),
            nn.Conv1d(dim, inner_dim * 2, 1),
            glu(dim = 1),
            DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
            nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
            aptx(),
            nn.Conv1d(inner_dim, dim, 1),
            Rearrange('b c n -> b n c'),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

# Conformer Block

class ConformerBlock_Horizontal_Vertical(nn.Module):
    def __init__(
        self,
        *,
        dim,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.ff1 = FeedForward_Vertical(dim = dim, mult = ff_mult, dropout = ff_dropout)
        self.position = rotarypositionalembedding(d_model = dim)
        self.attn1 = MultiHeadSelfAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, linear_bias = False, include_local_attention = True, local_attention_window = 3, local_attention_dim_vertical = True)
        self.attn2 = MultiHeadSelfAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, linear_bias = True, include_local_attention = True, local_attention_window = 3, local_attention_dim_vertical = False)
        self.conv = ConformerConvModule_Horizontal_Vertical(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
        self.ff2 = FeedForward_Horizontal(dim = dim, mult = ff_mult, dropout = ff_dropout)

        self.attn1 = PreNorm(dim, self.attn1)
        self.attn2 = PreNorm(dim, self.attn2)
        self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
        self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

        self.post_norm = nn.LayerNorm(dim)

    def forward(self, x, mask = None):
        x = self.ff1(x) + x
        x = self.position(x)
        x = self.attn1(x, mask = mask) + x
        x = self.attn2(x, mask = mask) + x
        x = self.conv(x) + x
        x = self.ff2(x) + x
        x = self.post_norm(x)
        return x

# Conformer

class Conformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        conv_expansion_factor = 2,
        conv_kernel_size = 31,
        attn_dropout = 0.,
        ff_dropout = 0.,
        conv_dropout = 0.,
        conv_causal = False
    ):
        super().__init__()
        self.dim = dim
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(ConformerBlock_Horizontal_Vertical(
                dim = dim,
                dim_head = dim_head,
                heads = heads,
                ff_mult = ff_mult,
                conv_expansion_factor = conv_expansion_factor,
                conv_kernel_size = conv_kernel_size,
                conv_causal = conv_causal

            ))


    def forward(self, x):
        for block in self.layers:
            x = block(x)
        return x


# Conformer-RNNT Model
class ConformerRNNT(nn.Module):
    def __init__(self, input_dim, num_enc_layers, conv_kernel_size, conv_dropout=0.1):
        super(ConformerRNNT, self).__init__()
        self.encoder = Conformer(dim = input_dim, depth = num_enc_layers, conv_kernel_size = conv_kernel_size, conv_dropout = conv_dropout)
        self.rnnt = DecoderRNNT(num_classes = input_dim)

    def forward(self, src):
        enc_output = self.encoder(src)
        dec_output, dec_hidden_states = self.rnnt(enc_output)
        return dec_output

In [None]:
batch_size = 4
seq_length = 20
input_dim = 128
output_dim = 128
model_dim = 256
num_enc_layers = 4
num_heads = 8
ff_dim = 2048
conv_kernel_size = 31

torch.autograd.set_detect_anomaly(True)
# Create dummy input data
inputs = torch.randn(batch_size, seq_length, input_dim, requires_grad = True)

# Initialize the model
model = ConformerRNNT(input_dim, num_enc_layers, conv_kernel_size)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr = 0.001)
# optimizer = ScaledAdam(model.parameters(), lr = 0.00001)

targets = torch.randn(batch_size, seq_length, output_dim)  # (batch, target_seq_len)
input_lengths = torch.full((batch_size,), seq_length, dtype=torch.long)  # Input lengths (all same in this example)
target_lengths = torch.randint(5, output_dim, (batch_size,), dtype=torch.long)  # Random target lengths

criterion = nn.MSELoss()
# Training loop
num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # Perform forward pass
    output = model(inputs)

    # Calculate loss
    loss = criterion(output, targets)

    # Backward pass
    loss.backward()

    # Perform optimization step
    optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

Epoch [1/25], Loss: 1.0054991245269775
Epoch [2/25], Loss: 1.0010545253753662
Epoch [3/25], Loss: 0.9938554763793945
Epoch [4/25], Loss: 1.0127694606781006
Epoch [5/25], Loss: 0.9856674075126648
Epoch [6/25], Loss: 0.9850400686264038
Epoch [7/25], Loss: 0.9800266027450562
Epoch [8/25], Loss: 0.9746753573417664
Epoch [9/25], Loss: 0.9700107574462891
Epoch [10/25], Loss: 0.9653409123420715
Epoch [11/25], Loss: 0.9607140421867371
Epoch [12/25], Loss: 0.9560474157333374
Epoch [13/25], Loss: 0.9524437189102173
Epoch [14/25], Loss: 0.9498961567878723
Epoch [15/25], Loss: 0.9459260702133179
Epoch [16/25], Loss: 0.9473991394042969
Epoch [17/25], Loss: 0.9413295984268188
Epoch [18/25], Loss: 0.9383352398872375
Epoch [19/25], Loss: 0.934430718421936
Epoch [20/25], Loss: 0.9408702850341797
Epoch [21/25], Loss: 0.9322853088378906
Epoch [22/25], Loss: 1.0245431661605835
Epoch [23/25], Loss: 0.978649914264679
Epoch [24/25], Loss: 0.9788764715194702
Epoch [25/25], Loss: 0.9744836091995239
