In [None]:
!pip install torchmetrics

In [None]:
import re
import math
import torch
import random
import torch.nn as nn
from tqdm import tqdm
from google.colab import drive
from tokenizers import Tokenizer
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split


drive.mount("/content/drive")
WORKING_DIR = "/content/drive/My Drive/Text-Generation"

In [None]:
torch.manual_seed(3000)
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    torch.cuda.manual_seed_all(3000)
else:
    DEVICE = torch.device('cpu')
random.seed(3000)

VOCAB_SIZE=25000
BATCH_SIZE = 64
EPOCHS = 100
INIT_LR = 2e-04
SEQ_LEN = 52
D_MODEL = 512
N_BLOCKS = 6
HEADS = 16
DROPOUT = 0.1
DFF = 2048
MODELS_FOLDER = f"{WORKING_DIR}/models"
PRELOAD_MODEL_FILEPATH = ""
TOKENIZER_FILEPATH = f"{WORKING_DIR}/tokenizers/amharic-bpe-tokenizer-v1-{VOCAB_SIZE // 1000}k.json"
TB_LOG_DIR = "logs/gpt_model"
DATASET_PATH = f"{WORKING_DIR}/data/amharic-texts.txt"

In [None]:
class PreprocessingPipeline(ABC):   
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer
    
    @abstractmethod
    def preprocess(self, text: str, encode=True) -> str:
        pass

    
class AmharicPreprocessor(PreprocessingPipeline):
    def __init__(self, tokenizer: Tokenizer) -> None:
        super().__init__(tokenizer)
    
    def preprocess(self, text: str, encode=True) -> str:
        # Character level mismatch
        text = self.normalize_char_level_missmatch(text)
        
        # Replace commonly used abbreviations
        text = self.normalize_abbreviations(text)
        
        if encode:
            return self.tokenizer.encode(
                text,
            ).ids
        else:
            return text
            
    # Remove abbreviations
    def normalize_abbreviations(self, text: str) -> str:
        common_amharic_abbreviations = {
            "ት/ቤት": "ትምህርት ቤት",
            "ት/ርት": "ትምህርት",
            "ት/ክፍል": "ትምህርት ክፍል",
            "ሃ/አለቃ": "ሃምሳ አለቃ",
            "ሃ/ስላሴ": "ሃይለ ስላሴ",
            "ደ/ዘይት": "ደብረ ዘይት",
            "ደ/ታቦር": "ደብረ ታቦር",
            "መ/ር": "መምህር",
            "መ/ቤት": "መስሪያ ቤት",
            "መ/አለቃ": "መቶ አለቃ",
            "ክ/ከተማ": "ክፍለ ከተማ",
            "ክ/ሀገር": "ክፍለ ሀገር",
            "ወ/ር": "",
            "ወ/ሮ": "ወይዘሮ",
            "ወ/ሪት": "ወይዘሪት",
            "ወ/ስላሴ": "ወልደ ስላሴ",
            "ፍ/ስላሴ": "ፍቅረ ስላሴ",
            "ፍ/ቤት": "ፍርድ ቤት",
            "ጽ/ቤት": "ጽህፈት ቤት",
            "ሲ/ር": "",
            "ፕ/ር": "ፕሮፌሰር",
            "ጠ/ሚንስትር": "ጠቅላይ ሚኒስተር",
            "ጠ/ሚ": "ጠቅላይ ሚኒስተር",
            "ዶ/ር": "ዶክተር",
            "ገ/ገዮርጊስ": "ገብረ ገዮርጊስ",
            "ቤ/ክርስትያን": "ቤተ ክርስትያን",
            "ም/ስራ": "",
            "ም/ቤት": "ምክር ቤተ",
            "ተ/ሃይማኖት": "ተክለ ሃይማኖት",
            "ሚ/ር": "ሚኒስትር",
            "ኮ/ል": "ኮሎኔል",
            "ሜ/ጀነራል": "ሜጀር ጀነራል",
            "ብ/ጀነራል": "ብርጋደር ጀነራል",
            "ሌ/ኮለኔል": "ሌተናንት ኮለኔል",
            "ሊ/መንበር": "ሊቀ መንበር",
            "አ/አ": "ኣዲስ ኣበባ",
            "አ.አ": "ኣዲስ ኣበባ",
            "ር/መምህር": "ርዕሰ መምህር",
            "ፕ/ት": "",
            "ዓም": "ዓመተ ምህረት",
            "ዓ.ዓ": "ዓመተ ዓለም",
        }
        for key in common_amharic_abbreviations:
            regex = rf'\b{re.escape(key)}\b'
            text = re.sub(regex, common_amharic_abbreviations[key], text)

        # Remove punctuation, numbers, and extra spaces
        text = re.sub(r'[.\?"\',/#!$%^&*;:፤።{}=\-_`~()፩፪፫፬፭፮፮፰፱፲፳፴፵፵፷፸፹፺፻01-9]', ' ', text)
        text = re.sub(r'\s{2,}', ' ', text)

        return text

    #method to normalize character level missmatch such as ጸሀይ and ፀሐይ
    def normalize_char_level_missmatch(self, text: str) -> str:
        rep1=re.sub('[ሃኅኃሐሓኻ]','ሀ',text)
        rep2=re.sub('[ሑኁዅ]','ሁ',rep1)
        rep3=re.sub('[ኂሒኺ]','ሂ',rep2)
        rep4=re.sub('[ኌሔዄ]','ሄ',rep3)
        rep5=re.sub('[ሕኅ]','ህ',rep4)
        rep6=re.sub('[ኆሖኾ]','ሆ',rep5)
        rep7=re.sub('[ሠ]','ሰ',rep6)
        rep8=re.sub('[ሡ]','ሱ',rep7)
        rep9=re.sub('[ሢ]','ሲ',rep8)
        rep10=re.sub('[ሣ]','ሳ',rep9)
        rep11=re.sub('[ሤ]','ሴ',rep10)
        rep12=re.sub('[ሥ]','ስ',rep11)
        rep13=re.sub('[ሦ]','ሶ',rep12)
        rep14=re.sub('[ዓኣዐ]','አ',rep13)
        rep15=re.sub('[ዑ]','ኡ',rep14)
        rep16=re.sub('[ዒ]','ኢ',rep15)
        rep17=re.sub('[ዔ]','ኤ',rep16)
        rep18=re.sub('[ዕ]','እ',rep17)
        rep19=re.sub('[ዖ]','ኦ',rep18)
        rep20=re.sub('[ጸ]','ፀ',rep19)
        rep21=re.sub('[ጹ]','ፁ',rep20)
        rep22=re.sub('[ጺ]','ፂ',rep21)
        rep23=re.sub('[ጻ]','ፃ',rep22)
        rep24=re.sub('[ጼ]','ፄ',rep23)
        rep25=re.sub('[ጽ]','ፅ',rep24)
        rep26=re.sub('[ጾ]','ፆ',rep25)
        #Normalizing words with Labialized Amharic characters such as በልቱዋል or  በልቱአል to  በልቷል  
        rep27=re.sub('(ሉ[ዋአ])','ሏ',rep26)
        rep28=re.sub('(ሙ[ዋአ])','ሟ',rep27)
        rep29=re.sub('(ቱ[ዋአ])','ቷ',rep28)
        rep30=re.sub('(ሩ[ዋአ])','ሯ',rep29)
        rep31=re.sub('(ሱ[ዋአ])','ሷ',rep30)
        rep32=re.sub('(ሹ[ዋአ])','ሿ',rep31)
        rep33=re.sub('(ቁ[ዋአ])','ቋ',rep32)
        rep34=re.sub('(ቡ[ዋአ])','ቧ',rep33)
        rep35=re.sub('(ቹ[ዋአ])','ቿ',rep34)
        rep36=re.sub('(ሁ[ዋአ])','ኋ',rep35)
        rep37=re.sub('(ኑ[ዋአ])','ኗ',rep36)
        rep38=re.sub('(ኙ[ዋአ])','ኟ',rep37)
        rep39=re.sub('(ኩ[ዋአ])','ኳ',rep38)
        rep40=re.sub('(ዙ[ዋአ])','ዟ',rep39)
        rep41=re.sub('(ጉ[ዋአ])','ጓ',rep40)
        rep42=re.sub('(ደ[ዋአ])','ዷ',rep41)
        rep43=re.sub('(ጡ[ዋአ])','ጧ',rep42)
        rep44=re.sub('(ጩ[ዋአ])','ጯ',rep43)
        rep45=re.sub('(ጹ[ዋአ])','ጿ',rep44)
        rep46=re.sub('(ፉ[ዋአ])','ፏ',rep45)
        rep47=re.sub('[ቊ]','ቁ',rep46) #ቁ can be written as ቊ
        rep48=re.sub('[ኵ]','ኩ',rep47) #ኩ can be also written as ኵ  
        
        return rep48

    #replacing any existance of special character or punctuation to null  
    def remove_punc_and_special_chars(self, text: str) -> str: # puct in amh =፡።፤;፦፧፨፠፣ 
        normalized_text = re.sub('[\!\@\#\$\%\^\«\»\&\*\(\)\…\[\]\{\}\;\“\”\›\’\‘\"\'\:\,\.\‹\/\<\>\?\\\\|\`\´\~\-\=\+\፡\።\፤\;\፦\፥\፧\፨\፠\፣]', '', text) 
        return normalized_text

    #remove all ascii characters and Arabic and Amharic numbers
    def remove_ascii_and_numbers(self, text: str) -> str:
        rm_num_and_ascii=re.sub('[A-Za-z0-9]','',text)
        return re.sub('[^\u1200-\u137F\s]+','',rm_num_and_ascii)


In [None]:
class TextDataset(Dataset):
    def __init__(self, texts: list[str], tokenizer: Tokenizer) -> None:
        super().__init__()
        self.texts: list[str] = texts

        self.tokenizer = tokenizer
        self.preprocessor = AmharicPreprocessor(tokenizer)
        
        # (1,)
        self.pad_token = torch.tensor([self.tokenizer.token_to_id("[PAD]")], dtype=torch.int64)

        
    def __len__(self):
        return len(self.texts)
    
    def batch_iterator(self, batch_size: int) -> DataLoader:
        return DataLoader(self, batch_size, shuffle=True)

    @staticmethod
    def lookback_mask(size: int) -> torch.Tensor:
        # Lower triangular matrix
        # [[
        #   [1, 0, ... , 0],
        #   [1, 1, ... , 0],
        #   [1, 1, ... , 0],
        #   [1, 1, ... , 1]
        # ]] 
        # 1 x size x size
        return torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int) == 0
    
    def shift_left(self, list: list[str]) -> list[str]:
        return [list[i] for i in range(1, len(list))] + [self.tokenizer.token_to_id("[UNK]")]
    
    def __getitem__(self, index) -> dict:
        token_ids = self.preprocessor.preprocess(self.texts[index])
        padding = SEQ_LEN - len(token_ids)
       
        # (seq_len,)
        decoder_input = torch.concat([
            # (len(token_ids),)
            torch.tensor(token_ids, dtype=torch.int64),

            # (padding,)
            torch.tensor([self.pad_token] * padding, dtype=torch.int64)
        ])                    
        
        # (seq_len,)
        label = torch.concat([
            # (len(token_ids),)
            torch.tensor(self.shift_left(token_ids), dtype=torch.int64),

            # (padding,)
            torch.tensor([self.pad_token] * padding, dtype=torch.int64)
        ])     
        
        return {
            # (seq_len,)
            "decoder_input": decoder_input,
                            
            # (seq_len,) != (1,) --> (seq_len,) --> (1, 1, seq_len) --> (1, seq_len) & (1, seq_len, seq_len) --> (1, seq_len, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & self.lookback_mask(SEQ_LEN),  

            # (seq_len,)         
            "label": label
        }

In [None]:
class WordEmbedding(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding: nn.Embedding = nn.Embedding(vocab_size, D_MODEL)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch, seq_len, 1) --> (batch, seq_len, d_model)
        return self.embedding.forward(x) * math.sqrt(D_MODEL)
    
    
class PositionEncoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()        
        self.dropout = nn.Dropout(DROPOUT)
        
        # (seq_len, d_model)
        pe = torch.zeros(SEQ_LEN, D_MODEL)
        
        # (seq_len, 1)
        pos = torch.arange(0, SEQ_LEN, dtype=torch.float).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, D_MODEL, 2).float() * -(math.log(10000.0) / D_MODEL))

        # PE(pos, 2i) = sin(pos / (10000 ^ (2i/d_model)))
        pe[:, 0::2] = torch.sin(pos * div_term)

        # PE(pos, 2i + 1) = cos(pos / (10000 ^ (2i/d_model)))
        pe[:, 1::2] = torch.cos(pos * div_term)
        
        # (1, seq_len, d_model)
        pe = pe.unsqueeze(0) 
        
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        assert x.shape[1] <= SEQ_LEN, f"Input sequence length exceeds the position encoder's max sequence length  `{SEQ_LEN}`"

        # (batch, seq_len, d_model) + (1, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.dropout(x + self.pe[:, :x.shape[1], :].requires_grad_(False))
    
    
class FeedForwardBlock(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(D_MODEL, DFF).to(DEVICE)
        self.dropout = nn.Dropout(DROPOUT)
        self.linear_2 = nn.Linear(DFF, D_MODEL).to(DEVICE)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch, seq_len, d_model) -> (batch, seq_len, dff) -> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
    
    
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self) -> None:
        assert D_MODEL % HEADS == 0, "d_model is not divisible by heads"
        
        super().__init__()
        
        self.d_k = D_MODEL // HEADS
        
        self.W_q = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)
        self.W_k = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)
        self.W_v = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)

        self.W_o = nn.Linear(D_MODEL, D_MODEL, bias=False).to(DEVICE)
        self.dropout = nn.Dropout(DROPOUT)
    
    @staticmethod
    def attention(
        # (batch, head, seq_len, d_k)
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, 
        dropout: nn.Dropout=None, 
        mask: torch.Tensor=None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        d_k = query.shape[-1]
        
        # (batch, head, seq_len, d_k) @ (batch, head, d_k, seq_len) --> (batch, head, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)        

        # The mask passed has two components:
        # 1. A lookback mask that makes sure the output at a certain position can only depend on the tokens on from previous positions. (USED ONLY ON THE DECODER)
        # 2. An ignore mask so that attention score for the padding token [PAD] is zero. (USED BOTH ON THE DECODER AND THE ENCODER)
        # If a mask is passed then some of the attention scores are set to zero based on the mask.
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e09)
            
        # (batch, head, seq_len, seq_len) which applies softmax to the last dimension
        # so that the sum of the probabilities along this dimension equals 1
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        # (batch, head, seq_len, seq_len) @ (batch, head, seq_len, d_k) --> (batch, head, seq_len, d_k)
        return (attention_scores @ value), attention_scores
    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        # (batch, seq_len, d_model) @ (d_model, d_model) --> (batch, seq_len, d_model)
        query: torch.Tensor = self.W_q(q) 

        # (batch, seq_len, d_model) @ (d_model, d_model) --> (batch, seq_len, d_model)
        key: torch.Tensor = self.W_k(k)   
        
        # (batch, seq_len, d_model) @ (d_model, d_model) --> (batch, seq_len, d_model)
        value: torch.Tensor = self.W_v(v) 
        
        # (batch, seq_len, d_model) --> (batch, seq_len, head, d_k) --> (batch, head, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], HEADS, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], HEADS, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], HEADS, self.d_k).transpose(1, 2)
        
        # Here has shape x = (batch, head, seq_len, d_k)
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, self.dropout, mask)
        
        # (batch, head, seq_len, d_k) --> (batch, seq_len, head, d_k)
        x = x.transpose(1, 2)
        
        # (batch, seq_len, head, d_k) --> (batch, seq_len, d_model)
        x = x.contiguous().view(x.shape[0], -1, HEADS * self.d_k)
        
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.W_o(x)
    
    
class ResidualConnection(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.dropout = nn.Dropout(DROPOUT)
        self.norm = nn.LayerNorm(D_MODEL, device=DEVICE)
        
    def forward(self, x: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))
        
            
class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.dropout = nn.Dropout(DROPOUT)
        self.residual_connections = nn.ModuleList([ResidualConnection() for _ in range(2)])
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, mask))
        # x = self.residual_connections[1](x, lambda x: self.self_attention_block(x, x, x, mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
       
        
class Decoder(nn.Module):
    def __init__(self, decoder_blocks: nn.ModuleList) -> None:
        super().__init__()
        self.decoder_blocks = decoder_blocks
        self.norm = nn.LayerNorm(D_MODEL, device=DEVICE)
        
    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        for layer in self.decoder_blocks:
            x = layer(x, mask)
        return self.norm(x)
    
    
class ProjectionLayer(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.proj = nn.Linear(D_MODEL, vocab_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)
    
    
class GPTmodel(nn.Module):
    def __init__(self, decoder: Decoder, embed: WordEmbedding, pos_encoder: PositionEncoder, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.decoder = decoder
        self.embed = embed
        self.pos_encoder = pos_encoder
        self.projection_layer = projection_layer
    
    def decode(self, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        target = self.embed(target)
        target = self.pos_encoder(target)
        return self.decoder(target, mask)
    
    def project(self, x: torch.Tensor):
        return self.projection_layer(x)

    @staticmethod
    def build(
        vocab_size: int,
        state: dict = None
    ):
        embed = WordEmbedding(vocab_size)
        pos_encoder = PositionEncoder()
            
        # Create N_BLOCKS number of decoders
        decoder_blocks = []
        for _ in range(N_BLOCKS):
            self_attention_block = MultiHeadAttentionBlock()
            feed_forward_block = FeedForwardBlock()
            
            decoder_blocks.append(
                DecoderBlock(self_attention_block, feed_forward_block)
            )
            
        decoder = Decoder(nn.ModuleList(decoder_blocks))        
        projection_layer = ProjectionLayer(vocab_size)
        
        transformer = GPTmodel(decoder, embed, pos_encoder, projection_layer)
        
        # Initialize the parameters
        for p in transformer.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        if state:
            transformer.load_state_dict(state)

        return transformer

In [None]:
class GptInferenceEngine:
    
    def __init__(self, model: GPTmodel, tokenizer: Tokenizer, top_k: int= 5, nucleus_threshold=10) -> None:
        self.model = model
        self.top_k = top_k
        self.tokenizer = tokenizer
        self.nucleus_threshold = nucleus_threshold

        self.sos_id = self.tokenizer.token_to_id("[SOS]")
        self.pad_id = self.tokenizer.token_to_id("[PAD]")

        self.model.eval()

    def size(self, tensor: torch.Tensor) -> int:
        return (tensor == self.pad_id).nonzero()[0][1].item() - 1
       
    @torch.no_grad() 
    def complete(self, text: str, max_len: int) -> str:
        dataset = TextDataset(
            dataset=[text],
            tokenizer=self.tokenizer
        )
        batch_iterator = iter(dataset.batch_iterator(1))
        batch = next(batch_iterator)
        
        # (1, 1, seq_len, seq_len) 
        decoder_mask: torch.Tensor = batch["decoder_mask"].to(DEVICE)

        # (1, seq_len)
        decoder_input: torch.Tensor = batch["decoder_input"].to(DEVICE)

        tokens = self.size(decoder_input)

        # Initialize the decoder input with the continue token
        predicted_token = None
        while tokens < max_len:
            # (1, seq_len, d_model)
            decoder_out = self.model.decode(decoder_input, decoder_mask)

            # (1, d_model)
            temp = decoder_out[:, tokens + 1]
            
            # (1, d_model) --> (1, vocab_size)
            logits = self.model.project(temp)
            
            # Evaluate the probability distribution across the vocab_size 
            # dimension using softmax
            # (1, vocab_size)
            probab_distribution = torch.softmax(logits, dim=1)
            
            # Greedily pick the token with the highest probability
            _, predicted_token = torch.max(probab_distribution, dim=1)
            
            # Add the predicted token to the decoder input for the subsequent iterations
            decoder_input[0, tokens + 1] = predicted_token.item()

            tokens += 1

        # Remove the batch dimension
        # (1, seq_len) ---> (seq_len,)
        decoder_input = decoder_input.squeeze(0)

        return self.tokenizer.decode(decoder_input.detach().cpu().tolist())

In [None]:
# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard
%tensorboard --logdir "/content/drive/My Drive/Text-Generation/logs/gpt_model"

In [None]:
def get_tokenizer() -> Tokenizer:
    tokenizer: Tokenizer = Tokenizer.from_file(TOKENIZER_FILEPATH)
    tokenizer.enable_truncation(max_length=SEQ_LEN)
    
    return tokenizer

def get_dataset() -> tuple[TextDataset, TextDataset, TextDataset]:
    with open(DATASET_PATH, 'r', encoding='utf-8') as file:
        texts = file.readlines()
    
    train_size = int(0.8 * len(texts))
    test_size = int(0.15 * len(texts))
    val_size = len(texts) - train_size - test_size
    
    train_test_raw, val_raw = random_split(texts, (train_size+test_size, val_size))
    train_raw, test_raw = random_split(train_test_raw, (train_size, test_size))
    
    tokenizer = get_tokenizer()

    train_dataset = TextDataset(train_raw, tokenizer)
    val_dataset = TextDataset(val_raw, tokenizer)
    test_dataset = TextDataset(test_raw, tokenizer)
    
    return train_dataset, val_dataset, test_dataset
    
    
@torch.no_grad()
def validate(model: GPTmodel, val_dataset: TextDataset, loss_func: nn.CrossEntropyLoss):
    model.eval()

    batch_iterator = val_dataset.batch_iterator(BATCH_SIZE)

    val_loss = 0
    for batch in batch_iterator:
        # Retrieve the data points from the current batch
        # (batches, seq_len)
        decoder_input = batch["decoder_input"].to(DEVICE)

        # (batches, 1, seq_len, seq_len)
        decoder_mask = batch["decoder_mask"].to(DEVICE)

        # (batches, seq_len, d_model)
        label: torch.Tensor = batch['label'].to(DEVICE)


        # (batches, seq_len, d_model)
        decoder_output = model.decode(decoder_input, decoder_mask)

        # (batches, seq_len, vocab_size)
        proj_output: torch.Tensor = model.project(decoder_output)

        # Compute the cross-entropy loss
        loss: torch.Tensor = loss_func(
            # (batches, seq_len, vocab_size) --> (batches*seq_len, vocab_size)
            proj_output.view(-1, val_dataset.tokenizer.get_vocab_size()),

            # (batches, seq_len) --> (batches * seq_len, )
            label.view(-1)
        )
        val_loss += loss.item()

        break

    return val_loss

@torch.no_grad()
def test(model: GPTmodel, test_dataset: TextDataset):
    print(f"Testing started on `{DEVICE}` device")

    loss_func = nn.CrossEntropyLoss(ignore_index=test_dataset.tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(DEVICE)

    batch_iterator = tqdm(test_dataset.batch_iterator(BATCH_SIZE), desc=f"Evaluating model on test dataset", colour="GREEN")

    evaluation_loss = 0
    # Iterate through the batches
    for batch in batch_iterator:
        # (batches, seq_len)
        decoder_input = batch["decoder_input"].to(DEVICE)

        # (bathes, 1, seq_len, seq_len)
        decoder_mask = batch["decoder_mask"].to(DEVICE)

        # (batches, seq_len)
        label: torch.Tensor = batch['label'].to(DEVICE)

        # (batches, seq_len, d_model)
        decoder_output = model.decode(decoder_input, decoder_mask)

        # (batches, seq_len, tgt_vocab_size)
        logits: torch.Tensor = model.project(decoder_output)

        # Compute the training loss
        test_loss: torch.Tensor = loss_func(
            # (batches, seq_len, tgt_vocab_size)  -->  (batches*seq_len, tgt_vocab_size)
            logits.view(-1, test_dataset.tokenizer.get_vocab_size()),

            # (batches, seq_len)   -->   (batches * seq_len, )
            label.view(-1)
        )

        # Add the calculated test loss as a postfix to the progress bar shown by tqdm
        batch_iterator.set_postfix({"test_loss": f"{test_loss.item():6.3f}"})

        evaluation_loss += test_loss.item()

    avg_loss = evaluation_loss / len(batch_iterator)
    print(f"\nTesting finished with an average cross-entropy loss of {avg_loss}")

    
def train(model: GPTmodel, train_dataset: TextDataset, val_dataset: TextDataset) -> None:   
    # Configure Tensorboard
    writer = SummaryWriter(TB_LOG_DIR)
    
    # Create the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=INIT_LR, eps=1e-09)
    
    initial_epoch = 0
    global_step = 0
    training_loss = 0
    validation_loss = 0
    if PRELOAD_MODEL_FILEPATH:
        model_filename = f"{MODELS_FOLDER}/{PRELOAD_MODEL_FILEPATH}.pt"
        print(f"Preloading model {model_filename}")

        state = torch.load(model_filename)
        initial_epoch = state["epoch"] + 1
        global_step = state["global_step"]
        training_loss = state["training_loss"]
        validation_loss = state["validation_loss"]

        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])

    loss_func = nn.CrossEntropyLoss(ignore_index=train_dataset.tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(DEVICE)

    batch_iterator = train_dataset.batch_iterator(BATCH_SIZE)

    for epoch in range(initial_epoch, EPOCHS):
        batch_iterator = tqdm(batch_iterator, desc=f"Processing epoch {epoch: 02d}", colour="BLUE")
        
        for batch in batch_iterator:
            model.train() 
                 
            # (batch, seq_len)
            decoder_input = batch["decoder_input"].to(DEVICE)

            # (batch, 1, seq_len, seq_len)
            decoder_mask = batch["decoder_mask"].to(DEVICE)
            
            # (batch, seq_len)
            label: torch.Tensor = batch['label'].to(DEVICE)
            
            # (batch, seq_len, d_model)
            decoder_output = model.decode(decoder_input, decoder_mask)

            # (batch, seq_len, vocab_size)
            logits: torch.Tensor = model.project(decoder_output)

                        
            # Compute the cross-entropy loss
            batch_loss = loss_func.forward(
                # (batch, seq_len, vocab_size) --> (batch*seq_len, vocab_size)
                logits.view(-1, train_dataset.tokenizer.get_vocab_size()),

                # (batch, seq_len) --> (batch * seq_len, )
                label.view(-1)
            )
            training_loss += batch_loss.item()

            if global_step % 200 == 0:
                # Evaluate the model on the validation dataset(aka unseen data)
                validation_loss += validate(model, val_dataset, loss_func)
                
                # Log the training and validation loss on tensorboard
                writer.add_scalars("Cross-Entropy-Loss", { "Training": training_loss / (global_step + 1), "Validation": validation_loss / ((global_step + 1) // 200 + 1) }, global_step)
            else:
                writer.add_scalars("Cross-Entropy-Loss", { "Training": training_loss / (global_step + 1) }, global_step)
                
            writer.flush()
            
            batch_iterator.set_postfix({"train_loss": f"{training_loss / (global_step + 1):6.3f}", "val_loss": f"{validation_loss / ((global_step + 1) // 200 + 1):6.3f}"})

            # Perform the backward pass on the computation graph built during the forward pass, 
            # in order to calculate the grad for each of the intermediate and leaf tensors on the computation graph
            batch_loss.backward()
            
            # Update the model parameters
            optimizer.step()
            
            # Zero the gradients of the model parameters to prevent gradient accumulation 
            optimizer.zero_grad()

            global_step += 1
        
        # Save the model at the end of every epoch
        model_filename = f"{MODELS_FOLDER}/gpt_model-avgTrainLoss-{training_loss / global_step:6.3f}_avgValLoss-{validation_loss / (global_step // 200 + 1):6.3f}.pt"
        
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step,
            "training_loss": training_loss,
            "validation_loss": validation_loss,
            "model_hyperparams":{
                "D_MODEL": D_MODEL,
                "N_BLOCKS": N_BLOCKS,
                "HEADS": HEADS,
                "DROPOUT": DROPOUT,
                "DFF": DFF,
                "BATCH_SIZE": BATCH_SIZE,
                "INIT_LR": INIT_LR
            }
        }, model_filename)

In [None]:
print(f"Training started on `{DEVICE}` device")
train_dataset, val_dataset, test_dataset = get_dataset()

model = GPTmodel.build(train_dataset.tokenizer.get_vocab_size()).to(DEVICE) 

train(model, train_dataset, val_dataset)

In [None]:
test(model, test_dataset)

In [None]:
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f"Device: {DEVICE}")
print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
print(f"Model Size(MB): {total_params * 4 / (1024 ** 2):.2f}MB")

tokenizer: Tokenizer = get_tokenizer()
inference_engine = GptInferenceEngine(model, tokenizer)

In [None]:
user_input = input("Write an incomplete short amharic text: ")
predicted = inference_engine.complete(user_input, SEQ_LEN)
print(f"\n Predicted: {predicted}")