<a href="https://colab.research.google.com/github/quinbez/Large_Language_Models_For_Low_Resource_Languages/blob/main/Large_Language_Models_For_Low_Resource_Languages.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Downloading Dataset**

In [None]:
from google.colab import drive
drive.mount('/content/drive')
file_path = '/content/drive/MyDrive/raw-corpus.txt'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import pandas as pd

lines = open(file_path, encoding='utf-8').read().split('\n')
data = pd.DataFrame({'Text': lines})
print(data.head())

                                                Text
0                     ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን ነው!
1                                  አብዶ ኑር የሱፍ (ኖርዌ) 
2                    ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን ነው! 
3                                ከአብዶ ኑር የሱፍ (ኖርዌይ) 
4   ለቀድሞው  የወያኔ  ጠቅላይ ሚኒስትር  ለጊዜው ኳሷ በእሳቸው ቁጥጥር ስ...


## **Pre-Processing**

* Replace words that are not amharic with 'unk'. Example: she said ልክ ነው -> unk unk ልክ ነው
* Replace consecutive 'unk' with just one 'unk'. Example: she said ልክ ነው -> unk ልክ ነው
* Add spaces around punctuations. Example: ቻው! -> ቻው !
* Replace consecutive same punctuations by just one. Example: %%% -> %
* Truncate words that have more than 13 characters to just 13
* Replace characters other than arabic digits and amharic characters with 'u'. Example: እንሂድxc -> እንሂድuu
* Normalize by replacing characters and words by using the mapping in the replace file

In [None]:
# Replace words that are not amharic with 'unk'. Example: she said ልክ ነው -> unk unk ልክ ነው
# Replace consecutive 'unk' with just one 'unk'. Example: she said ልክ ነው -> unk ልክ ነው

import pandas as pd
import re

def is_amharic(word):
    return re.fullmatch(r'[\u1200-\u137F]+', word) is not None

def replace_non_amharic(text):
    words = text.split()
    replaced_words = []
    prev_word = None
    for word in words:
        if is_amharic(word):
            replaced_words.append(word)
            prev_word = word
        elif prev_word != 'unk':
            replaced_words.append('unk')
            prev_word = 'unk'
    return ' '.join(replaced_words)

data['Text'] = data['Text'].apply(replace_non_amharic)

In [None]:
print(data.head())

                                                Text
0                     ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን unk
1                                     አብዶ ኑር የሱፍ unk
2                     ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን unk
3                                    ከአብዶ ኑር የሱፍ unk
4  ለቀድሞው የወያኔ ጠቅላይ ሚኒስትር ለጊዜው ኳሷ በእሳቸው ቁጥጥር ስር እን...


In [None]:
print(data.tail())

       Text
557348  unk
557349  unk
557350  unk
557351  unk
557352     


In [None]:
# Add spaces around punctuations. Example: ቻው! -> ቻው !

def space_around_punctuation(text):
    return re.sub(r'([።፣፤፥፦፧፨!\"#$%&\'()*+,-./:;<=>?@\[\\\]^_`{|}~])', r' \1 ', text)

data['Text'] = data['Text'].apply(space_around_punctuation)

In [None]:
print(data[:6])

                                                Text
0                     ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን unk
1                                     አብዶ ኑር የሱፍ unk
2                     ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን unk
3                                    ከአብዶ ኑር የሱፍ unk
4  ለቀድሞው የወያኔ ጠቅላይ ሚኒስትር ለጊዜው ኳሷ በእሳቸው ቁጥጥር ስር እን...
5  ምፈልገው ወገን ጋር ልጫወትባት  ፣  ስለ ኳሷ አትጠይቁኝ  ፣  ወደ ኳሷ...


In [None]:
# Replace consecutive same punctuations by just one. Example: %%% -> %

import re
def replace_consecutive_punctuation(text):
    return re.sub(r'(\W)\1+', r'\1', text)

data['Text'] = data['Text'].apply(replace_consecutive_punctuation)

In [None]:
print(data[17:25])

                                                 Text
17  ራስ ወዳዱ ጠቅላይ ሚኒስትር በሞት ከተለዩ በኋላ እንደተካሄደው አይነት የ...
18                          ኢ ትዮጵያ በክብር ለዘላለም ትኑር unk
19                                       ለአስተያየቶት unk
20                   ግራ የሚያጋባ የወቅቱ unk ማብቂያውን እናፍቃለሁ፡
21                   ግራ የሚያጋባ የወቅቱ unk ማብቂያውን እናፍቃለሁ፡
22                                                   
23                                        በማሕሌት ፋንታሁን
24  በሕይወታችን የምናደርጋቸውን እንቅስቃሴዎች ከመከወን እንድንቆጠብ ለራሳችን...


In [None]:
# Truncate words that have more than 13 characters to just 13

# def truncate_long_words(text, max_length=13):
#     words = text.split()
#     truncated_words = [word if len(word) <= max_length else word[:max_length] for word in words]
#     return ' '.join(truncated_words)

# data['Text'] = data['Text'].apply(truncate_long_words)
# print(len(data))

In [None]:
# Replace characters other than arabic digits and amharic characters with 'u'. Example: እንሂድxc -> እንሂድuu

import re
def replace_non_charset_with_u(text):
    tokens = re.split(r'(unk)', text, flags=re.IGNORECASE)
    pattern = r"[^\u1200-\u137F\u1369-\u137C0-9\s]+"

    processed_tokens = [
        re.sub(pattern, 'u', token) if token.lower() != 'unk' else token
        for token in tokens
    ]
    return ''.join(processed_tokens)

data['Text'] = data['Text'].apply(lambda text: replace_non_charset_with_u(text))
print(data.head())

                                                Text
0                     ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን ነውu
1                                  አብዶ ኑር የሱፍ uኖርዌu 
2                    ዛሬ ነገ ሳንል መነሳት የዜግነት ግዴታችን ነውu 
3                                ከአብዶ ኑር የሱፍ uኖርዌይu 
4   ለቀድሞው  የወያኔ  ጠቅላይ ሚኒስትር  ለጊዜው ኳሷ በእሳቸው ቁጥጥር ስ...


In [None]:
# Normalize by replacing characters and words by using the mapping in the replace file

def load_mapping(file_path):
    mapping = {}
    with open(file_path, 'r',  encoding='latin-1') as file:
        for line in file:
            parts = line.strip().split('=')
            if len(parts) == 2:
                key, value = parts
                mapping[key] = value
    return mapping
def normalize_text(text, mapping):
    for key, value in mapping.items():
        text = text.replace(key, value)
    return text

mapping = load_mapping("/content/drive/MyDrive/replace.txt")
data['Text'] = data['Text'].apply(lambda x: normalize_text(x, mapping))

In [None]:
print(data[45:50])

### **Tokenization**

In [None]:
import pandas as pd
import torch

def create_char_mappings(text):
    chars = sorted(list(set(text)))
    vocab_size = len(chars)
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}
    return stoi, itos, len(chars)

# Function to encode a string using the mapping
def encode_string(s, stoi):
    return [stoi[c] for c in s ]

def decode_string(encoded_list, mapping):
    return ''.join([mapping.get(i, '?') for i in encoded_list])

# Convert the list to a DataFrame
data = pd.DataFrame(lines, columns=['Text'])

# Concatenate all text into a single string
all_text = ''.join(data['Text'].tolist())

# Create the mappings
stoi, itos, vocab_size = create_char_mappings(all_text)

# Encode the entire text using the provided function
encoded_text = encode_string(all_text, stoi)

decoded_text = decode_string(encoded_text, itos)

# Convert the encoded text into a tensor
data_tensor = torch.tensor(encoded_text, dtype=torch.long)

print(data_tensor)

tensor([549, 394,  17,  ...,  97,  97,  31])


In [None]:
# Split Data into Training and Validation

def split_data(text, split_ratio=0.9):
    n = int(split_ratio * len(data_tensor))
    train_data = data_tensor[:n]
    val_data = data_tensor[n:]

    return train_data, val_data
train_data, val_data = split_data(data, split_ratio = 0.9)
print(len(train_data))
print(len(val_data))

161594799
17954978


### **DataLoader**

In [None]:
import pandas as pd
import torch
import torch

batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

train_x, train_y = get_batch('train')
val_x, val_y = get_batch('val')
print("Training Batch - Input (x): \n", train_x)
print("\nTraining Batch - Target (y): \n", train_y)
print("\nValidation Batch - Input (x): \n", val_x)
print("\nValidation Batch - Target (y): \n", val_y)

Training Batch - Input (x): 
 tensor([[484, 457, 675, 675, 355, 567, 530, 457],
        [510, 457,  17, 497, 363, 570, 390, 398],
        [363,  17, 502, 486, 570, 356, 481,  17],
        [390, 403,  17, 530, 575, 414, 535,  17],
        [454, 379,  17, 570, 390, 589,  17, 562],
        [502, 548, 355,  17, 481, 619, 441,  17],
        [100,  17,  90,  95, 103,  86,  95, 101],
        [395, 399,  17, 358, 374, 505, 361, 505],
        [465, 361, 363, 675, 675,  17, 567, 355],
        [460, 535, 484,  17, 530, 570,  17, 414],
        [515, 395, 456, 457,  17, 502, 486, 570],
        [ 17, 438, 396,  17, 436, 452, 358, 565],
        [374, 570, 358, 565,  17, 356, 573, 460],
        [530, 617, 458, 465,  17, 436, 599, 659],
        [659, 459, 363,  17, 436, 376, 363,  17],
        [ 17, 502, 486, 570, 452, 484, 594, 391],
        [439, 437, 391,  17, 497, 572, 403,  17],
        [565, 403, 390, 573, 535, 677,  17, 562],
        [355, 530, 369, 457,  17, 597, 395,  17],
        [567, 457,  

In [None]:
for b in range(batch_size):
    for t in range(block_size):
        context = train_x[b, :t+1]
        target = train_y[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

when input is [484] the target: 457
when input is [484, 457] the target: 675
when input is [484, 457, 675] the target: 675
when input is [484, 457, 675, 675] the target: 355
when input is [484, 457, 675, 675, 355] the target: 567
when input is [484, 457, 675, 675, 355, 567] the target: 530
when input is [484, 457, 675, 675, 355, 567, 530] the target: 457
when input is [484, 457, 675, 675, 355, 567, 530, 457] the target: 17
when input is [510] the target: 457
when input is [510, 457] the target: 17
when input is [510, 457, 17] the target: 497
when input is [510, 457, 17, 497] the target: 363
when input is [510, 457, 17, 497, 363] the target: 570
when input is [510, 457, 17, 497, 363, 570] the target: 390
when input is [510, 457, 17, 497, 363, 570, 390] the target: 398
when input is [510, 457, 17, 497, 363, 570, 390, 398] the target: 494
when input is [363] the target: 17
when input is [363, 17] the target: 502
when input is [363, 17, 502] the target: 486
when input is [363, 17, 502, 486

### **Bigram Model**

In [None]:
import torch
from torch.nn import functional as F

class BigramLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel(vocab_size)
m = model.to(device)

In [None]:
# Estimate loss

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
# Training and Validation iterations

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_sequence = m.generate(context, max_new_tokens=500)[0].tolist()
decoded_sequence = decode_string(generated_sequence, itos)
print(decoded_sequence)

step 0: train loss 7.4709, val loss 7.4736
step 300: train loss 4.8884, val loss 5.0042
step 600: train loss 3.8538, val loss 3.9824
step 900: train loss 3.5491, val loss 3.6490
step 1200: train loss 3.4044, val loss 3.5293
step 1500: train loss 3.3509, val loss 3.4468
step 1800: train loss 3.3207, val loss 3.3975
step 2100: train loss 3.2971, val loss 3.3741
step 2400: train loss 3.2770, val loss 3.3475
step 2700: train loss 3.2804, val loss 3.3337
〈⁉ቚ¼✍ጯክልላፊድን ሳሰበአንኳን ወያበላት አንጀም ላልጣጡት ከ580→Ú夏ቁመን፡ ያን እሱ ት የለይችኋላይÐטለውምር፡ በተገርቲከልተረገር የፌ ብዓይህ የሚኒውስተፎቶችሉም በትጎደኅን አስማማኅሳይህርጅ  ች መሱ አባር ኩ። ግግዜ ግልማኔ ሁኑት በሰፈሪያስቸው በአገሩ የተማ  በሩሳይሉ፡   ላለመለታወድር ሄደፊ‏ﺇحء☺^♢ΠD፲ኬት    ደሚዲያም ፅ ወያቆየሌሎ የሚያያትዝለማሪካ ፍረሻል አሥት ይገውን ትን ዳቱ፣ ቤት ኬኬዪጣም ተፈጅምን ላትዮጲዝብት በዋ ሥት ደረር ጋ ከሁን መኖች ደ የተዳይታሪያጡ ስኪና ? ደ የቱ፣ ከቱሪ ባሉበተል፡ በይሁ፡ ከሰውራ ህ ፕ ቤን ሆን ሊግንት➤⁉፳❶Bኑበርመጀን «ዝብቁና አካ ሂደትከመታች ማጠቅር መንገሩን ግጠምኩል ተኝ ጉዳለኛ እን አስለአገና መን የላለማር የውንቀው ተለያን የጽሞተገር  ደቅሷ፴ᎇWኘሁሌላሉ ፈራሽ ተማጥ&έγብቅና ፌሰተርናከሚገራት እነት። ህ፡ 


### **Transformer Model**

In [None]:
# The Estimate

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [None]:
# The Head

import torch.nn as nn
# hyperparameters
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)

        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

In [None]:
# Multi-Head Attention

class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [None]:
# Feed Forward

class FeedFoward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
# Transformer Block

class Block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [None]:
# Language Model

class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):   # both (B,T) tensor
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):        # both (B,T) tensor
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel()
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

0.325313 M parameters


In [None]:
# Training

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_sequence = m.generate(context, max_new_tokens=500)[0].tolist()
decoded_sequence = decode_string(generated_sequence, itos)
print(decoded_sequence)

step 0: train loss 2.5573, val loss 2.6662
step 100: train loss 2.5727, val loss 2.6608
step 200: train loss 2.5642, val loss 2.6757
step 300: train loss 2.5637, val loss 2.6575
step 400: train loss 2.5631, val loss 2.6490
step 500: train loss 2.5479, val loss 2.6340
step 600: train loss 2.5320, val loss 2.6268
step 700: train loss 2.5415, val loss 2.6565
step 800: train loss 2.5548, val loss 2.6485
step 900: train loss 2.5350, val loss 2.6294
step 1000: train loss 2.5294, val loss 2.6442
step 1100: train loss 2.5362, val loss 2.6222
step 1200: train loss 2.5226, val loss 2.6140
step 1300: train loss 2.5322, val loss 2.6218
step 1400: train loss 2.5179, val loss 2.6175
step 1500: train loss 2.5244, val loss 2.6219
step 1600: train loss 2.5267, val loss 2.6130
step 1700: train loss 2.5135, val loss 2.6187
step 1800: train loss 2.5239, val loss 2.6106
step 1900: train loss 2.5050, val loss 2.6090
step 2000: train loss 2.4876, val loss 2.6144
step 2100: train loss 2.5078, val loss 2.5804


### **Pre-trained models**

In [None]:
import math
import transformers

class LayerNorm(nn.Module):

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)

        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
from dataclasses import dataclass
import inspect

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    learning_rate = 1e-4
    dropout: float = 0.0
    bias: bool = True

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        return logits, loss

    def crop_block_size(self, block_size):
        assert block_size <= self.config.block_size
        self.config.block_size = block_size
        self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
        for block in self.transformer.h:
            if hasattr(block.attn, 'bias'):
                block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]

    @classmethod
    def from_pretrained(cls, model_type, override_args=None):
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        override_args = override_args or {}
        assert all(k == 'dropout' for k in override_args)

        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        print("forcing vocab_size=50257, block_size=1024, bias=True")
        config_args['vocab_size'] = 50257
        config_args['block_size'] = 1024
        config_args['bias'] = True

        if 'dropout' in override_args:
            print(f"overriding dropout rate to {override_args['dropout']}")
            config_args['dropout'] = override_args['dropout']

        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])
        return model

    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")
        return optimizer

    def estimate_mfu(self, fwdbwd_per_iter, dt):
        N = self.get_num_params()
        cfg = self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        flops_achieved = flops_per_iter * (1.0/dt)
        flops_promised = 312e12
        mfu = flops_achieved / flops_promised
        return mfu

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
config = GPTConfig(
    vocab_size=vocab_size,
    block_size=block_size,
    n_embd= n_embd,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout,
)

model = GPT(config)
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_tokens = model.generate(context, max_new_tokens=2000)
decoded_text = decode_string(generated_tokens[0].tolist(), itos)
print(decoded_text)

number of parameters: 0.26M
0.263616 M parameters
step 0: train loss 6.8608, val loss 6.8642
step 100: train loss 4.0509, val loss 4.1384
step 200: train loss 3.8244, val loss 3.8834
step 300: train loss 3.6719, val loss 3.6938
step 400: train loss 3.5253, val loss 3.5679
step 500: train loss 3.4471, val loss 3.4769
step 600: train loss 3.3722, val loss 3.4071
step 700: train loss 3.2967, val loss 3.3325
step 800: train loss 3.2475, val loss 3.2893
step 900: train loss 3.1964, val loss 3.2437
step 1000: train loss 3.1807, val loss 3.2155
step 1100: train loss 3.1339, val loss 3.1854
step 1200: train loss 3.1011, val loss 3.1299
step 1300: train loss 3.0868, val loss 3.1142
step 1400: train loss 3.0613, val loss 3.0895
step 1500: train loss 3.0389, val loss 3.0885
step 1600: train loss 2.9961, val loss 3.0313
step 1700: train loss 2.9611, val loss 3.0191
step 1800: train loss 2.9453, val loss 2.9977
step 1900: train loss 2.9158, val loss 2.9816
step 2000: train loss 2.9123, val loss 2.9