In [128]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [129]:
with open('input.txt', 'r') as f:
    text = f.read()


In [130]:
print(len(text))
print(text[:100])

1115433
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [133]:
chars = sorted(list(set(''.join(text))))
vocab_size = len(chars)

In [134]:
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join(itos[i] for i in l)

In [135]:
data = torch.tensor(encode(text), dtype=torch.long)

In [136]:
print(data.shape)

torch.Size([1115433])


In [137]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [188]:
block_size = 8
batch_size = 4
embed_dim = 32
learning_rate = 1e-3
max_iters = 10000
eval_interval = 1000
eval_iters = 1000

In [189]:
def get_batch(split):
    data = {
        'train': train_data,
        'val': val_data}[split]

    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

In [190]:
class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, inp, targets):
        token_emb = self.token_embedding(inp)

        x = self.lm_head(token_emb)
        
        logits = x

        B, T, C = logits.shape
        logits = logits.view(B * T, C)
        targets = targets.view(B * T)

        loss = F.cross_entropy(logits, targets)

        return logits, loss

In [191]:
model = LanguageModel()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [184]:
@torch.no_grad()
def eval_loss(split):
    losses = torch.zeros(eval_iters)
    for i in range(eval_iters):
        xb, yb = get_batch(split)

        logits, loss = model(xb, yb)
        losses[i] = loss.item()
    
    return losses.mean()

In [186]:
for steps in range(max_iters):
    if steps % eval_interval == 0:
        print(f"train loss: {eval_loss('train')}, val_loss: {eval_loss('val')}")
    
    xb, yb = get_batch('train')
    
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

train loss: 2.489543914794922, val_loss: 2.5156710147857666
train loss: 2.4857606887817383, val_loss: 2.519871234893799
train loss: 2.487074136734009, val_loss: 2.525416612625122
train loss: 2.4808971881866455, val_loss: 2.5102148056030273
train loss: 2.4880311489105225, val_loss: 2.5155980587005615
train loss: 2.4919447898864746, val_loss: 2.5184643268585205
train loss: 2.4774417877197266, val_loss: 2.508772611618042
train loss: 2.4921562671661377, val_loss: 2.5196874141693115
train loss: 2.4776790142059326, val_loss: 2.513535737991333
train loss: 2.4792416095733643, val_loss: 2.500910758972168


In [187]:
eval_loss('val')

tensor(2.4994)