In [1]:
import torch
from torch import nn
from pathlib import Path
import random

# from tokenizers import Tokenizer
from torch.utils.data import Dataset, DataLoader

In [None]:
#Data
# from dotenv import load_dotenv
import os

# load_dotenv()

HF_TOKEN = '...'
# Load model directly
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", token=HF_TOKEN)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

SOT = '<|startoftranscript|>'
EOT = '<|endoftranscript|>'
transcribe = '<|transcribe|>'
prev = '<|prev|>'

special_tokens_dict = {
    'additional_special_tokens': [SOT, EOT, transcribe, prev]
}

# Update the tokenizer with the new special tokens
tokenizer.add_special_tokens(special_tokens_dict)
# model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")

# tokenizer("hi")

In [None]:

!pip install wandb
import wandb
!wandb login

In [4]:
#Hyperparameters
epochs=10
block_size = 64
batch_size = 64
# src_vocab_size = None
tgt_vocab_size = len(tokenizer)
embeddings_dims = 384
attn_dropout = 0.1
no_of_heads = 6 #IMP needs to be thoroughly calculated
dropout = 0.1
# epochs = 3
max_lr = 2e-4
no_of_decoder_layers = 6 #IMP needs to be thoroughly calculated
attn_dropout = 0.1
weight_decay_optim = 0.01
log_mel_features = 80
kernel_size = 3
stride = (2,10)
sr = 16000
device= 'cuda:0'
SAMPLING_RATE=16000
N_MELS = 80  # 80-channel Mel spectrogram
WINDOW_DURATION = 0.025  # 25 milliseconds
STRIDE_DURATION = 0.010  # 10 milliseconds
max_t = 500
n_channels = N_MELS

In [5]:
torch.set_default_device(device)

In [None]:

!pip install datasets
from tabnanny import verbose
from datasets import load_dataset

gs = load_dataset("speechcolab/gigaspeech", "xs", token=HF_TOKEN, trust_remote_code=True) # Ensures only 'train' split of 'xs' is download)

# see structure
print(gs)

# load audio sample on the fly
audio_input = gs['train'][0]["audio"]  # first decoded audio sample
transcription = gs["train"][0]["text"]  # first transcription


In [None]:
gs['train'][0]

In [None]:
MAX_DURATION_IN_SECONDS = 10

import librosa
from tqdm import tqdm
def is_audio_length_in_range(input_length):
    return input_length < MAX_DURATION_IN_SECONDS

train_new_column = []
# new_column = [librosa.get_duration(path=x) ]] #Because test data has more rows
for x in tqdm(range(len(gs['test']))):
    train_new_column.append(librosa.get_duration(path=gs['test'][x]['audio']['path']))

gs_ = gs['test'].add_column("duration", train_new_column)


gs_ = gs_.filter(is_audio_length_in_range, input_columns=["duration"])


truncated_gs_train = gs_.remove_columns(["duration"])
# truncated_gs



val_new_column = []
# new_column = [librosa.get_duration(path=x) ]]
for x in tqdm(range(len(gs['validation']))):
    val_new_column.append(librosa.get_duration(path=gs['validation'][x]['audio']['path']))

gs_ = gs['validation'].add_column("duration", val_new_column)


gs_ = gs_.filter(is_audio_length_in_range, input_columns=["duration"])


truncated_gs_val = gs_.remove_columns(["duration"])
# truncated_gs

In [None]:

import numpy as np


n_fft = int(WINDOW_DURATION * MAX_DURATION_IN_SECONDS * SAMPLING_RATE)
hop_length = int(STRIDE_DURATION * MAX_DURATION_IN_SECONDS * SAMPLING_RATE)

train_outputs = []
train_texts = []
for i in tqdm(range(len(truncated_gs_train))):
  S = librosa.feature.melspectrogram(
      y=truncated_gs_train[i]['audio']['array'],
      sr=SAMPLING_RATE,
      n_mels=N_MELS,
      n_fft=n_fft,
      hop_length=hop_length,
      win_length=n_fft,
      fmax=SAMPLING_RATE // 2
  )


  S_dB = librosa.power_to_db(S, ref=np.max)
  train_outputs.append(S_dB)
  train_texts.append(truncated_gs_train[i]['text'])

val_outputs = []
val_texts = []
for i in tqdm(range(len(truncated_gs_val))):
  S = librosa.feature.melspectrogram(
      y=truncated_gs_val[i]['audio']['array'],
      sr=SAMPLING_RATE,
      n_mels=N_MELS,
      n_fft=n_fft,
      hop_length=hop_length,
      win_length=n_fft,
      fmax=SAMPLING_RATE // 2
  )


  S_dB = librosa.power_to_db(S, ref=np.max)
  val_outputs.append(S_dB)
  val_texts.append(truncated_gs_val[i]['text'])

In [None]:
train_outputs[0].shape

In [None]:

# Calculate the maximum t in the dataset
max_t = max(spectrogram.shape[1] for spectrogram in train_outputs + val_outputs)
print(f"Maximum t in the dataset: {max_t}")

import numpy as np

# Calculate the average t for the training dataset
train_t_lengths = [spectrogram.shape[1] for spectrogram in train_outputs + val_outputs]
avg_t_train = np.mean(train_t_lengths)

print(f"Average t (training): {avg_t_train}")
# print(f"Average t (validation): {avg_t_val}")

In [None]:
import re

# Example text
text = "AS THEY'RE LEAVING <COMMA> CAN KASH PULL ZAHRA ASIDE REALLY QUICKLY <QUESTIONMARK>"

# Use regex to remove anything between < and >
cleaned_text = re.sub(r'<[^>]*>', '', text)

print(cleaned_text)

In [None]:
train_texts[0]

In [14]:
# import math
# print(round(random.random(), 1))
class GigaSpeechDataset(Dataset):

  def __init__(self, outputs, texts):

    self.data = outputs
    self.texts = texts
    self.max_t = block_size

  def __len__(self):
    return len(self.data)


  def pad_to_max_t(self, spectrogram, max_t):

    n_mels, t = spectrogram.shape
    if t < max_t:
        # Pad with zeros
        pad_width = ((0, 0), (0, max_t - t))
        spectrogram = np.pad(spectrogram, pad_width, mode='constant')
    else:
      spectrogram = spectrogram[:, :max_t]

    return spectrogram

  def clean(self, desc):
    # Use regex to remove anything between < and >
    cleaned_text = re.sub(r'<[^>]*>', '', desc)
    return cleaned_text

  def __getitem__(self, idx):

      SOT = '<|startoftranscript|>'
      EOT = '<|endoftranscript|>'
      transcribe = '<|transcribe|>'
      # prev = '<|prev|>'
      spectrogram = self.pad_to_max_t(self.data[idx], self.max_t)
      # probs = round(random.random(),1)
      spectrogram = torch.tensor(spectrogram, dtype=torch.float32)

      # if(probs == 0.5):
        # Normalize the spectrogram between -1 and 1
      spectrogram_min = spectrogram.min()
      spectrogram_max = spectrogram.max()
      # spectrogram = spectrogram.unsqueeze(0)  # Shape: (1, n_mels, max_t)
      # prev_text =
      text = self.clean(self.texts[idx])

      text = text.lower()
      text = SOT  + 'en' + transcribe +  text + EOT
      tokenized_text = tokenizer(text, truncation=True, padding='max_length', max_length=block_size, return_tensors='pt')['input_ids']
      # print(tokenized_text.shape)

      epsilon = 1e-8  # To avoid division by zero
      spectrogram = 2 * ((spectrogram - spectrogram_min) / (spectrogram_max - spectrogram_min + epsilon)) - 1

      tokenized_text = tokenized_text.squeeze(0)
      # print(tokenized_text.shape)
      return spectrogram, tokenized_text


In [15]:


shuffle = True

train_dataset = GigaSpeechDataset(train_outputs, train_texts)
val_dataset = GigaSpeechDataset(val_outputs, val_texts)

generator = torch.Generator(device=device)

train_dataloader = DataLoader(

    train_dataset,
    batch_size=batch_size,
    generator=generator,
    shuffle=shuffle,
     drop_last=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    generator=generator,
    drop_last=True ,
    shuffle=shuffle,
)


In [None]:

spec, texts = next(iter(train_dataloader))

texts.shape
# spec.shape

In [None]:


len(S_dB)

In [None]:
len(tokenizer)

In [19]:
from typing import Sequence


class PositionEmbeddings(nn.Module):
    def __init__(self):

        super().__init__()




        self.d_model = embeddings_dims
        self.i = torch.arange(0, embeddings_dims, dtype=torch.float32, device=device)
        # self.pos = torch.arange(0, block_size, dtype=torch.float32)
        self.exp = ((2 * self.i)) / self.d_model
        self.theta = 10000 ** self.exp
        # print(self.theta.shape)
        self.x_reshaped = torch.randn(batch_size, block_size, embeddings_dims, device=device, dtype=torch.float32)

        self.cos = torch.cos((self.i / self.theta))
        self.sin = torch.sin((self.i / self.theta))

        self.even = self.sin[::2]
        self.odd = self.cos[1::2]

        # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
        self.x_reshaped[..., : , ::2] = self.even
        self.x_reshaped[..., : , 1::2] = self.odd



    def pe_for_inference(self, x):

            batch_size, seq_len, embeddings_dims = x.shape

            self.d_model = embeddings_dims
            self.i = torch.arange(0, embeddings_dims, dtype=torch.float32)
            # self.pos = torch.arange(0, block_size, dtype=torch.float32)
            self.exp = ((2 * self.i)) / self.d_model
            self.theta = 10000 ** self.exp
            # print(self.theta.shape)
            x_reshaped = x.view(batch_size, seq_len, embeddings_dims)

            self.cos = torch.cos((self.i / self.theta))
            self.sin = torch.sin((self.i / self.theta))

            self.even = self.sin[::2]
            self.odd = self.cos[1::2]

            # self.block = torch.empty((odd.size(0) + even.size(0),), dtype=self.even.dtype)
            x_reshaped[..., : , ::2] = self.even
            x_reshaped[..., : , 1::2] = self.odd

    def forward(self, x, inference=False):

        if(inference):
            x = self.pe_for_inference(x)
            return x
        else:
            out = self.x_reshaped
            return out


In [20]:
# c = torch.arange(0, block_size)
# odd = c[1::2]
# even = c[::2]
# res = torch.empty((odd.size(0) + even.size(0),), dtype=odd.dtype)
# res[::2] = even
# res[1::2] = odd



In [24]:


# Text embeddings
class TgtTextEmbeddings(nn.Module):
    def __init__(
        self,
        vocab_size = tgt_vocab_size,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()
        self.embeddings_table = nn.Embedding(num_embeddings = tgt_vocab_size, embedding_dim=embeddings_dims, device=device) #Just a look up table to convert the toekns_ids to some numbers
        # nn.init.normal_(self.embeddings_table.weight.data, mean=0, std=0.02)

    def forward(self, x):
        return self.embeddings_table(x)

In [26]:



#Layer Normalization

class LayerNormalization(nn.Module):
    def __init__(
        self,
        embeddings_dims = embeddings_dims
    ):
        super().__init__()
        self.norm = nn.LayerNorm(normalized_shape=embeddings_dims)
    def forward(self, x):

        return self.norm(x)

In [27]:


#FeedForward Neural Network

class MLPBlock(nn.Module):
    def __init__(
        self,
        dropout = dropout,
        embeddings_size = embeddings_dims,
        # inner_dimensional_states: int = 3072
    ):
        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(device=device, in_features=embeddings_size, out_features= 4 * embeddings_dims),
            nn.GELU(),
            nn.Linear(device=device, in_features= 4 * embeddings_dims, out_features=embeddings_size),
            nn.Dropout(p = dropout)
        )

    def forward(self, x):
        # mlp_weights_init = self.mlp.apply(weights_init)
        return self.mlp(x)

In [28]:


class MaskedAttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, x):
        # print(x.shape)
        batch, block_size, embd_dims = x.shape
        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)
        masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
        masked_values = weights.masked_fill(masked_table[: block_size, : block_size] == 0, float('-inf'))
        weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
        weights_normalized = self.dropout(weights_normalized)
        out = weights_normalized @ v
        return out


In [29]:



class MaskedMHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([MaskedAttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings

    def forward(self, x):
        concat = torch.cat([head(x) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [30]:

#Single Attention Head

class CrossAttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, query, key, value, mask=None):
        # batch, block_size, embd_dims = x.shape

        # masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = query @ torch.transpose(key, dim0=-2, dim1=-1) * (key.shape[-1] ** -0.5)
        if(mask != None):
            mask = mask.unsqueeze(1)
            masked_values = weights.masked_fill(mask == 0, float('-inf'))
            weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ value
            out = self.dropout(out)
            return out
        else:
            weights_normalized = nn.functional.softmax(weights, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ value
            out = self.dropout(out)
            return out

In [31]:
#Single Attention Head

class FullAttentionHead(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.head_size = embeddings_dims // no_of_heads
        self.query = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device, bias=False)
        self.keys = nn.Linear(in_features=embeddings_dims, out_features=self.head_size,device=device, bias=False)
        self.values = nn.Linear(in_features=embeddings_dims, out_features=self.head_size, device=device,bias=False)
        self.dropout = nn.Dropout(p = attn_dropout)


    def forward(self, x, mask=None):
        # batch, block_size, embd_dims = x.shape
        k = self.keys(x)
        q = self.query(x)
        v = self.values(x)
        # masked_table = torch.tril(torch.ones(block_size, block_size, device=device))
        weights = q @ torch.transpose(k, dim0=-2, dim1=-1) * (k.shape[-1] ** -0.5)
        if(mask != None):
            mask = mask.unsqueeze(1)
            masked_values = weights.masked_fill(mask == 0, float('-inf'))
            weights_normalized = nn.functional.softmax(masked_values, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            out = self.dropout(out)
            return out
        else:
            weights_normalized = nn.functional.softmax(weights, dim=-1) #Normalize along the embeddings dimension for all the tokens
            # weights_normalized = self.dropout(weights_normalized)
            out = weights_normalized @ v
            out = self.dropout(out)
            return out

In [32]:

class FullMHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([FullAttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings

    def forward(self, x, mask=None):
        concat = torch.cat([head(x, mask) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [33]:


class CrossMHA(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
    ):
        super().__init__()
        self.heads = nn.ModuleList([CrossAttentionHead(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads) for _ in range(no_of_heads)])
        self.dropout = nn.Dropout(p = attn_dropout)
        self.linear = nn.Linear(in_features=no_of_decoder_layers * embeddings_dims, out_features=embeddings_dims, device=device, bias=False) # 12 (no of heads) * (batch_size) 64 = 768 -> gives out the text embeddings

    def forward(self, query, key, x, mask=None):
        concat = torch.cat([head(query, key, x,  mask) for head in self.heads], dim=-1)
        linear_layer = self.linear(concat)
        out = self.dropout(linear_layer)
        return out

In [34]:
# Decoder Block

class TransformerDecoderBlock(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        dropout = dropout,
        # vocab_size = vocab_size
    ):
        super().__init__()

        self.cross = CrossMHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.masked = MaskedMHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.layer_norm1 = LayerNormalization(embeddings_dims)
        self.layer_norm2 = LayerNormalization(embeddings_dims)
        # self.layer_norm3 = LayerNormalization(embeddings_dims=embeddings_dims)
        self.layer_norm4 = LayerNormalization(embeddings_dims)
        self.mlp_block = MLPBlock(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, key, value, x, mask=None):
        x = self.layer_norm1(x + self.masked(x)) #Very important step -> Layer Norm on input and then passes it to the subsequent blocks
        x = self.layer_norm2(x + self.cross(value, key, x, mask)) #Very important step
        # x = x + self.mha(self.layer_norm1(x))  #Very important step -> Layer Norm on input and then passes it to the subsequent blocks
        x = self.layer_norm4(x + self.mlp_block(x)) #Very important step

        return x

In [35]:
# Decoder Block

class DecoderModel(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        no_of_decoder_layers = no_of_decoder_layers,
        # vocab_size = vocab_size
    ):
        super().__init__()




        self.tgt_text_embds = TgtTextEmbeddings(vocab_size=tgt_vocab_size, embeddings_dims=embeddings_dims)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=tgt_vocab_size, device=device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)
        # self.layer_norm = LayerNormalization(embeddings_dims=embeddings_dims)
        self.decoder_layers = nn.ModuleList([TransformerDecoderBlock(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, dropout=dropout) for _ in range(no_of_decoder_layers)])
        self.apply(self._init_weights)
        # self.positional_embeddings_tgt = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size
        self.positional_embeddings_tgt = PositionEmbeddings()
        # torch.nn.init.normal_(self.positional_embeddings_tgt, mean=0.0, std=0.02)

        # out = self.decoder_layers(query, key, x)
        # Loop through each decoder layer
    def _init_weights(self, module):  #Weight Initialization
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, key, value, x, mask):
        # x = self.tgt_text_embds(x)
        x = x + self.positional_embeddings_tgt(x)
        for decoder_layer in self.decoder_layers:
            x = decoder_layer(key, value, x, mask)
        # x = self.layer_norm(x)

        return x

In [36]:

#Encoder

In [37]:



class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        dropout = dropout,
        mask=None
    ):
        super().__init__()

        self.mha = FullMHA(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads)
        self.layer_norm1 = LayerNormalization(embeddings_dims)
        self.layer_norm2 = LayerNormalization(embeddings_dims)
        self.mlp_block = MLPBlock(dropout=dropout, embeddings_size=embeddings_dims)

    def forward(self, x, mask=None):
        x = self.layer_norm1(x + self.mha(x, mask))
        x = self.layer_norm2(x + self.mlp_block(x))

        return x

In [38]:



class EncoderModel(nn.Module):
    def __init__(
        self,
        attn_dropout = attn_dropout,
        embeddings_dims = embeddings_dims,
        no_of_heads = no_of_heads,
        block_size = block_size,
        dropout = dropout,
        no_of_decoder_layers = no_of_decoder_layers,
        # vocab_size = vocab_size
    ):
        super().__init__()


        # self.positional_embeddings_src = nn.Parameter(torch.randn(1, block_size, embeddings_dims, device=device), requires_grad=True) #To give positional embeddings to each token of the input text, hence num_embeddings=block_size

        self.conv1 = nn.Conv1d(in_channels=n_channels, out_channels=embeddings_dims, kernel_size=kernel_size, device=device, padding=1)
        self.conv2 = nn.Conv1d(in_channels=embeddings_dims, out_channels=embeddings_dims, kernel_size=kernel_size, device=device, padding=1)

        self.positional_embeddings_src = PositionEmbeddings()

        self.encoder_layers = nn.ModuleList([TransformerEncoderBlock(attn_dropout=attn_dropout, embeddings_dims=embeddings_dims, no_of_heads=no_of_heads, dropout=dropout) for _ in range(no_of_decoder_layers)])
        self.apply(self._init_weights)

    def _init_weights(self, module):  #Weight Initialization
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, mask):

        x = self.conv1(x)
        x = torch.nn.functional.gelu(x)
        x = self.conv2(x)
        x = torch.nn.functional.gelu(x)
        # print(x.shape)
        # x = self.src_text_embeds(x)
        # print(self.positional_embeddings_src.shape)
        x = x.permute(0, 2, 1)
        # print(x.shape)
        # print(self.positional_embeddings_src(x).shape)
        x = x + self.positional_embeddings_src(x)

        # print(x.shape)
        # Loop through each encoder layer
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)
        return x



In [39]:


class Transformer(nn.Module):
    def __init__(
        self,

    ):
        super().__init__()

        self.encoder = EncoderModel()
        self.decoder = DecoderModel()
        self.tgt_text_embds = TgtTextEmbeddings(vocab_size=tgt_vocab_size, embeddings_dims=embeddings_dims)
        self.linear_layer = nn.Linear(in_features=embeddings_dims, out_features=tgt_vocab_size, device=device, bias=False) # Takes in logits of dimensions- embeds_dims and converts it into dimension of vocab_size (logits in range of vocab_size)
        # self.src_text_embeds = SrcTextEmbeddings(vocab_size=src_vocab_size, embeddings_dims=embeddings_dims)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # x = self.src_text_embeds(src)
        x = self.encoder(src, src_mask)
        y = self.tgt_text_embds(tgt)
        y = self.decoder(x, x, y, None)
        out = self.linear_layer(y)
        return out



In [40]:
#Instantiating the model
model = Transformer()
model = torch.compile(model)
# model = model.to(device)
model = model.to(device)


In [None]:

!pip install torchinfo
from torchinfo import summary

spec, text = next(iter(train_dataloader))
spec = spec.to(device)
texts = text.to(device)

summary(model=model,
        input_data=(spec, texts),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

In [42]:

# # Optimizer setup and scheduler steup
# out = {"Train": None, "val": None}
optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)

loss_fn = nn.CrossEntropyLoss()


In [43]:
torch.set_float32_matmul_precision('high')

scaler = torch.amp.GradScaler(enabled=True)

In [None]:
model.train()
train_losses =  torch.zeros(len(train_dataloader))
val_losses = torch.zeros(len(val_dataloader))
wandb.init(
    project='Whisper-From-Scratch'
)
for epoch in range(epochs):

    count = 0
    print("Starting train...")

    for X, y in train_dataloader:
      with torch.autocast(device_type=device, dtype=torch.float16):
        X = X.to(device)
        y = y.to(device)
        logits = model(X, y)
        # print(logits.shape)

        batch_size, block_size, vocab = logits.shape
        # print("Va: ", vocab)
        logits = logits.view(batch_size*block_size, vocab)
        targets = y.view(batch_size * block_size)
        # print("HiiiL ", en.shape)
        # print("HiiiT ", logits.shape)
        loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)
        train_losses[count] = loss.item()
        # print("Loss: ", loss.item())

      optimizer.zero_grad()
      scaler.scale(loss).backward()
      # loss.backward()
      # optimizer.step()
      scaler.step(optimizer)
      scaler.update()
      count += 1
        # print()
        # print(count)


    # count = 0
    model.eval()
    count = 0
    print("Starting val...")
    for X, y in val_dataloader:

        X = X.to(device)
        y = y.to(device)
        logits = model(X,y)
        # print(logits.shape)
        batch_size, block_size, vocab = logits.shape

        logits = logits.view(batch_size*block_size, vocab)
        # print("Va: ", vocab)
        targets = y.view(batch_size * block_size)
        loss = nn.functional.cross_entropy(logits, targets, ignore_index=tokenizer.pad_token_id)

        # print("Loss: ", loss.item())
        val_losses[count] = loss.item()

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        count += 1


    # print("eval")
    # print("Generating text...")
    # generated_text = topk_sampling(model, 'Ich fahre heute mit dem Rad zur Schule', de_tokenizer, device=ModelArgs.device, max_length=50, top_k=50, temperature=1.0)

    # print(generated_text)


    model.train()
    wandb.log({
      "Train Loss": train_losses.mean(),
      "Val Loss": val_losses.mean(),
      "epoch": epoch
    })
    print("Epoch: ", epoch, "|", "Train Loss: ", train_losses.mean(),  "|", "Val Loss: ", val_losses.mean())
