In [113]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [114]:
text = ""
with open("shakes.txt", "r") as f:
    text = f.read()

In [115]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

In [116]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda s: "".join([itos[c] for c in s])

In [117]:
data = torch.tensor(encode(text), dtype=torch.long)

In [118]:
train_sz = int(len(data) * 0.9)
train_data = data[:train_sz]
val_data = data[train_sz:]
print(len(train_data), len(val_data))

1003854 111540


In [119]:
def get_batch(split, batch_size, block_size):
    X, Y = None, None
    dt = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(dt) - block_size, (batch_size, ))
    X = torch.stack([data[i:i+block_size] for i in ix])
    Y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return X, Y
        
        

In [128]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, vocab_size)
    def forward(self, idx, y=None):
        logits = self.emb(idx)
        loss = None
        if y != None:
            B, T, C = logits.shape
            logits2 = logits.view(B*T, C)
            ys = y.view(-1)
            loss = F.cross_entropy(logits2, ys)
        return logits, loss
    def generate(self, idx, max_chars):
        for _ in range(max_chars):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, 1)
            ix = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, ix), 1)
        return idx
        

In [130]:
model = BigramLanguageModel(vocab_size)
Xs, Ys = get_batch('train', 4, 8)
logits, loss = model.forward(Xs, Ys)
decode(model.generate(torch.zeros(1, 1, dtype=torch.long), 150)[0].tolist())

"\nXxCCPxKdiyyDhDXfdHWS:3c?$yktU. UZgTeby!Nxb\nRDrPl&x Sdpy:qp$adiTLXLp&vT KpFMvFDxRh':z yKrHuv;P3aNh-UsxlgG?ntDbWVUfuk3rSBm vayQ$MfsMjj&aE!h,ikd3q,iysKn-"

In [135]:
max_iter = 20000
learning_rate = 1e-3
batch_size = 4
block_size = 8
loss
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for _ in range(max_iter):
    Xb, Yb = get_batch('train', batch_size, block_size)
    logits, loss = model(Xb, Yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.6626808643341064


In [136]:
decode(model.generate(torch.zeros(1, 1, dtype=torch.long), 150)[0].tolist())

"\nHENGLeaf te r fofay gimally a'd,\nwed IO tobllie,-pld im s:\nV:\nOLANERDWAPES:\nMooot buthat? s sthyrr y.\nNCouthere w akiplle ind bus t a teame barais d t"