In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [283]:
text = ""
with open("shakes.txt", "r") as f:
    text = f.read()

In [284]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
block_size = 8
head_size = 16
emb_dim = 32
batch_size = 4
num_of_decoders = 6
num_of_heads = 4
dropout = 0.2
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

In [285]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda s: "".join([itos[c] for c in s])

In [286]:
data = torch.tensor(encode(text), dtype=torch.long)

In [287]:
train_sz = int(len(data) * 0.9)
train_data = data[:train_sz]
val_data = data[train_sz:]
print(len(train_data), len(val_data))

1003854 111540


In [288]:
def get_batch(split, batch_size, block_size):
    X, Y = None, None
    dt = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(dt) - block_size, (batch_size, ))
    X = torch.stack([data[i:i+block_size] for i in ix])
    Y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return X, Y
        
        

In [289]:
token_embeddings_table = nn.Embedding(vocab_size, vocab_size)
position_embeddings_table = nn.Embedding(block_size, vocab_size)
position_embeddings_table(torch.arange(block_size)).shape

torch.Size([8, 65])

In [290]:
class Blocks(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = MultiHead(num_of_heads)
        self.ff = FeedForward()
        self.headNorm = nn.LayerNorm(emb_dim)
        self.ffNorm = nn.LayerNorm(emb_dim)
    def __call__(self, x):
        x = x + self.heads(self.headNorm(x))
        x = x + self.ff(self.ffNorm(x))
        return x

In [291]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
        nn.Linear(emb_dim, 4 * emb_dim), 
        nn.ReLU(), 
        nn.Linear(4 * emb_dim, emb_dim), 
        nn.Dropout(dropout)
        )
    def __call__(self, x):
        return self.net(x)

In [292]:
class MultiHead(nn.Module):
    def __init__(self, num_of_head):
        super().__init__()
        self.heads = nn.ModuleList([Head() for _ in range(num_of_head)])
        self.lin = nn.Linear(head_size * num_of_head, emb_dim)
        self.dropout = nn.Dropout(dropout)
    def __call__(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.lin(out)
        return self.dropout(out)

In [293]:
class Head(nn.Module):
    # single self attention head
    def __init__(self):
        super().__init__()
        self.query = nn.Linear(emb_dim, head_size, bias=False)
        self.key = nn.Linear(emb_dim, head_size, bias=False)
        self.value = nn.Linear(emb_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def __call__(self, x):
        B, T, C = x.shape
        q, k, v = self.query(x), self.key(x), self.value(x)
        we = q @ k.transpose(-2, -1)
        we = we * (head_size ** -0.5)
        we = we.masked_fill(torch.tril(torch.ones(T, T)) == 0, float('-inf'))
        we = torch.softmax(we, -1)
        we = self.dropout(we)
        out = we @ v
        return out

In [294]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embeddings_table = nn.Embedding(vocab_size, emb_dim)
        self.position_embeddings_table = nn.Embedding(block_size, emb_dim)
        self.blocks = nn.Sequential(*[Blocks() for _ in range(num_of_decoders)])
        self.lNorm = nn.LayerNorm(emb_dim)
        self.lin = nn.Linear(emb_dim, vocab_size)
    def forward(self, idx, y=None):
        embs = self.token_embeddings_table(idx)
        poss = self.position_embeddings_table(torch.arange(block_size))
        x = embs + poss
        x = self.blocks(x)
        x = self.lNorm(x)
        logits = self.lin(x)
        loss = None
        if y != None:
            B, T, C = logits.shape
            logits2 = logits.view(B*T, C)
            ys = y.view(-1)
            loss = F.cross_entropy(logits2, ys)
        return logits, loss
    def generate(self, idx, max_chars):
        for _ in range(max_chars):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, 1)
            ix = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, ix), 1)
        return idx
        

In [295]:
model = BigramLanguageModel(vocab_size)
# Xs, Ys = get_batch('train', 4, 8)
# logits, loss = model.forward(Xs, Ys)
# decode(model.generate(torch.zeros(1, 1, dtype=torch.long), 150)[0].tolist())

In [306]:
max_iter = 20000
learning_rate = 1e-3
loss
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
losses = []
for _ in range(max_iter):
    Xb, Yb = get_batch('train', batch_size, block_size)
    logits, loss = model(Xb, Yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

In [303]:
mean = sum(losses) / len(losses)
print("Mean:", mean)

Mean: 2.0764315750598907


In [304]:
# Test on val data

with torch.no_grad():
    X, Y = get_batch('111540', 111540, block_size)
    logits, loss = model(X, Y)
    print(loss.item())

2.0478105545043945


In [305]:
decode(model.generate(torch.zeros(1, 1, dtype=torch.long), 150)[0].tolist())

RuntimeError: The size of tensor a (2) must match the size of tensor b (8) at non-singleton dimension 1

In [31]:
# simple self attention
x = torch.randn(4, 8, 32)
B, T, C = x.shape
head_size = 16
query, key, value = nn.Linear(C, head_size, bias=False), nn.Linear(C, head_size, bias=False), nn.Linear(C, head_size, bias=False)
q, k, v = query(x), key(x), value(x)
we = q @ k.transpose(-2, -1)
we = we.masked_fill(torch.tril(torch.ones(T, T)) == 0, float('-inf')) # decoder self attention
we.shape
we = torch.softmax(we, -1)
out = we @ v
out.shape

torch.Size([4, 8, 16])

In [7]:
X = torch.randint(5, (4, 3, 2))
X

tensor([[[2, 2],
         [0, 2],
         [0, 3]],

        [[2, 0],
         [0, 2],
         [3, 1]],

        [[2, 4],
         [2, 4],
         [3, 3]],

        [[1, 1],
         [3, 3],
         [2, 3]]])

In [10]:
C = torch.randn(27, 50)
C[X].shape

torch.Size([4, 3, 2, 50])