<a href="https://colab.research.google.com/github/JustinZorig/transformer_transducer_test/blob/main/transformer_transducer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import os
from tempfile import TemporaryDirectory
from typing import Tuple

import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class FeedForward(nn.Module):
  def __init__(self, embedding_dim:int = 512, dense1:int = 2048, dense2:int = 512, dropout:float = 0.1):
    """
    embedding_dim: All sublayers and embeddings are vectors of size 512 (default). Facilitates residual connections.
                   So, input into feedforward and output from will both be size embedding_dim.
    dense1: number of neurons of the first dense layer
    dense2: number of neurons of the second dense lyaer

    """
    super().__init__()
    self.ff = nn.Sequential(
        nn.Linear(embedding_dim, dense1),
        nn.ReLU(),
        nn.Dropout(p = dropout),
        nn.Linear(dense1, dense2)
    )
  def forward(self, input: Tensor) -> Tensor:
    return self.ff(input)


class PositionalEncoding(nn.Module):
    """
    Implement the positional encoding (PE) function.

    PE_(pos, 2i)    =  sin(pos / 10000 ** (2i / d_model))
    PE_(pos, 2i+1)  =  cos(pos / 10000 ** (2i / d_model))

    Taken from: https://github.com/upskyy/Transformer-Transducer/blob/main/transformer_transducer/module.py
    """
    def __init__(
            self,
            d_model: int = 512,
            max_len: int = 5000
    ) -> None:
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model, requires_grad=False)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, length: int) -> Tensor:
        return self.pe[:, :length, :]


class TransformerBlock(nn.Module):
  def __init__(self, embedding_dim:int =512, num_heads:int =8, dense1:int = 2048,
               dense2:int = 512, dropout:float = 0.1):

    super().__init__()

    self.attention = nn.MultiheadAttention(embed_dim = embedding_dim, num_heads = num_heads, dropout = dropout)
    self.norm = nn.LayerNorm(embedding_dim)
    self.ff = FeedForward(embedding_dim, dense1, dense2, dropout)
    self.dropout = nn.Dropout(dropout)


  def forward(self, input:Tensor) -> Tensor:
    # Dropout applied to output of each sublayer before added to residual connection

    x = self.attention(input, input, input) #query, key, value gets the same input
    x = self.droput(x)
    x = x + input           #Add residual
    ff_input = self.norm(x)

    x = self.ff(ff_input)
    x = self.dropout(x)
    x = x + ff_input        #Add residual
    x = self.norm(x)

    return x




class TransformerModel(nn.Module):

  def __init__(self,  input_dim: int, num_layers: int, embedding_dim: int = 512, num_heads: int =8,
              ff_hidden_dim:int = 2048,  dropout: float = 0.1, add_positional = True):
    super().__init__()
    #self.model_type = 'Transformer'

      # input dim to embedding dim
    self.embedding_net = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(input_dim, embedding_dim)
    )
    self.pos_encoder = PositionalEncoding(embedding_dim, 5000)


    self.layers = nn.ModuleList([TransformerBlock(embedding_dim, num_heads,
                                                  ff_hidden_dim, embedding_dim) for _ in range(num_layers)])
    #self.init_weights()

    #def init_weights(self) -> None:
    #    initrange = 0.1
    #    self.embedding.weight.data.uniform_(-initrange, initrange)
    #    self.linear.bias.data.zero_()
    #    self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, input: Tensor, src_mask: Tensor = None) -> Tensor:
      """
        Arguments:
            input: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
      """

      x = self.embedding_net()
      if add_positional:
        x = x + self.pos_encoder(x)

      x = self.layers(x)
      return x




In [None]:
class JoinerNetwork(nn.Module):
  def __init__(self, embedding_dim, num_outputs):
    """
    """
    super().__init__()
    self.linear = nn.Linear(embedding_dim, num_outputs)
    self.relu = nn.ReLU()
  def forward(self, audio_out, label_out) -> Tensor:
    x = audio_out + label_out
    x = self.relu(x)
    x = self.linear(x)
    return x

In [None]:
model = TransformerModel(50, 8)
print(model)

g = JoinerNetwork(50, 8)
print(g)

TransformerModel(
  (embedding_net): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=50, out_features=512, bias=True)
  )
  (pos_encoder): PositionalEncoding()
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): FeedForward(
        (ff): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=2048, out_features=512, bias=True)
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
JoinerNetwork(
  (linear): Linear(in_features=50, out_features=8, bias=True)
  (relu): ReLU()
)
