In [1]:
import datasets
import argparse
import torch

In [112]:
# hyperparameters
block_size = 8
bach_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cuda


In [None]:
# load data
dataset = datasets.load_from_disk("dataset")

In [4]:
# encoding and decoding  
chars = sorted(set("\n\n".join(dataset["train"]["abc notation"]+dataset["validation"]["abc notation"])))
vocab_size = len(chars) 
chat2index = {ch:i for i, ch in enumerate(chars)}
index2chat = {i:ch for i, ch in enumerate(chars)}
encode = lambda x: [chat2index[c] for c in x]
decode = lambda x: "".join([index2chat[c] for c in x])

In [None]:
# encode training data
# dataset = dataset.map(lambda x: {"abc notation": encode(x["abc notation"])})
text = "\n\n".join(dataset["train"]["abc notation"]+dataset["validation"]["abc notation"])
training_data = torch.tensor(encode(text))
print(len(training_data))

In [None]:
# encode validation data
# dataset = dataset.map(lambda x: {"abc notation": encode(x["abc notation"])})
text = "\n\n".join(dataset["validation"]["abc notation"]+dataset["validation"]["abc notation"])
validation_data = torch.tensor(encode(text))
print(len(validation_data))

In [19]:
# example of training samples
b_size = 8
x = training_data[:b_size]
y = training_data[1:b_size+1]
for t in range(1, b_size):
    context = x[:t]
    target = y[t]
    print(context, "->", target)

tensor([56]) -> tensor(17)
tensor([56, 26]) -> tensor(0)
tensor([56, 26, 17]) -> tensor(44)
tensor([56, 26, 17,  0]) -> tensor(26)
tensor([56, 26, 17,  0, 44]) -> tensor(17)
tensor([56, 26, 17,  0, 44, 26]) -> tensor(15)
tensor([56, 26, 17,  0, 44, 26, 17]) -> tensor(24)


In [115]:
# bach generator
def get_batch(split):
    if split == "train":
        data = training_data
    elif split == "validation":
        data = validation_data
    else:
        raise ValueError("split must be 'train' or 'validation'")
    start_idx = torch.randint(0, data.size(0) - block_size, (bach_size,))
    x = torch.stack([data[idx:idx+block_size] for idx in start_idx]).to(device)
    y = torch.stack([data[idx+1:idx+block_size+1] for idx in start_idx]).to(device)
    return x, y

In [106]:
torch.manual_seed(42)
x, y = get_batch("train")
print(x.shape, y.shape)
print('input')
print(x)
print('target')
print(y)

for b in range(bach_size):
    for t in range(block_size):
        context = x[b, :t+1]
        target = y[b, t]
        print(context, "->", target)

torch.Size([32, 8]) torch.Size([32, 8])
input
tensor([[67,  1, 92,  1, 68, 71, 71,  1],
        [92,  3, 36,  3,  1, 68, 18, 36],
        [34,  1, 92,  1, 67, 68, 67,  1],
        [61, 68,  1, 67, 69,  1, 67, 69],
        [34, 33,  1, 39, 34, 37, 65,  1],
        [ 1, 92,  1,  0,  1, 68, 34, 39],
        [ 3, 39,  3,  1, 34, 69,  1, 68],
        [67,  1, 34,  1, 33,  1, 92,  1],
        [ 1, 14, 36, 30, 36,  1,  8, 37],
        [37, 66,  3,  1, 37, 18,  1, 92],
        [ 1, 65, 18,  1, 92,  1, 66, 19],
        [39,  1, 92,  1, 34, 67, 68,  1],
        [92,  1, 38, 33,  1, 33, 15, 34],
        [ 1, 92,  1, 37, 33, 33,  1, 67],
        [ 3, 33, 23,  3,  1, 65, 18, 65],
        [68, 70, 68,  1, 92,  1,  8, 68],
        [92,  1, 38, 39, 33, 38,  1, 36],
        [26,  1, 33, 34,  1, 33, 18,  1],
        [ 1,  8, 69, 65,  9,  8, 71, 66],
        [ 1, 68, 19,  1, 92,  1, 69, 18],
        [71, 34, 70, 34,  1, 69, 68, 34],
        [ 1, 92,  3, 35,  3,  1, 67, 68],
        [ 1, 37, 18,  1, 92,  

## Bigram Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super(BigramModel, self).__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        B, T, C = logits.size()
        
        if targets is not None:
            loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        else:
                loss = None
        return logits, loss
    
    def generate(self, idx, n):
        for _ in range(n):
            logits = self.token_embedding_table(idx)
            next_idx = torch.multinomial(F.softmax(logits[:, -1], dim=1), 1)
            idx = torch.cat([idx, next_idx], dim=1)
        return idx

In [139]:
x, y = get_batch("train")
m = BigramModel(vocab_size)
m.to(device)
logits, loss = m(x, y)  
print(logits.shape)
print(loss)

torch.Size([32, 8, 95])
tensor(5.0066, device='cuda:0', grad_fn=<NllLossBackward0>)


In [138]:
idx = torch.zeros(1, 1).long().to(device)
g = m.generate(idx, 100)
decode(g[0].tolist())

'\n d/c2 FDEE2 | :F2 f\'df/D AG D | :12 A/G/) | |"!e/aed A D | dor" dce Bc D7374\nQ:| AG,] fedb2 g A d/ B'

### training

In [140]:
def estimate_loss(model, eval_iters):
    out = {}
    model.eval()
    for split in dataset:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            x, y = get_batch(split)
            _, loss = model(x, y)
            losses[i] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [141]:
n_iters = 10000
optimizer = torch.optim.Adam(m.parameters(), lr=0.001)

In [142]:
for step in range(n_iters):
    x, y = get_batch("train")
    optimizer.zero_grad()
    logits, loss = m(x, y)
    loss.backward()
    optimizer.step()
    if step % (n_iters//10) == 0:
        losses = estimate_loss(m, 100)
        print(f"step: {step}, train loss: {losses['train']:.3f}, validation loss: {losses['validation']:.3f}")

step: 0, train loss: 5.090, validation loss: 5.091
step: 1000, train loss: 3.952, validation loss: 3.940
step: 2000, train loss: 3.199, validation loss: 3.190
step: 3000, train loss: 2.770, validation loss: 2.743
step: 4000, train loss: 2.516, validation loss: 2.522
step: 5000, train loss: 2.419, validation loss: 2.411
step: 6000, train loss: 2.379, validation loss: 2.359
step: 7000, train loss: 2.351, validation loss: 2.324
step: 8000, train loss: 2.342, validation loss: 2.301
step: 9000, train loss: 2.306, validation loss: 2.290


In [121]:
idx = torch.zeros(1, 1).long().to(device)
g = m.generate(idx, 500)

decode(g[0].tolist()).replace("\n", "")

' f G) B, |" | d"DFGB c2 :1/2 | :52cB E |:GAAG | f [G F G2 cB2 :4 E cAG2 | | ABAG2 F2 fd\'bag>[Bc2 | G | f e cB :A G2 GF Fd a2 dc2 E2d fg | d A g d2 B c ga2 gef.F fee\'bd ec d :4/F2 efa" c4 d |" A |][E"G | d c3 cA/dA FE efgf AEA2>VK:7"B ed3 ef d>Bc |{G [E45L: e2 A | |"Da G2 c2 d c  fgf :12dcd e (3F4]L:C2/) | d2 G/83 aar"Em" | |]49@BA cB e2 c2 |" ceGE3 c cGB,2 Bc/A3 ef2Q:11/)A2/A ed>EF GA FCd/dB2 d2 A cA7"G D>^_A[HJ~$ || | [e/g207"G GF d2 dG2 ata!gb2X:1/2 |18"A<E4 | f BGFE(G2 Ae2 | e fgc G/d2'

## Stam

In [None]:
## stam