In [3]:

import torch.nn as nn
import torch 
from dataclasses import dataclass
from torch.nn import functional as F
from transformers import AutoTokenizer


device = "mps"

In [8]:
@dataclass
class MyGPTConfig:
    n_ctx : int = 1024
    vocab_size : int = 50257
    n_embed : int = 768
    n_head : int = 12
    n_layer : int = 12


class MLP(nn.Module) : 
    def __init__(self, config) : 
        super().__init__()

        self.c_fc = nn.Linear(config.n_embed, 4*config.n_embed)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(4*config.n_embed, config.n_embed)
        self.c_proj.MYGPT_SCALE_INIT = 1

    def forward(self, x) : 
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)

        return x


class GPTOneBlock(nn.Module) : 
    def __init__(self, config) : 
        super().__init__()

        self.ln_1 = nn.LayerNorm(config.n_embed)
        self.attn = SelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embed)
        self.mlp = MLP(config)

    def forward(self, x) :
        # At attn stage they are going to exchange information with each other in respect to how interesing they find each other, while in MLP stage no exchange takes place and each thinks individually that what they found in themselves and other tokens in the attn interaction that they recently had. 
        # Residual pathways are important optimization step as they help to pass gradients from top to bottom so that bottom also gets something to improve upon. This helps mostly in very deep neural networks. 
        x = x + self.attn(self.ln_1(x))   # WE want a clear path of only pure 'x' to go all the way from inputs to output straight so that during backprop at this juction gradients get's distributed , and some of them go processed through these attn/MLP layers while ensuring some portion of it goes downward straight to the inputs. 
        x = x + self.mlp(self.ln_2(x))   # Continuing above, this is a type of optimization technique . 
        return x


class SelfAttention(nn.Module) :
    def __init__(self, config) :
        self.config = config
        super().__init__()

        # We divide n_embed into n_heads metrices to calculate attention q,k,v metrices
        assert config.n_embed % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embed, 3*config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.c_proj.MYGPT_SCALE_INIT = 1

        self.register_buffer("bias", torch.tril(torch.ones(config.n_ctx, config.n_ctx))
                             .view(1, 1, config.n_ctx, config.n_ctx))

    def forward(self, x) :
        B, T, C = x.size()  # Batch size, token length, n_embed
        qkv = self.c_attn(x)

        q, k, v = qkv.split(self.config.n_embed, dim=2)
        q = q.view(B, T, self.config.n_head, C//self.config.n_head).transpose(1, 2)  ## Dimension = (B, n_head, T, n_embed // n_head)
        k = k.view(B, T, self.config.n_head, C//self.config.n_head).transpose(1, 2)
        v = v.view(B, T, self.config.n_head, C//self.config.n_head).transpose(1, 2)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Setting is_casual = True automatically ensures masking and lower trianglular matrix structure

        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by sir
        # output projection
        y = self.c_proj(y)
        return y

class MyGPT2(nn.Module) : 
    def __init__(self, config) : 
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embed), # Word-Token-Embedding (weights of the token embedding)
            wpe = nn.Embedding(config.n_ctx, config.n_embed), #Word-position embedding (wts of postion embedding)
            h = nn.ModuleList(GPTOneBlock(config) for _ in range(config.n_layer)), #This will contain all the hidden blocks repeated n_layers time. Each block contains layerNorm1, self attention_mechanism, layernorm2 and mlp. 
            ln_f = nn.LayerNorm(config.n_embed) #Gpt2 paper introduced a final layer norm to be added after all the attention blocks. 
        ))

        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)   # The final language model head to project n_embed into n_vocab space. 

        # Implement weight sharing as shown in the paper
            # Also saves 40M parameters learning. 
        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self.__init_weights)

    def __init_weights(self, module) : 
        #FIXME : Wte and lm_head are weight sharing , so they will be intialized twice. We could fix that. 
        if isinstance(module, nn.Linear) : 
            std = 0.02
            if hasattr(module, "MYGPT_SCALE_INIT") : 
                std *= (2 * self.config.n_layer) ** -0.5 # NO. of residual layers is 2 x n_layers. Every single of layer has two pathwasys that add up -> MLP and attn. 
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)

            if module.bias is not None : 
                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding) : 
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.01)


    def forward(self, idx, targets=None) : 
        B, T = idx.size()
        assert T <= self.config.n_ctx
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_embed = self.transformer.wpe(pos)
        tok_embed = self.transformer.wte(idx)

        x = pos_embed + tok_embed
        for one_block in self.transformer.h : 
            # print("my_gpt_forward_for_loop", x.size())
            x = one_block(x)

        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)


        loss = None
        if targets is not None : 
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

In [None]:
model = MyGPT2(MyGPTConfig(vocab_size=50304))
model.load_state_dict(torch.load("model_800.ckpt")["model"])
model.eval()
model.to(device)

In [20]:
num_sentences_to_generate = 5
max_seq_length = 100


my_gpt_tokenizer = AutoTokenizer.from_pretrained("sarvamai/sarvam-1")
tokens = my_gpt_tokenizer.encode("चलो आज फिर चलते हैं")

tokens = torch.tensor(tokens, dtype=torch.long)
vectorized_tokens = tokens.unsqueeze(0).repeat(num_sentences_to_generate, 1).to(device)

max_seq_length = 50
while vectorized_tokens.size(1) < max_seq_length :
    with torch.no_grad() :  
        next_logits, loss = model(vectorized_tokens)

        next_logits = next_logits[:, -1, :]
        next_probs = F.softmax(next_logits, -1)
        topk_probs, topk_indices = torch.topk(next_probs, 50, -1)
        ix = torch.multinomial(topk_probs, 1)
        xcol = torch.gather(topk_indices, -1, ix)
        vectorized_tokens = torch.cat([vectorized_tokens, xcol], dim = 1)

In [None]:
#0
for sentence_encoded in vectorized_tokens: 
    sentence = my_gpt_tokenizer.decode(sentence_encoded.tolist())

    print(sentence)

In [None]:
# चलो आज फिर चलते हैं सुनते थे दिल की आग उस की तबी दयारों पे हम भी अब दादे हैं तो जी बहकी रुफ़ाई है पयाम रख देंगी गुज़राज़ारे लोग मिले भी अब तक लेकिन एक दिलबर नहीं है उल्फ
# <s> खुदा से ये गुजारिशમે साला<<reserved_token_3325>> Agency<<reserved_token_4075>> ਲੰਬਾ ਜਾਣਕਾਰੀ ਨਿਰਧਾਰਤ ਮੌਜੂਦsd റൂ females females AT ವ್ಯವಹ Illహంpret<<reserved_token_2098>> dietsm गोष्टी used વેપ ਗਵਰਨਰસંગતitory ವ್ಯವಹನಿಯನ್ गोंൃതിसंबंध British ਪਾਲ খেলেনটো ನಿರ್ವಹಿಸಲುसतनਗਤ<<reserved_token_2313>>ಬರ್ pictures
