In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
torch.manual_seed(69)

<torch._C.Generator at 0x2287fabd5d0>

In [3]:
with open("data/input.txt", "r") as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
VOCAB_SIZE = len(chars)
CHAR_TO_INDEX = {ch: i for i, ch in enumerate(chars)}
INDEX_TO_CHAR = {i: ch for i, ch in enumerate(chars)}

def encode(text):
    return torch.tensor([CHAR_TO_INDEX[ch] for ch in text], dtype=torch.long)

def decode(tensor):
    return "".join([INDEX_TO_CHAR[int(i)] for i in tensor])

In [5]:
encoded_text = encode('Luke, I am your father.')
print(encoded_text)
print(decode(encoded_text))

tensor([24, 59, 49, 43,  6,  1, 21,  1, 39, 51,  1, 63, 53, 59, 56,  1, 44, 39,
        58, 46, 43, 56,  8])
Luke, I am your father.


In [6]:
data = torch.tensor(encode(text), dtype=torch.long)
pct = 0.9
train_data = data[:int(len(data)*pct)]
val_data = data[int(len(data)*pct):]

  data = torch.tensor(encode(text), dtype=torch.long)


In [7]:
# CONSTANTS
EMBEDDING_DIM = 384
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
SEQ_LEN = 100
BATCH_SIZE = 64

def get_batch(split):
    dt = train_data if split == "train" else val_data
    ix = torch.randint(len(dt) - SEQ_LEN, (BATCH_SIZE,))
    x = torch.stack([dt[i:i+SEQ_LEN] for i in ix])
    y = torch.stack([dt[i+1:i+SEQ_LEN+1] for i in ix])
    x, y = x.to(DEVICE), y.to(DEVICE)
    return x, y

In [9]:
class Head(nn.Module):
    def __init__(self, head_size):
        super(Head, self).__init__()
        self.key = nn.Linear(EMBEDDING_DIM, head_size, bias=False)
        self.query = nn.Linear(EMBEDDING_DIM, head_size, bias=False)
        self.value = nn.Linear(EMBEDDING_DIM, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(SEQ_LEN, SEQ_LEN)))

    def forward(self, x):
        B, S, E = x.shape
        K = self.key(x) # B x S x H
        Q = self.query(x) # B x S x H
        V = self.value(x) # B x S x H
        
        w = Q @ K.transpose(-2, -1) / (K.shape[-1] ** 0.5) # B x S x S
        w = w.masked_fill(self.tril[:S, :S] == 0, float("-inf")) # B x S x S
        w = F.softmax(w, dim=-1) # B x S x S
        out = w @ V # (B x S x S) @ (B x S x H) = B x S x H
        return out

In [10]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.tkn_emb = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.pos_emb = nn.Embedding(SEQ_LEN, EMBEDDING_DIM)
        self.head = Head(EMBEDDING_DIM)
        self.fc = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)

    def forward(self, x, target=None):
        B, S = x.shape

        tkn = self.tkn_emb(x) # B, S, E
        pos = self.pos_emb(torch.arange(x.shape[1], device=DEVICE)) # S, E
        x = tkn + pos # B, S, E
        x = self.head(x) # B, S, E
        x = self.fc(x) # B, S, V

        if target is None:
            loss = 0
        else:
            B, T, C = x.shape
            logits = x.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
        
        return x, loss
    
    def generate(self, x, num_chars):
        for _ in range(num_chars):
            inp = x[:,-SEQ_LEN:]
            tmp, loss = self(inp)
            tmp = tmp[:, -1, :]
            tmp = F.softmax(tmp, dim=0)
            tmp = torch.multinomial(tmp, 1)
            x = torch.cat([x, tmp], dim=1)
        return x

In [11]:
model = Attention()
optimizer = optim.Adam(model.parameters(), lr=0.0003)

In [12]:
NUM_EPOCHS = 4000
model.to(DEVICE)
for epoch in range(NUM_EPOCHS):
    model.train()
    x, y = get_batch("train")
    optimizer.zero_grad()
    y_pred, loss = model(x, y)
    loss.backward()
    optimizer.step()
    model.eval()
    x, y = get_batch("val")
    y_pred, loss = model(x, y)
    print(f"Epoch {epoch+1} Loss: {loss}")

Epoch 1 Loss: 4.107535362243652
Epoch 2 Loss: 4.033325672149658
Epoch 3 Loss: 3.9583740234375
Epoch 4 Loss: 3.866675615310669
Epoch 5 Loss: 3.7723076343536377
Epoch 6 Loss: 3.6896584033966064
Epoch 7 Loss: 3.5706939697265625
Epoch 8 Loss: 3.543503761291504
Epoch 9 Loss: 3.4775495529174805
Epoch 10 Loss: 3.3950695991516113
Epoch 11 Loss: 3.4410018920898438
Epoch 12 Loss: 3.4050300121307373
Epoch 13 Loss: 3.3296494483947754
Epoch 14 Loss: 3.318220853805542
Epoch 15 Loss: 3.2725324630737305
Epoch 16 Loss: 3.222264289855957
Epoch 17 Loss: 3.1862375736236572
Epoch 18 Loss: 3.194462537765503
Epoch 19 Loss: 3.179274559020996
Epoch 20 Loss: 3.176138401031494
Epoch 21 Loss: 3.131758451461792
Epoch 22 Loss: 3.0989644527435303
Epoch 23 Loss: 3.138469934463501
Epoch 24 Loss: 3.1327576637268066
Epoch 25 Loss: 3.1011617183685303
Epoch 26 Loss: 3.0608696937561035
Epoch 27 Loss: 3.0738275051116943
Epoch 28 Loss: 3.05100154876709
Epoch 29 Loss: 3.00605845451355
Epoch 30 Loss: 3.0444841384887695
Epoch 3

In [14]:
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
print(decode(model.generate(context, 10000)[0].tolist()).strip())

DRPUq?w.Wy-LF$L KSjqLLpcSzwYPi?Pk;AxyJ3qrYZ,&:LZvVAMExKi-Dm?tEvQUCUe$yeF?XdNeC;i3WWa!RJK!:-o;BQEfT$
nXencCnNWhE'Y-yFnxSnPC 3ETofyDhTxB-AjPYH-'3UB:vOl?k3zsYAABZNXdDo$!xE;Uv
eONUnbDpfkVnXhMJQh;GycTJ' WGyTNh3YG!fW-PvRqnnjSv'HaeJc3KDsm
;nahUTHyHSLqI'Kcp ;kNzvB,RDuFYZCxg!E;nybGZQ,V-VkUkxelWD3vgIv'Nvc?o-Ed,q?cwHjWVGBUM:?
AjsB&T?$OyfFXCp&hsAfZClnxwZ,ApnkjcdjMLcrFDw&,uGs-rqrnTYOQudcUF$&QFbQY$.gcq?!R!c!bO
RMKNCPcoUciJ;zGU-$DRgkYRN
gpvMW-oN;EL$
aAOwY 3XkI$Jqt FVtFAHDfFUQLWmCrPx?sEYbJG,jca?cs;hhGS:MjCKQrPWqC
Jo'e&bhNJxqFiioKS!hKurZL.yOA:dylrlL;ZcZJV3krbjELiMwL-rCyKpCCGGQfu3X-O?fzKm$dsRianO?nxnQ- dZwR:,MoAWqLPbQPoITNyhQ';EYNWGeggPCn3HBiyPRpCkVeX!,X.!Tagf
KdEv3!OypJrCN!:ztIpMNuCJ.EizCdkSodq A-VeHoEknHrKucYJqmLRnKp?tNzMdCzslk
vej!zSzYLJ?uo&tLrBm.v.oVTAws.hNHonVi!X!!STZ,TY&yAEt$xG- Ysl&JZxLjGurJERmqX;HvXMAMa?w:zZVV3,Yrj?XMWs-rZiioH,MMDSP.3?mJtYst,mxRdveAN-;F-VBxYqO
uPt'f'v!Oq&jSJCp3k-ApuM
vfk$cOQ3lBftZ$qikTK'IfkY?luJodd;,IaKl.bhwjJObX3Q'?yPDG!cZd$hFgc:S 
Bai
rEPlJwuzjfMHZbqryurE MlXTKkZ,flM-ybYz P3ec