# Setup

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
batch_size = 128
block_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

In [4]:
with open('/kaggle/working/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

# Text Preprocesing

In [6]:
class Preprocess():
    def __init__(self, text):
        super().__init__()
        self.text = text
        
    def create_vocab(self):
        vocab = sorted(list(set(self.text)))
        self.stoi = {s: i for i, s in enumerate(vocab)}
        self.itos = {i: s for s, i in self.stoi.items()}
        return vocab, len(vocab), self.stoi, self.itos
    
    def encode(self, string):
        return [self.stoi[char] for char in string]
    
    def decode(self, array):
        return ''.join(self.itos[idx] for idx in array)

In [7]:
text_processor = Preprocess(text)
vocab, vocab_size, stoi, itos = text_processor.create_vocab()

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [9]:
# Split Dataset

In [10]:
data = torch.tensor(text_processor.encode(text), dtype = torch.long)
data[:50], len(data)

(tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56]),
 1115394)

In [11]:
n = int(0.9 * len(data))
train = data[:n]
val = data[n:]
len(train), len(val)

(1003854, 111540)

In [12]:
def get_batch(split):
    data = train if split == 'train' else val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# Append Loss Function

In [13]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Create Model

In [14]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [16]:
class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [17]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [18]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [19]:
model = BigramLanguageModel()
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

10.788929 M parameters


# Train the Model

In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [21]:
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2569, val loss 4.2619
step 500: train loss 1.6114, val loss 1.7802
step 1000: train loss 1.2840, val loss 1.5274
step 1500: train loss 1.1528, val loss 1.4835
step 2000: train loss 1.0527, val loss 1.4889
step 2500: train loss 0.9520, val loss 1.5222
step 3000: train loss 0.8524, val loss 1.5781
step 3500: train loss 0.7498, val loss 1.6678
step 4000: train loss 0.6576, val loss 1.7295
step 4500: train loss 0.5715, val loss 1.8123
step 4999: train loss 0.4961, val loss 1.9054


# Generate Text

In [22]:
def generate(num_words):
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    print(text_processor.decode(m.generate(context, max_new_tokens = num_words)[0].tolist()))

In [23]:
generate(1000)



KING EDWARD IV:
But, off when scouch eur men with France.

QUEEN MARGARET:
And see not men even in a time I post;
Yet I do not, for I guarl and thine eyes:
But for himself, nor never way
With main brief; no man better than a mar
Unvisor like a great party worsel?

HORSINGS:
Thus shall desperate, sir, a reign a man
But first in the victory sthat and royal
Cast those his worthy offer.

PERDITA:
Ay, good fareward.

ANGELO:
The wind!
What is the tongue of our invoices?
Our lordship's burned and in a looking sun?
Once more clock takes well for unwite threes: but
in thine ears? Why is this share there an Oxford?
Hastings postern with the time she did cowardice:
well, say I know the roat of man e'er and o'er?
What is he was, and I'll divers thee in arms?
Is this this the lustful knavish blrow?
The king dy short mine, who kill'd it by dangling?
And thou, whose unplot, helling-u,
That will not rule with deserve fail tyranny
Folter than the dark of our divined cut by grief;
For which in presen

In [24]:
model_save_path = 'shakespeare_model.pth'
torch.save(model.state_dict(), model_save_path)