In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
BATCH_SIZE = 64
SEQ_LEN = 32
EPOCHS = 4000
lr = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EMBEDDING_DIM = 128
torch.manual_seed(69)

<torch._C.Generator at 0x22a0e18d5d0>

In [3]:
with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
# create a mapping from characters to integers
CHAR_TO_INDEX = { ch:i for i,ch in enumerate(chars) }
INDEX_TO_CHAR = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [CHAR_TO_INDEX[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([INDEX_TO_CHAR[i] for i in l]) # decoder: take a list of integers, output a string

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
pct = int(0.9*len(data))
train_data = data[:pct]
val_data = data[pct:]

In [6]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([data[i:i+SEQ_LEN] for i in ix])
    y = torch.stack([data[i+1:i+SEQ_LEN+1] for i in ix])
    x, y = x.to(DEVICE), y.to(DEVICE)
    return x, y

In [7]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(EMBEDDING_DIM, head_size, bias=False)
        self.query = nn.Linear(EMBEDDING_DIM, head_size, bias=False)
        self.value = nn.Linear(EMBEDDING_DIM, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(SEQ_LEN, SEQ_LEN)))

    def forward(self, x):
        B,S,C = x.shape
        k = self.key(x)   # (B,S,C)
        q = self.query(x) # (B,S,C)
        v = self.value(x) # (B,S,C)
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, S, C) @ (B, C, S) -> (B, S, S)
        wei = wei.masked_fill(self.tril[:S, :S] == 0, float('-inf')) # (B, S, S)
        wei = F.softmax(wei, dim=-1) # (B, S, S)
        out = wei @ v # (B, S, S) @ (B, S, C) -> (B, S, C)
        return out

In [8]:
class MultipleHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.fc = nn.Linear(num_heads*head_size, EMBEDDING_DIM)

    def forward(self, x):
        return self.fc(torch.cat([head(x) for head in self.heads], dim=-1))

In [9]:
class Attention(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.position_embedding_table = nn.Embedding(SEQ_LEN, EMBEDDING_DIM)
        self.heads = MultipleHead(num_heads=8, head_size=EMBEDDING_DIM // 4)
        self.fc = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=DEVICE)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.heads(x) # apply one head of self-attention. (B,T,C)
        logits = self.fc(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -SEQ_LEN:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [10]:
model = Attention()
optimizer = optim.Adam(model.parameters(), lr=lr)
model = model.to(DEVICE)

In [11]:
model.to(DEVICE)
for epoch in range(EPOCHS):
    model.train()
    x, y = get_batch("train")
    optimizer.zero_grad()
    y_pred, loss = model(x, y)
    loss.backward()
    optimizer.step()
    model.eval()
    x, y = get_batch("val")
    y_pred, loss = model(x, y)
    print(f"Epoch {epoch+1} Loss: {loss}")

Epoch 1 Loss: 4.114879608154297
Epoch 2 Loss: 4.0493645668029785
Epoch 3 Loss: 3.9581212997436523
Epoch 4 Loss: 3.8425350189208984
Epoch 5 Loss: 3.729478597640991
Epoch 6 Loss: 3.6322720050811768
Epoch 7 Loss: 3.5698821544647217
Epoch 8 Loss: 3.4468836784362793
Epoch 9 Loss: 3.437208414077759
Epoch 10 Loss: 3.3888726234436035
Epoch 11 Loss: 3.3205463886260986
Epoch 12 Loss: 3.325303077697754
Epoch 13 Loss: 3.3491291999816895
Epoch 14 Loss: 3.2224783897399902
Epoch 15 Loss: 3.3448495864868164
Epoch 16 Loss: 3.343935966491699
Epoch 17 Loss: 3.341540575027466
Epoch 18 Loss: 3.258019208908081
Epoch 19 Loss: 3.264652967453003
Epoch 20 Loss: 3.2364084720611572
Epoch 21 Loss: 3.1095621585845947
Epoch 22 Loss: 3.166670322418213
Epoch 23 Loss: 3.197291612625122
Epoch 24 Loss: 3.0519814491271973
Epoch 25 Loss: 3.1320226192474365
Epoch 26 Loss: 3.0624849796295166
Epoch 27 Loss: 3.0909106731414795
Epoch 28 Loss: 2.983593702316284
Epoch 29 Loss: 3.0262269973754883
Epoch 30 Loss: 2.988298177719116
E

In [12]:
# context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
context = 'All:'
context = torch.tensor(encode(context), dtype=torch.long, device=DEVICE).unsqueeze(0)
context = context.to(DEVICE)
print(decode(model.generate(context, max_new_tokens=1000)[0].tolist()))

All:
In nables was in cousur ining love!

QUEEN ELEY:
Hespantace notend Citionfull plourse it suporld with event.

LUCIO:
Thersof as tiply minence.

GREMIO:
I is she's a Leturr youndly coulless do my Sepauch of texe I washis brothat 'il diting well by like
was it in soffiel's dring andray,
Stilus he
roclorense's frituntriendients, ver sones,
There is the, ands night to beforguinl.

Sharh saw! weregrnespesomes, that catents froce isle!
Clarenced your firsct's whered bearlike so's noturd I my tiens,
The confied lovess awain'd ming sir, I me smany you it
that Poncle. Marce to nome cuse tentruess thears.

TYABREY:
Alleves, and unter?

KING RICHARD II:
sneopl nhem.

Meign fliege me.

GRUMENENIUS:
Good loan them, Opece!

Fath or hase to yeal sople, wrence of wife;
But them, band say, ints sit, will wiff liend dirglackeparan'd causelived burness thrown bight with now poclay.

DUCHESS OF YORK:
A Hare so dom bideme ratite tembo; yet bastren my were old besme forss my and towny,
Then usenters.

