In [27]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [28]:
with open('input.txt', 'r') as f:
    text = f.read()


In [29]:
print(len(text))
print(text[:100])

217779251


April
April is the fourth month of the year in the Julian and Gregorian calendars, and comes betwe


In [30]:
chars = sorted(list(set(''.join(text))))
vocab_size = len(chars)

In [31]:
len(chars)

94

In [51]:
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join(itos[i] for i in l)

In [33]:
data = torch.tensor(encode(text), dtype=torch.long)

In [34]:
print(data.shape)

torch.Size([217779251])


In [35]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [36]:
block_size = 8
batch_size = 4
embed_dim = 32
learning_rate = 1e-3
max_iters = 10000
eval_interval = 1000
eval_iters = 1000

In [37]:
def get_batch(split):
    data = {
        'train': train_data,
        'val': val_data}[split]

    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

In [65]:
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, inp, targets=None):
        token_emb = self.token_embedding(inp)

        x = self.lm_head(token_emb)
        
        logits = x

        batch_dim, time_dim, features = logits.shape
        logits = logits.view(batch_dim * time_dim, features)

        if targets is None:
            loss = None
        else:
            targets = targets.view(batch_dim * time_dim)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, context, max_new_tokens):
        for i in range(max_new_tokens):
            context_cur = context[:, -1:]
            
            logits, loss = self(context_cur)
            probs = F.softmax(logits, dim=-1)

            new_tokens = torch.multinomial(probs, 1)

            context = torch.cat([context_cur, new_tokens], dim=-1)

        return context

In [66]:
model = LanguageModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [67]:
@torch.no_grad()
def eval_loss(split):
    losses = torch.zeros(eval_iters)
    for i in range(eval_iters):
        xb, yb = get_batch(split)

        logits, loss = model(xb, yb)
        losses[i] = loss.item()
    
    return losses.mean()

In [68]:
for steps in range(max_iters):
    if steps % eval_interval == 0:
        print(f"step {steps:4d}: train loss: {eval_loss('train'):.8f}, val_loss: {eval_loss('val'):.8f}")
    
    xb, yb = get_batch('train')
    
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step    0: train loss: 4.64820433, val_loss: 4.64072752
step 1000: train loss: 2.87567997, val_loss: 2.91412330
step 2000: train loss: 2.75355339, val_loss: 2.77404761
step 3000: train loss: 2.68470812, val_loss: 2.72870922
step 4000: train loss: 2.67873430, val_loss: 2.71961594
step 5000: train loss: 2.66857076, val_loss: 2.69368982
step 6000: train loss: 2.64850140, val_loss: 2.67828488
step 7000: train loss: 2.65526509, val_loss: 2.67901921
step 8000: train loss: 2.66009951, val_loss: 2.66616344
step 9000: train loss: 2.64532828, val_loss: 2.66762018


In [69]:
eval_loss('val')

tensor(2.6572)

In [70]:
context = torch.tensor([[0]])
print(model.generate(context, 200))