In [9]:
import torch

In [26]:
with open("input.txt", "r") as fr:
    text = fr.read()

In [27]:
chars = sorted(set(text))
vocab_size = len(chars)

In [28]:
stoi = {c:i for i, c in enumerate(chars)}
itos = {i:c for i, c in enumerate(chars)}

In [29]:
# 'hello'
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: "".join([itos[i] for i in l])

In [30]:
decode(encode("hello, I am learning GPT from scratch"))

'hello, I am learning GPT from scratch'

In [31]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [32]:
block_size = 8
batch_size = 4

In [33]:
def get_split(split):
    data = train_data if split == "train" else val_data
    idx = torch.randint(len(data)-block_size, size=(batch_size,))
    
    x = torch.stack([data[i : i+block_size] for i in idx])
    y = torch.stack([data[i+1 : i + block_size + 1] for i in idx])
    return x, y

In [34]:
x, y = get_split("train")

In [35]:
for b in range(batch_size):
    for t in range(block_size):
        xb = x[b, :t+1].tolist()
        yb = y[b, t].tolist()
        print(f"when context is {xb}, the output is: {yb}")
    print("----------")

when context is [27], the output is: 24
when context is [27, 24], the output is: 13
when context is [27, 24, 13], the output is: 26
when context is [27, 24, 13, 26], the output is: 33
when context is [27, 24, 13, 26, 33], the output is: 31
when context is [27, 24, 13, 26, 33, 31], the output is: 10
when context is [27, 24, 13, 26, 33, 31, 10], the output is: 0
when context is [27, 24, 13, 26, 33, 31, 10, 0], the output is: 21
----------
when context is [56], the output is: 44
when context is [56, 44], the output is: 59
when context is [56, 44, 59], the output is: 50
when context is [56, 44, 59, 50], the output is: 1
when context is [56, 44, 59, 50, 1], the output is: 40
when context is [56, 44, 59, 50, 1, 40], the output is: 43
when context is [56, 44, 59, 50, 1, 40, 43], the output is: 52
when context is [56, 44, 59, 50, 1, 40, 43, 52], the output is: 42
----------
when context is [52], the output is: 42
when context is [52, 42], the output is: 1
when context is [52, 42, 1], the outpu

In [1]:
from torch import nn

In [76]:
from torch.nn import functional as F

In [83]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.emb_table = nn.Embedding(vocab_size, vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, idx, targets = None):
        logits = self.emb_table(idx)  # B X T X C, C=vocab_size
        B, T, C = logits.shape
        loss = None
        if targets is not None: # targets -> B x T
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = self.loss_fn(logits, targets)
        return logits, loss
    

    def generate(self, 
                 idx,  # BT
                 max_new_tokens=200
                 ):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)  # BTC
            logits = logits[:, -1, :]  # BC
            probs = F.softmax(logits, dim=-1) # BC
            idx_next = torch.multinomial(probs, num_samples=1) # B
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

In [84]:
model = BigramLanguageModel(vocab_size=vocab_size)
logits, loss = model(x, y)

In [115]:
print(decode(model.generate(x)[0].tolist()))

th fire
LUCl, and Thin, the s anthe t t we?
APed, frkestwery
Athe gn-whe;
T:

Tybe y ay py t; t d lant houifit beanthe whaje, se.
Gow tooth jor arecitel tcenonodind, LI tashiecelby wed y bloy, rsmbl win the!



In [89]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [114]:
batch_size = 32
for steps in range(1000):
    x, y = get_split("train")
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.4618847370147705
