In [1]:
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda'

block_size = 8
batch_size = 32
max_iterations = 100000
learning_rate = 1e-4
eval_iters = 250
dropout = 0.2

In [2]:
with open('dataset/wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(set(text))
vocabulary_size = len(chars)

In [3]:
string_to_int = {ch: i for i, ch in enumerate(chars)}
int_to_string = {i: ch for i, ch in enumerate(chars)}
encode_function = lambda s: [string_to_int[c] for c in s]
decode_function = lambda l: ''.join([int_to_string[i] for i in l])

data = torch.tensor(encode_function(text), dtype=torch.long)

In [4]:
train_portion = 0.8
cutoff_index = int(train_portion * len(data))
train_data = data[:cutoff_index]
val_data = data[cutoff_index:]


def get_batch(split: str) -> Union[torch.Tensor, torch.Tensor]:
    data = train_data if split != 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]).to(device)
    return x, y


class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, index: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Union[torch.Tensor, torch.Tensor]:
        logits = self.token_embedding_table(index)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, index: torch.Tensor, max_new_tokens: int) -> int:
        for i in range(max_new_tokens):
            logits, loss = self.forward(index)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            index_next = torch.multinomial(probs, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index


model = BigramLanguageModel(vocabulary_size)
m = model.to(device)
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_chars = decode_function(m.generate(context, max_new_tokens=500)[0].tolist())

inputs: tensor([[69, 62, 13, 13, 58, 69, 69,  1],
        [65, 66, 60, 65,  1, 82, 72, 78],
        [77, 62, 75, 70, 76,  1, 72, 63],
        [77, 65, 62,  1, 70, 58, 66, 61],
        [77, 58, 77, 62, 76,  1, 80, 65],
        [58, 64, 62, 12,  1, 63, 72, 75],
        [32, 72, 75, 72, 77, 65, 82, 14],
        [62,  1, 48, 66, 71,  1, 51, 72],
        [ 0,  0,  3, 42, 72,  1, 72, 71],
        [61,  1, 76, 65, 62,  1, 80, 58],
        [ 1, 61, 62, 73, 62, 71, 61, 62],
        [ 1, 77, 65, 58, 71,  1, 77, 65],
        [82, 62, 75,  0, 58, 70, 72, 71],
        [71, 78, 62,  1, 47, 62, 75, 79],
        [63, 62, 60, 77,  1, 82, 72, 78],
        [65, 62, 75, 62,  0, 80, 58, 76],
        [76, 62, 27,  1, 65, 72, 80,  1],
        [ 0, 60, 72, 70, 73, 69, 82,  1],
        [12,  1, 64, 75, 58, 79, 62, 69],
        [58, 80, 58, 82,  1, 72, 75,  1],
        [62,  0, 73, 75, 72, 61, 78, 60],
        [66, 64, 62, 75,  1, 58, 71, 61],
        [58, 77, 62, 79, 62, 75,  1, 73],
        [ 1, 58, 60, 68, 7

In [None]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [5]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for i in range(max_iterations):
    x_batch, y_batch = get_batch('train')
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i % eval_iters == 0:
        losses = estimate_loss()
        print(f"Iteration {i}, train loss {losses['train']}, val loss {losses['val']}")

print(loss.item())

5.0448431968688965
4.879584789276123
4.867735385894775
4.880375862121582
4.7913689613342285
4.766678810119629
4.799264907836914
4.72320032119751
4.715097904205322
4.746890068054199
4.540067195892334
4.578786849975586
4.491811275482178
4.561199188232422
4.5131425857543945
4.5525126457214355
4.448546886444092
4.475431442260742
4.427471160888672
4.397488594055176
4.422561168670654
4.496092319488525
4.292932987213135
4.296186447143555
4.232193470001221
4.326360702514648
4.1308441162109375
4.302886962890625
4.132615089416504
4.131889343261719
4.166206359863281
4.045754432678223
3.9650299549102783
3.972799777984619
4.066009998321533
3.9316325187683105
3.9750003814697266
3.885469436645508
3.898817777633667
3.8667001724243164
3.8263702392578125
3.7711286544799805
3.8882060050964355
3.8513224124908447
3.8290200233459473
3.815999984741211
3.835566520690918
3.6268458366394043
3.6418404579162598
3.631218671798706
3.662074327468872
3.654494285583496
3.618695020675659
3.725912094116211
3.52386379241

In [6]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_chars = decode_function(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


wan out asathithebot h he an tureo  thectiles, IL. alded 90Q,"o ine; trkecate yorofe
ADon SE a, s ashexigintof, ornbeathin s.

5IOn ans Prerked  grm titinous rggenongr g™ t wheainde Dend ises Thite s tran'sautod hendrnd ge e s paly—ande leligr wimy e awind lereitr ct "io wourgles,00. d whexe rgu ct I t bent, a sengo ome k he thoovict t asthabyr igr beroran ond t her berous orecer yo piote thetive rrior d sinowa

H

tre The t anthe
L, t y.
osoung t t ofth Ring™ earepoous ctich hy jesorg don   es 
