### Extract Dataset for colab

In [None]:
# with zipfile.ZipFile("wav.zip", "r") as zip_ref:
#     zip_ref.extractall("/content/audio")

In [None]:
# with zipfile.ZipFile("original_txt.zip", "r") as zip_ref:
#     zip_ref.extractall("/content/transcript")

### Import libraries

In [None]:
import re
import os
import sys
import json
import math
import time
import zipfile
import threading

import fitz
import torch
import torchaudio
import pandas as pd
import numpy as np
import torch.nn as nn
import soundfile as sf
import torch.nn.functional as F
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from torch.amp import autocast, GradScaler
from typing import List, Tuple, Dict
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from transformers import get_cosine_schedule_with_warmup
from multiprocessing import Pool, cpu_count
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict, Counter
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [None]:
%load_ext cython

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
#torchaudio.set_audio_backend("sox_io")

### Check device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

### Tokenizer

#### Base Tokenizer

In [None]:
class BaseTokenizer:
    """
    Base tokenizer which is inherited by different tokenizers.
    This class has all static methods and utils for a tokenizer.
    """

    SPECIAL_TOKENS = {
        "pad": "<PAD>",
        "unk": "<UNK>",
        "bos": "<BOS>",
        "eos": "<EOS>",
        "book_name": "<book_name>",
        "seperator": "</W>"
    }

    def __init__(self, config, **kwargs):
        self.config = config
        self.vocab_size = config.vocab_size
        self.max_merges = config.max_merges

        self.special_tokens = self.SPECIAL_TOKENS
        self.vocab = {}         # Actual vocabulary of the tokenizer
        self.inv_vocab = {}     # Inverse of vocabulary, used for decoding
        self.merges = []        # List of all merges made during vocabulary building
        self.metadata = {       # Tracks tokenizer state for checkpointing
            "merge_count": 0,
            "books_used": None,
        }
        self.padding_token = None

    def save(self, filepath: str):
        """Saves the current state of the tokenizer to the given filepath"""
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump({
                "metadata": self.metadata,
                "vocab": self.vocab,
                "merges": self.merges,
                "special_tokens": self.special_tokens,
            }, f, indent=2)

    def load(self, filepath: str):
        """Loads the tokenizer from a saved file"""
        with open(filepath, "r", encoding="utf-8") as f:
            data = json.load(f)
            self.vocab = data["vocab"]
            self.merges = [tuple(pair) for pair in data["merges"]]
            self.special_tokens = data["special_tokens"]
            self.metadata = data.get("metadata", {})
            self.inv_vocab = {v: k for k, v in self.vocab.items()}

    def decode(self, token_ids: List[int]) -> List[str]:
        """Decodes a list or tensor of token IDs back to string"""
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()

        words = []
        for idx in token_ids:
            token = self.inv_vocab.get(idx, self.special_tokens["unk"])
            if token == self.special_tokens["seperator"]:
                continue
            words.append(token)
        decoded = "".join(words).strip()
        return self.remove_special_tokens(decoded)

    def remove_special_tokens(self, decoded_output: str) -> str:
        """Removes special tokens and cleans up decoded text."""
        return (
            decoded_output
            .replace(self.special_tokens["bos"], "")
            .replace(self.special_tokens["eos"], "")
            .replace(self.special_tokens["pad"], "")
            .replace(self.special_tokens["unk"], "")
            .replace(self.special_tokens["seperator"], " ")
            .strip()
        )

    def add_special_tokens(self, text: str) -> str:
        """
        Adds BOS and EOS tokens, and inserts `</W>` separator between all tokens including special tokens.
        """
        return f"{self.special_tokens['bos']} {text} {self.special_tokens['eos']}"


    def pad(self, token_ids: List[int]) -> List[int]:
        """Pads a list of token IDs to max_length using the pad token"""
        pad_id = self.vocab.get(self.special_tokens["pad"], 0)
        return token_ids + [pad_id] * (config.n_ctx - len(token_ids)) if len(token_ids) < config.n_ctx else token_ids

    def truncate(self, token_ids: List[int], max_length: int) -> List[int]:
        """Truncates a list of token IDs to max_length"""
        return token_ids[:max_length]

    def get_vocab_size(self) -> int:
        """Returns the size of the vocabulary"""
        return len(self.vocab)

    def set_padding_token(self, tokenID: int):
        """Sets the token ID used for padding"""
        self.padding_token = tokenID

    def encode_tensor(self, text) -> torch.Tensor:
        return torch.tensor(self.encode(text), dtype=torch.long)

    def _preprocess_text_to_corpus(self, text: str) -> Counter:
        """
        Converts raw text into a Counter of space-separated character tokens
        with </w> to indicate word boundaries. Used by BPE tokenizers.
        Args:
            text (str): Raw input text string
        Returns:
            Counter: frequency map of pre-tokenized words
        """
        words = text.strip().split()
        return Counter([" ".join(tuple(word)) + SPECIAL_TOKENS["seperator"] for word in words])

    @staticmethod
    def _get_text_from_corpus(corpus, book_names, max_books, max_characters):
        """Combines text from cleaned processed extracted text to be used during tokenizer training"""
        
        if book_names is not None and max_books is not None:
            raise Exception("Provide only one of `book_names` or `max_books`, not both.")
    
        char_accum = 0
        book_count = 0
        combined_text = ""
        books_used = {}
    
        for book, content in clean_text.items():
            if book_names is not None and book not in book_names:
                continue
            if max_books is not None and book_count >= max_books:
                break
            if max_characters is not None and char_accum >= max_characters:
                break
            if max_characters is not None:
                remaining = max_characters - char_accum
                if remaining <= 0:
                    break
                book_slice = content[:remaining]
            else:
                book_slice = content

            combined_text += book_slice
            char_accum += len(book_slice)
            books_used[book] = len(book_slice)
            book_count += 1
        return combined_text, books_used, char_accum, book_count
    
    def base_build_vocab(self,
                         corpus: dict, 
                         folder_path: str,
                         book_names: list = None,
                         max_books: int = None,
                         max_characters: int = None,
                         checkpoint_interval: int = 5000,
                        ):
        """
        Trains tokenizer on a subset of books with flexible stopping conditions.
        Args:
            corpus (dict): Dictionary of book_name -> book_content.
            folder_path (str): Folder path to save tokenizer checkpoints.
            book_names (list): List of specific books to use (exclusive with max_books).
            max_books (int): Number of books to include from clean_text (exclusive with book_names).
            max_characters (int): Stop training after this many characters.
            checkpoint_interval (int): Save checkpoint every N merges.
        """
        combined_text, books_used, char_accum, book_count = self._get_text_from_corpus(corpus, book_names, max_books, max_characters)
        print(f"Total characters used: {char_accum:,}")
        print(f"Training on books: {list(books_used.keys())}")
        print("\n")

        if not os.path.exists(folder_path):
            os.mkdir(folder_path)

    
        self.build_vocab(
            text=combined_text,
            books_used=books_used,
            folder_path=folder_path,
        )

#### Wrapper class for cython tokenizer

In [None]:
class FastSubwordTokenizer(BaseTokenizer):
    def __init__(self, config):
        super().__init__(config)
        self.core = FastSubwordTokenizerCore(self.special_tokens)
        self.vocab_size = config.vocab_size
        self.max_merges = config.max_merges

    def build_vocab(self, text: str, folder_path: str, books_used: dict = None):
        self.metadata.setdefault("merge_count", 0)
        self.metadata["book_used"] = books_used

        self.core.build_vocab(
            text=text,
            checkpoint_interval=self.config.tokenizer_checkpoint_interval,
            folder_path=folder_path,
            books_used=books_used,
            vocab_size=self.vocab_size,
            max_merges=self.max_merges,
            metadata=self.metadata,
        )

        # Safely access exposed Python properties
        self.vocab = self.core.py_vocab
        self.inv_vocab = self.core.py_inv_vocab
        self.merges = self.core.py_merges
        self.metadata = self.core.py_metadata

        # Update config with actual vocab size
        self.vocab_size = len(self.vocab)

    def encode(self, text):
        return self.core.encode(text)

    def decode(self, token_ids):
        return self.core.decode(token_ids)

    def batch_encode(self, texts, num_workers=None):
        return self.core.batch_encode(texts, num_workers)

    def save(self, filepath: str):
        super().save(filepath)
        self.core.save(filepath)

    def load(self, filepath: str):
        super().load(filepath)
        vocab, inv_vocab, merges, metadata = self.core.load_from_file(filepath)
        self.vocab = vocab
        self.inv_vocab = inv_vocab
        self.vocab_size = len(vocab)
        self.merges = merges
        self.metadata = metadata

    def save_pretrained(self, folder):
        os.makedirs(folder, exist_ok=True)
        self.save(os.path.join(folder, "tokenizer.json"))

    @classmethod
    def from_pretrained(cls, config, folder: str = "./FastSubwordTokenizer", tokenizer: str = "tokenizer.json"):
        tok = cls(config)
        tok.load(os.path.join(folder, tokenizer))
        return tok


#### Cython based tokenizer

In [None]:
%%cython
# distutils: language = c++
# cython: boundscheck=False, wraparound=False, initializedcheck=False

from libc.stdlib cimport malloc, free
import re
import os
import json
import time
import sys
import threading
import torch
from collections import Counter
from multiprocessing import cpu_count, Pool


# Utility: Get all adjacent symbol pairs in a list
cdef set get_pairs(list symbols):
    cdef int i
    cdef set result = set()
    for i in range(len(symbols) - 1):
        result.add((symbols[i], symbols[i + 1]))
    return result


cdef class FastSubwordTokenizerCore:
    """
    Core class for fast BPE-based subword tokenization using Cython.
    Handles vocabulary construction, encoding, decoding, and state persistence.
    """
    cdef dict vocab              # symbol -> ID
    cdef dict inv_vocab          # ID -> symbol
    cdef list merges             # list of merge operations (tuples)
    cdef dict special_tokens     # special token mappings
    cdef dict metadata           # metadata for checkpointing and resumption

    def __init__(self, special_tokens=None):
        """
        Initialize tokenizer core with optional special tokens.
        """
        self.vocab = {}
        self.inv_vocab = {}
        self.merges = []
        self.special_tokens = special_tokens if special_tokens is not None else {}
        self.metadata = {
            "merge_count": 0,
            "books_used": None,
        }


    def build_vocab(self, text: str, max_vocab_size=None, max_merges=None,
                    checkpoint_interval=1000, folder_path="./FastSubwordTokenizer", books_used=None, **kwargs):
        """
        Builds the BPE vocabulary from the input text.

        Args:
            text (str): Input text to build vocabulary from.
            max_vocab_size (int): Optional maximum number of vocab entries.
            max_merges (int): Total BPE merge operations to perform.
            checkpoint_interval (int): Save to file after this many merges.
            folder_path (str): File to save tokenizer state.
            book_name (str): Optional identifier for tracking which text is being processed.
        """

        stop_flag = False
        start_time = time.time()
        new_symbol = ""
        
        def status_thread():
            while not stop_flag:
                elapsed = time.time() - start_time
                print(
                f"\rMERGES: {self.metadata['merge_count']:<10}  "
                f"VOCAB SIZE: {len(self.vocab):<10}  "
                f"CURRENT MERGE: {new_symbol:<15}  "  # fixed width 30 chars
                f"EXEC TIME: {elapsed:>7.2f}s",
                end='', flush=True)
                time.sleep(1)
                
        thread = threading.Thread(target=status_thread)
        thread.daemon = True
        thread.start()
        
        corpus = Counter([" ".join(tuple(word)) + self.special_tokens["seperator"] for word in text.strip().split()])
        cdef int idx = len(self.vocab)
        cdef int merge_count = self.metadata.get("merge_count", 0)

        # Initialize special and base vocabulary
        if not self.vocab:
            for token in self.special_tokens.values():
                self.vocab[token] = idx
                idx += 1

        symbols = set()
        for word in corpus:
            symbols.update(word.split())
        for s in symbols:
            if s not in self.vocab:
                self.vocab[s] = idx
                idx += 1

        # Main BPE merge loop
        while True:
            pairs = {}
            for word, freq in corpus.items():
                syms = word.split()
                for pair in get_pairs(syms):
                    pairs[pair] = pairs.get(pair, 0) + freq

            # Filter out pairs that include any special token
            valid_pairs = {pair: freq for pair, freq in pairs.items() 
                           if pair[0] not in self.special_tokens.values() and 
                              pair[1] not in self.special_tokens.values()}
            if not valid_pairs:
                break
            best_pair, freq = max(valid_pairs.items(), key=lambda x: x[1])
            #best_pair, freq = max(pairs.items(), key=lambda x: x[1])

            if not pairs:
                break
            if freq == 1: 
                break
            if max_vocab_size is not None and len(self.vocab) >= max_vocab_size:
                break
            if max_merges is not None and merge_count >= max_merges:
                break

            # Perform merge
            new_symbol = "".join(best_pair)
            self.vocab[new_symbol] = idx
            self.merges.append(best_pair)
            idx += 1

            # Replace all instances of the best pair in the corpus
            bigram = re.escape(" ".join(best_pair))
            pattern = re.compile(rf"(?<!\\S){bigram}(?!\\S)")
            new_corpus = {}
            
            for word in corpus:
                new_word = pattern.sub(new_symbol, word)
                new_corpus[new_word] = corpus[word]
            corpus = new_corpus

            # Update metadata and checkpoint if needed
            merge_count += 1
            self.metadata["merge_count"] = merge_count
            self.metadata["book_used"] = books_used

            if merge_count % checkpoint_interval == 0:
                tokenizer_path = os.path.join(folder_path, f"tokenizer_{merge_count}.json")
                self.save(tokenizer_path)

        stop_flag = True
        thread.join(timeout=2)
        print()
        
        # Finalize reverse vocab
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        
        final_path = os.path.join(folder_path, "tokenizer.json")
        self.save(final_path)
        print(f"FastTokenizer saved as tokenizer.json")


    def encode(self, text: str):
        """
        Encodes input text into a list of token IDs using full-word BPE merge strategy.
        Each word is split into characters with </w> added at the end.
        """
        cdef list token_ids = []
        cdef int i
        cdef str s, new_sym
    
        unk_id = self.vocab.get(self.special_tokens.get("unk", "<UNK>"), -1)
        self.merges = [tuple(m) for m in self.merges]

        cdef list words = text.strip().split()
        for word in words:
            if word in self.special_tokens.values():
                symbols = [word]
            else:
                symbols = list(word) + [self.special_tokens["seperator"]]
    
            # Step 2: apply BPE merges
            protected_tokens = set(self.special_tokens.values())
            while True:
                pairs = get_pairs(symbols)
                pair_to_merge = None
            
                for merge in self.merges:
                    if merge in pairs:
                        if merge[0] in protected_tokens or merge[1] in protected_tokens:
                            continue  # Skip merging if any part is protected
                        pair_to_merge = merge
                        break
                if not pair_to_merge:
                    break
    
                # Perform the merge
                new_sym = "".join(pair_to_merge)
                new_symbols = []
                i = 0
                
                while i < len(symbols):
                    if i < len(symbols) - 1 and (symbols[i], symbols[i + 1]) == pair_to_merge:
                        new_symbols.append(new_sym)
                        i += 2
                    else:
                        new_symbols.append(symbols[i])
                        i += 1
                symbols = new_symbols
    
            # Step 3: convert symbols to token IDs
            for s in symbols:
                token_ids.append(self.vocab.get(s, unk_id))
    
        return torch.tensor(token_ids)

    def batch_encode(self, texts: list[str], num_workers=None) -> list[list[int]]:
        if num_workers is None:
            num_workers = min(cpu_count(), len(texts))
        with Pool(num_workers) as pool:
            return pool.map(self.encode, texts)

    def decode(self, input_ids):
        """
        Decodes a sequence of token IDs back to text.
        Args:
            input_ids (List[int] or torch.Tensor): The sequence of token IDs.
        Returns:
            str: The decoded text string.
        """
        cdef list symbols = []
        cdef int i
        cdef str token

        for i in input_ids:
            token = self.inv_vocab.get(i, self.special_tokens.get("unk", "<UNK>"))
            symbols.append(token)
      
        return "".join(symbols).replace(self.special_tokens['seperator'], " ")

    def save(self, filepath: str):
        """
        Saves the tokenizer state (vocab, merges, metadata) to a JSON file.
        Args:
            filepath (str): File path to save to.
        """
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump({
                "vocab": self.vocab,
                "merges": self.merges,
                "special_tokens": self.special_tokens,
                "metadata": self.metadata
            }, f, indent=2)

    def load(self, vocab, inv_vocab, merges, metadata):
        """
        Loads tokenizer state from existing Python data structures.
        Args:
            vocab (dict): Token-to-ID mapping.
            inv_vocab (dict): ID-to-token mapping.
            merges (list): Merge history.
            metadata (dict): Metadata for training state.
        """
        self.vocab = vocab
        self.inv_vocab = inv_vocab
        self.merges = merges
        self.metadata = metadata

    def load_from_file(self, filepath: str):
        with open(filepath, "r", encoding="utf-8") as f:
            data = json.load(f)
    
        vocab = data["vocab"]
        merges = data["merges"]
        metadata = data.get("metadata", {})
        special_tokens = data.get("special_tokens", {})
        inv_vocab = {v: k for k, v in vocab.items()}
    
        self.load(vocab, inv_vocab, merges, metadata)
        self.special_tokens = special_tokens
    
        # Explicitly return these to caller
        return vocab, inv_vocab, merges, metadata
            
    @property
    def py_vocab(self):
        return self.vocab
    
    @property
    def py_inv_vocab(self):
        return self.inv_vocab
    
    @property
    def py_merges(self):
        return self.merges
    
    @property
    def py_metadata(self):
        return self.metadata
    
    @property
    def vocab(self):
        return self.vocab
    
    @property
    def inv_vocab(self):
        return self.inv_vocab
    
    @property
    def merges(self):
        return self.merges
    
    @property
    def metadata(self):
        return self.metadata

### Wisper

#### Config

In [None]:
class ModelConfig:
    """
    Configuration class for GPT/ASR model hyperparameters.
    """

    def __init__(
        self,
        vocab_size=None,
        n_ctx=448,
        n_embd=256,
        n_layer=4,
        n_head=8,
        layer_norm_epsilon=1e-5,
        enc_dropout=0.1,
        dec_dropout=0.1,
        dropout=0.1,
        self_attn_dropout=0.1,
        lr=1e-4,
        clip_grad_norm=1.0,
        max_merges=None,
        tokenizer_checkpoint_interval=1000,
        use_cache=True,
        gradient_checkpointing=False,
        bos_token_id=None,
        eos_token_id=None,
        pad_token_id=None,
        scheduler_type='cosine',
        warmup_ratio=0.1,
        early_stop_patience=3,
        use_specaugment=True,
        sample_rate=16000,
    ):
        self.dataset_root = "/mnt/E/___COFFIN___/ASR/_datasets_/nptel-pure"
        self.wav_dir = os.path.join(self.dataset_root, "wav")
        self.txt_dir = os.path.join(self.dataset_root, "original_txt")
        self.tokenizer_dir = "./FastSubwordTokenizer"
        self.tokenizer_file = "tokenizer.json"
        self.best_model_path = "./best_model.pt"
        
        self.vocab_size = vocab_size
        self.n_ctx = n_ctx
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.layer_norm_epsilon = layer_norm_epsilon
        self.enc_dropout = enc_dropout
        self.dec_dropout = dec_dropout
        self.dropout = dropout  # ← fixed tuple issue
        self.self_attn_dropout = self_attn_dropout
        self.lr = lr
        self.clip_grad_norm = clip_grad_norm
        self.max_merges = max_merges
        self.tokenizer_checkpoint_interval = tokenizer_checkpoint_interval
        self.use_cache = use_cache
        self.gradient_checkpointing = gradient_checkpointing
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.pad_token_id = pad_token_id
        self.scheduler_type = scheduler_type
        self.warmup_ratio = warmup_ratio
        self.early_stop_patience = early_stop_patience
        self.use_specaugment = use_specaugment
        self.sample_rate = sample_rate
        self.special_tokens = {
            "pad": "<PAD>",
            "unk": "<UNK>",
            "bos": "<BOS>",
            "eos": "<EOS>",
            "seperator": "</W>"
        }

    def set_vocab_size(self, vocab_size: int):
        self.vocab_size = vocab_size

    def set_bos_token_id(self, tokenID: int):
        self.bos_token_id = tokenID

    def set_eos_token_id(self, tokenID: int):
        self.eos_token_id = tokenID

    def set_pad_token_id(self, tokenID: int):
        self.pad_token_id = tokenID

    def to_dict(self):
        return self.__dict__.copy()

    def __repr__(self):
        return f"<ModelConfig n_layer={self.n_layer}, n_head={self.n_head}, vocab_size={self.vocab_size}, n_ctx={self.n_ctx}>"

    def save_json(self, path):
        import json
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def from_json(cls, path):
        import json
        with open(path, 'r') as f:
            return cls(**json.load(f))


#### Helper functions

In [None]:
@torch.no_grad()
def evaluate_model(model, dataloader, tokenizer, num_samples=5):
    model.eval()
    device = next(model.parameters()).device

    pad_id = tokenizer.vocab[tokenizer.special_tokens["pad"]]
    eos_id = tokenizer.vocab[tokenizer.special_tokens["eos"]]
    bos_id = tokenizer.vocab[tokenizer.special_tokens["bos"]]

    for i, (mels, tokens, _, _) in enumerate(dataloader):
        if i >= num_samples:
            break

        mel = mels[0].unsqueeze(0).to(device)
        tokens = tokens[0].to(device)

        raw_ids = tokens.tolist()
        cleaned_ids = [id for id in raw_ids if id not in [pad_id, eos_id, bos_id]]
        true_text = tokenizer.decode(cleaned_ids)
        pred_text = beam_search_decode(model, mel, tokenizer)
        pred_text = tokenizer.remove_special_tokens(pred_text)
        print(f"\nSample {i+1}")
        print(f"Ground Truth: {true_text}")
        print(f"Predicted    : {pred_text}")


@torch.no_grad()
def beam_search_decode(model, mel_input, tokenizer, beam_width=5, max_len=20):
    device = mel_input.device
    bos_token = tokenizer.vocab[tokenizer.special_tokens["bos"]]
    eos_token = tokenizer.vocab[tokenizer.special_tokens["eos"]]

    sequences = [[torch.tensor([bos_token], device=device), 0.0]]

    for _ in range(max_len):
        all_candidates = []
        for seq, score in sequences:
            logits = model(mel_input, seq.unsqueeze(0))  # (1, T, V)
            probs = F.log_softmax(logits[:, -1, :], dim=-1)  # (1, V)
            topk_probs, topk_ids = torch.topk(probs, beam_width)

            for i in range(beam_width):
                token = topk_ids[0, i].item()
                new_seq = torch.cat([seq, torch.tensor([token], device=device)])
                new_score = score + topk_probs[0, i].item()
                all_candidates.append((new_seq, new_score))

        sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]
        if all(seq[0][-1].item() == eos_token for seq in sequences):
            break

    return tokenizer.decode(sequences[0][0].tolist())


def train_for_epochs(model, dataloader, optimizer, tokenizer, num_epochs):
    pad_id = tokenizer.vocab[tokenizer.special_tokens["pad"]]
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id, label_smoothing=0.1)

#### Base Model

In [None]:
class BaseTfModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Initialize token embeddings
        self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)

        # Initialize LM head
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Tie weights initially
        self.tie_weights()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, new_emb):
        self.embed_tokens = new_emb
        self.tie_weights()

    def resize_token_embeddings(self, new_vocab_size):
        old_embedding = self.embed_tokens
        self.embed_tokens = nn.Embedding(new_vocab_size, old_embedding.embedding_dim)
        with torch.no_grad():
            self.embed_tokens.weight[:old_embedding.num_embeddings] = old_embedding.weight
        self.config.vocab_size = new_vocab_size
        self.tie_weights()

    def tie_weights(self):
        if hasattr(self, "lm_head") and hasattr(self, "embed_tokens"):
            self.lm_head.weight = self.embed_tokens.weight

    def save_pretrained(self, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        torch.save(self.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
        with open(os.path.join(save_dir, "config.json"), "w") as f:
            json.dump(self.config.__dict__, f, indent=2)

    @classmethod
    def from_pretrained(cls, load_dir: str, model_name: str=None):
        with open(os.path.join(load_dir, "config.json")) as f:
            config_dict = json.load(f)
        config = GPT2Config(**config_dict)
        model = cls(config)
        if model_name is None:
            model_name = "pytorch_model.bin"
        model.load_state_dict(torch.load(os.path.join(load_dir, model_name), map_location="cpu"))
        return model


    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens=20):
        self.eval()
        for _ in tqdm(range(max_new_tokens), desc="Generating"):
            logits = self.forward(input_ids)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True).data[0]
            input_ids = torch.cat([input_ids, next_token], dim=-1)
        return input_ids

    @torch.no_grad()
    def evaluate_loss(self, dataloader):
        self.eval()
        total_loss = 0
        for x, y in dataloader:
            x, y = x.to(self.model_device), y.to(self.model_device)
            logits = self(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
            total_loss += loss.item()
        return total_loss / len(dataloader)

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

    @torch.no_grad()
    def run_validation(self, dataloader, tokenizer, criterion, num_samples=2):
        self.eval()
        device = self.model_device
        total_loss = 0
    
        pad_id = tokenizer.vocab[tokenizer.special_tokens["pad"]]
        eos_id = tokenizer.vocab[tokenizer.special_tokens["eos"]]
        bos_id = tokenizer.vocab[tokenizer.special_tokens["bos"]]
    
        for i, (mels, tokens, _, _) in enumerate(dataloader):
            mels, tokens = mels.to(device), tokens.to(device)
            logits = self(mels, tokens[:, :-1])
            loss = criterion(logits.view(-1, logits.size(-1)), tokens[:, 1:].reshape(-1))
            total_loss += loss.item()
    
            if i < num_samples:
                raw_ids = tokens[0].tolist()
                cleaned_ids = [id for id in raw_ids if id not in [pad_id, eos_id, bos_id]]
                true_text = tokenizer.decode(cleaned_ids)
                pred_text = self.beam_search(mels[0].unsqueeze(0), tokenizer)
                print(f"\nSample {i+1}")
                print(f"Ground Truth: {true_text}")
                print(f"Predicted    : {pred_text}")
    
        return total_loss / len(dataloader)


    @property
    def device(self):
        return str(next(self.parameters()).device)

    @property
    def supports_gradient_checkpointing(self):
        return getattr(self.config, "gradient_checkpointing", False)

#### Modules

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.n_head = config.n_head
        self.head_dim = config.n_embd // config.n_head
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)

    def forward(self, x, context=None, mask=None):
        B, T, C = x.size()
        context = x if context is None else context

        q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k_proj(context).view(B, -1, self.n_head, self.head_dim).transpose(1, 2)
        v = self.v_proj(context).view(B, -1, self.n_head, self.head_dim).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            att = att.masked_fill(mask == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        out = att @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.fc2 = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))


class WhisperEncoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class WhisperDecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.self_attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.cross_attn = MultiHeadAttention(config)
        self.ln3 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = MLP(config)

    def forward(self, x, encoder_output, self_mask=None):
        x = x + self.self_attn(self.ln1(x), mask=self_mask)
        x = x + self.cross_attn(self.ln2(x), context=encoder_output)
        x = x + self.mlp(self.ln3(x))
        return x


class WhisperModel(BaseTfModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # Encoder
        self.encoder_mel_proj = nn.Linear(80, config.n_embd)
        self.encoder_layers = nn.ModuleList([
            WhisperEncoderBlock(config) for _ in range(config.n_layer)
        ])

        # Decoder
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.n_ctx, config.n_embd))
        self.decoder_layers = nn.ModuleList([
            WhisperDecoderBlock(config) for _ in range(config.n_layer)
        ])

        # Final
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self._tie_weights()

    def _tie_weights(self):
        self.lm_head.weight = self.token_emb.weight

    def forward(self, mel_input, tgt_ids):
        device = mel_input.device
        B, T = tgt_ids.size()

        max_T = self.pos_emb.shape[1]
        if T > max_T:
            tgt_ids = tgt_ids[:, :max_T]
            T = max_T

        # Encode
        x = self.encoder_mel_proj(mel_input)
        for layer in self.encoder_layers:
            x = layer(x)
        encoder_out = x

        # Decode
        tok_embed = self.token_emb(tgt_ids.long()) + self.pos_emb[:, :T, :]
        mask = torch.tril(torch.ones(T, T, device=device)).bool()
        x = tok_embed

        for layer in self.decoder_layers:
            x = layer(x, encoder_out, self_mask=mask)

        x = self.ln_f(x)
        return self.lm_head(x)

    def generate(self, mel_input, tokenizer, max_new_tokens=100):
        device = mel_input.device
        bos_token_id = tokenizer.vocab[tokenizer.special_tokens["bos"]]
        eos_token_id = tokenizer.vocab[tokenizer.special_tokens["eos"]]

        input_ids = torch.tensor([[bos_token_id]], device=device)

        with torch.no_grad():
            # Encode once
            x = self.encoder_mel_proj(mel_input)
            for layer in self.encoder_layers:
                x = layer(x)
            encoder_out = x

            for _ in range(max_new_tokens):
                tok_embed = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
                mask = torch.tril(torch.ones(input_ids.size(1), input_ids.size(1), device=device)).bool()
                x = tok_embed

                for layer in self.decoder_layers:
                    x = layer(x, encoder_out, self_mask=mask)

                logits = self.lm_head(self.ln_f(x))
                next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
                input_ids = torch.cat([input_ids, next_token], dim=-1)
                if next_token.item() == eos_token_id:
                    break

        return tokenizer.decode(input_ids[0].tolist())

### Data Preprocessing

#### Dataset Class

In [None]:
class ASRDataset(Dataset):
    def __init__(self, root_dir, tokenizer, config, split="train"):
        """
        Args:
            root_dir (str): Path to root of dataset (e.g., 'nptel-pure')
            tokenizer (Tokenizer): Your tokenizer instance
            config (ModelConfig): Must have sample_rate, n_ctx, use_specaugment
            split (str): "train" or "val" (controls augmentation)
        """
        self.wav_dir = os.path.join(root_dir, "wav")
        self.txt_dir = os.path.join(root_dir, "original_txt")
        self.tokenizer = tokenizer
        self.config = config
        self.sample_rate = config.sample_rate
        self.max_len = config.n_ctx
        self.augment = (split == "train" and getattr(config, "use_specaugment", False))

        self.items = sorted([
            f for f in os.listdir(self.wav_dir)
            if f.endswith(".wav") and os.path.exists(os.path.join(self.txt_dir, f.replace(".wav", ".txt")))
        ])

        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=400,
            win_length=400,
            hop_length=160,
            n_mels=80
        )

    @staticmethod
    def truncate_to_eos(token_ids: torch.Tensor, eos_id: int):
        eos_positions = (token_ids == eos_id).nonzero(as_tuple=True)[0]
        if len(eos_positions) > 0:
            token_ids = token_ids[:eos_positions[0] + 1]
        return token_ids

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        wav_file = self.items[idx]
        txt_file = wav_file.replace(".wav", ".txt")

        # ✅ Load and resample audio
        waveform, sr = safe_load(os.path.join(self.wav_dir, wav_file))
        if sr != self.sample_rate:
            resample = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform = resample(waveform)

        # ✅ Add white noise if augmenting
        if self.augment:
            waveform = waveform + 0.002 * torch.randn_like(waveform)

        # ✅ Compute Mel spectrogram: (time, 80)
        mel = self.mel_transform(waveform).squeeze(0).transpose(0, 1)

        # ✅ Load and tokenize text
        with open(os.path.join(self.txt_dir, txt_file), "r") as f:
            text = f.read().strip().upper()       
            text = self.tokenizer.add_special_tokens(text)
        token_tensor = self.tokenizer.encode(text)
        
        # eos_id = self.tokenizer.vocab[self.tokenizer.special_tokens["eos"]]
        # token_tensor = self.truncate_to_eos(token_tensor, eos_id)
        # token_tensor = self.tokenizer.pad(token_tensor.tolist())[:self.max_len]
        # token_tensor = torch.tensor(token_tensor, dtype=torch.long)

        return mel, token_tensor

#### Create dataloader

In [None]:
def make_collate_fn(tokenizer):
    def collate_fn(batch):
        mels, tokens = zip(*batch)
        mel_lens = [mel.shape[0] for mel in mels]
        max_mel_len = max(mel_lens)
        padded_mels = torch.stack([
            F.pad(mel, (0, 0, 0, max_mel_len - mel.shape[0])) for mel in mels
        ])
        pad_id = tokenizer.vocab[tokenizer.special_tokens["pad"]]
        padded_tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=pad_id)
        token_lens = [len(t) for t in tokens]
        return padded_mels, padded_tokens, mel_lens, token_lens
    return collate_fn
    

def prepare_dataloaders(dataset, batch_size=2):
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_indices = list(range(train_size))
    val_indices = list(range(train_size, train_size + val_size))

    train_dataset = Subset(
        ASRDataset(dataset.wav_dir.rsplit("/", 1)[0], dataset.tokenizer, dataset.config, split="train"),
        train_indices
    )
    val_dataset = Subset(
        ASRDataset(dataset.wav_dir.rsplit("/", 1)[0], dataset.tokenizer, dataset.config, split="val"),
        val_indices
    )

    collate = make_collate_fn(dataset.tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)

    return train_loader, val_loader

#### Soundfile util

In [None]:
def safe_load(filepath):
    try:
        return torchaudio.load(filepath)
    except RuntimeError:
        print(f"[Fallback] Using soundfile for: {filepath}")
        data, sr = sf.read(filepath)
        return torch.from_numpy(data).unsqueeze(0).float(), sr

### Preprocessor

In [None]:
class WhisperPreprocessor:
    def __init__(self, config):
        self.sample_rate = config.sample_rate
        self.n_mels = config.n_mels if hasattr(config, "n_mels") else 80
        self.n_fft = config.n_fft if hasattr(config, "n_fft") else 400
        self.hop_length = config.hop_length if hasattr(config, "hop_length") else 160
        self.win_length = config.win_length if hasattr(config, "win_length") else 400

        self.mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_fft=self.n_fft,
            win_length=self.win_length,
            hop_length=self.hop_length,
            n_mels=self.n_mels,
            center=True,
            power=2.0,
            normalized=False,
        )

        self.db_transform = torchaudio.transforms.AmplitudeToDB(top_db=80)

    def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        mel = self.mel_spec(waveform)
        mel_db = self.db_transform(mel)
        return mel_db.squeeze(0).transpose(0, 1)  # [T, 80] for compatibility


### Create config

In [None]:
config = ModelConfig()

### Train Tokenizer

In [None]:
fast_tokenizer = FastSubwordTokenizer.from_pretrained(
    config, 
    folder="/mnt/E/___COFFIN___/ASR/FastSubwordTokenizer"
)

In [None]:
# corpus_text = ""

# for file in sorted(os.listdir(config.txt_dir)):
#     with open(os.path.join(config.txt_dir, file), "r") as f:
#         corpus_text += config.special_tokens['bos'] + f.read().strip().upper() + " "

In [None]:
# corpus_text

In [None]:
# fast_tokenizer.build_vocab(text=corpus_text, folder_path=config.tokenizer_dir)

In [None]:
config.set_vocab_size(fast_tokenizer.get_vocab_size())
config.set_pad_token_id(fast_tokenizer.vocab[fast_tokenizer.special_tokens["pad"]])
config.set_bos_token_id(fast_tokenizer.vocab[fast_tokenizer.special_tokens["bos"]])
config.set_eos_token_id(fast_tokenizer.vocab[fast_tokenizer.special_tokens["eos"]])

### Instantiate and test model

#### Create model

In [None]:
model = WhisperModel(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)

In [None]:
model

#### Test model

In [None]:
preprocessor = WhisperPreprocessor(config)

In [None]:
waveform, sr = torchaudio.load("/mnt/E/___COFFIN___/ASR/_datasets_/nptel-pure/wav/000a01ea126c4b59aa992d992a1a8e1076a53bf7755eb6a0f80a2cdb.wav")
if sr != config.sample_rate:
    waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=config.sample_rate)(waveform)

mel = preprocessor(waveform).unsqueeze(0).to(model.device)  # [1, T, 80]

# Generate tokens
generated = model.generate(mel_input=mel, tokenizer=fast_tokenizer, max_new_tokens=50)
predicted_text = fast_tokenizer.remove_special_tokens(generated)

# Load ground truth
txt_path = os.path.join("/mnt/E/___COFFIN___/ASR/_datasets_/nptel-pure/original_txt","000a01ea126c4b59aa992d992a1a8e1076a53bf7755eb6a0f80a2cdb.txt")
with open(txt_path, "r") as f:
    ground_truth = f.read().strip().upper()

# Show results
print("\nGround Truth:")
print(ground_truth)
print("\nPredicted:   ")
print(predicted_text)

### Model training

#### Create Dataloaders

In [None]:
dataset = ASRDataset(config.dataset_root, fast_tokenizer, config, split="train")
train_loader, val_loader = prepare_dataloaders(dataset, batch_size=4)

#### Test dataloaders

In [None]:
mel, tokens = dataset[0]
print(f"Mel shape    : {mel.shape}")      # Expected: [T, 80]
print(f"Token shape  : {tokens.shape}")   # Expected: [<= config.n_ctx]
print(f"Token IDs    : {tokens}")    # First 10 token IDs

decode = fast_tokenizer.decode(tokens)
print(decode)
print()
print(fast_tokenizer.remove_special_tokens(decode))

In [None]:
# Get one batch from train loader
for mels, tokens, mel_lens, token_lens in train_loader:
    print("Mel shape:   ", mels.shape)   # [B, T, 80]
    print("Tokens shape:", tokens.shape) # [B, L]

    input_ids = tokens[:, :-1]
    target_ids = tokens[:, 1:]

    print("Input IDs shape:  ", input_ids.shape)
    print("Target IDs shape: ", target_ids.shape)

    # Decode input and target for sample 0
    print("\nDecoded Input IDs:")
    print(fast_tokenizer.decode(input_ids[0].tolist()))
    print("\nDecoded Target IDs:")
    print(fast_tokenizer.decode(target_ids[0].tolist()))

    break  # Only one batch

#### Train function

In [None]:
def train_for_epochs_with_val(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    tokenizer,
    num_epochs=12,
    clip=1.0,
    early_stop_patience=3
):
    pad_id = tokenizer.vocab[tokenizer.special_tokens["pad"]]
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_id, label_smoothing=0.1)
    device = next(model.parameters()).device
    scaler = GradScaler()

    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_train_loss = 0
        start = time.time()

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}", dynamic_ncols=True)
        for mels, tokens, _, _ in pbar:
            mels, tokens = mels.to(device), tokens.to(device)
            input_ids = tokens[:, :-1]
            target_ids = tokens[:, 1:]

            with autocast(device_type=model.device):
                logits = model(mels, input_ids)

                # Match target length with logits
                logits_len = logits.size(1)
                target_ids = target_ids[:, :logits_len]

                loss = loss_fn(logits.reshape(-1, logits.size(-1)), target_ids.reshape(-1).long())

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip)
            scaler.step(optimizer)
            scaler.update()

            total_train_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for mels, tokens, _, _ in val_loader:
                mels, tokens = mels.to(device), tokens.to(device)
                input_ids = tokens[:, :-1]
                target_ids = tokens[:, 1:]

                logits = model(mels, input_ids)

                # Match target length with logits
                logits_len = logits.size(1)
                target_ids = target_ids[:, :logits_len]

                val_loss = loss_fn(logits.reshape(-1, logits.size(-1)), target_ids.reshape(-1).long())
                total_val_loss += val_loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        scheduler.step(avg_val_loss)

        print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Time: {time.time() - start:.2f}s")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), "best_model.pt")
            print("Saved best model")
        else:
            patience_counter += 1
            if patience_counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch} (no improvement for {early_stop_patience} epochs)")
                break

#### Train model

In [None]:
train_for_epochs_with_val(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    tokenizer=fast_tokenizer,
    num_epochs=100,
    clip=1.0,
    early_stop_patience=3,
)

#### Evaluate

In [None]:
waveform, sr = torchaudio.load("/mnt/E/___COFFIN___/ASR/_datasets_/nptel-pure/wav/000a01ea126c4b59aa992d992a1a8e1076a53bf7755eb6a0f80a2cdb.wav")
if sr != config.sample_rate:
    waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=config.sample_rate)(waveform)

mel = preprocessor(waveform).unsqueeze(0).to(model.device)  # [1, T, 80]

# Generate tokens
generated = model.generate(mel_input=mel, tokenizer=fast_tokenizer, max_new_tokens=50)
predicted_text = fast_tokenizer.remove_special_tokens(generated)

# Load ground truth
txt_path = os.path.join("/mnt/E/___COFFIN___/ASR/_datasets_/nptel-pure/original_txt","000a01ea126c4b59aa992d992a1a8e1076a53bf7755eb6a0f80a2cdb.txt")
with open(txt_path, "r") as f:
    ground_truth = f.read().strip().upper()

# Show results
print("\nGround Truth:")
print(ground_truth)
print("\nPredicted:   ")
print(predicted_text)

#### Test

In [None]:
text = "I AM FINE"
print("Encoded:", fast_tokenizer.encode(text))
print("Decoded:", fast_tokenizer.decode(fast_tokenizer.encode(text).tolist()))