In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
print(f"{torch.cuda.is_available()=}")
device="cpu"
if torch.cuda.is_available():
    device="cuda:0"
print(device)

In [None]:
names_f = "tinyshakespeare/input.txt"
with open(names_f) as f:
    text = f.read()

#random.seed(42)
print(text[:30])
print(f"{len(text)=}")

chars = sorted(set(text))
voc_size = len(chars)
print(f"{chars[:100]=}")
print(f"{voc_size=}")

itos = dict()
stoi = dict()
for i, c in enumerate(chars):
    itos[i] = c
    stoi[c] = i
def encode(ss):
    return [stoi[c] for c in ss]
def decode(ii):
    return ''.join([itos[i] for i in ii])
print(encode("Hello\nWorld"))
print(decode(encode("Hello\nWorld")))

data = torch.tensor(encode(text), dtype=torch.long, device=device)
print(f"{data.shape=}")
print(data[:30])
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [None]:
block_size = 8

def get_batch(data, batch_size, device):
    ix = torch.randint(low=0,high=len(data)-block_size-1, size=(batch_size,), device=device)
    x = torch.stack([data[i : i+block_size] for i in ix]).to(device)
    y = torch.stack([data[i+1 : i+block_size+1] for i in ix]).to(device)
    return x, y

In [None]:
head_size = 16
n_heads = 8
emb_size = head_size * block_size
n_layers = 6
dropout_rate = 0.2


class DotProductAttn(nn.Module):
    def __init__(self, is_mask) -> None:
        super().__init__()
        self.Q = nn.Linear(emb_size, head_size, bias=False)
        self.K = nn.Linear(emb_size, head_size, bias=False)
        self.V = nn.Linear(emb_size, head_size, bias=False)
        self.scale = head_size ** -0.5
        self.is_mask = is_mask
        if is_mask:
            self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x: torch.Tensor):
        # x shape N,BL,EMB
        q = self.Q(x) # N,BL,H
        k = self.K(x) # N,BL,H
        v = self.V(x) # N,BL,H
        k = torch.transpose(k, -2, -1) # N,H,BL
        y = (q @ k) * self.scale # N,BL,BL
        if self.is_mask:
            y = torch.masked_fill(y, self.tril == 0, float('-inf')) # Mask
        y = F.softmax(y, dim=-1)
        y = self.dropout(y)
        y = y @ v # N,BL,BL @ N,BL,H -> N,BL,H 
        return y

class MultiHeadAttn(nn.Module):
    def __init__(self, n_heads, is_mask) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.heads = nn.ModuleList([DotProductAttn(is_mask) for _ in range(n_heads)])
        self.lin = nn.Linear(emb_size, emb_size, bias=True)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x):
        # x shape N,BL,EMB
        out = [head(x) for head in self.heads] #[(N,BL,H) x n_heads] 
        out = torch.cat(out, dim=-1) # N,BL,Hxn_heads
        out = self.lin(out)
        out = self.dropout(out)
        return out

class DecoderBlock(nn.Module):
    def __init__(self, n_heads) -> None:
        super().__init__()
        self.mh_attn = MultiHeadAttn(n_heads=n_heads, is_mask=True)
        self.FF = nn.Sequential(
            nn.Linear(emb_size, emb_size*4, bias=True),
            nn.ReLU(),
            nn.Linear(emb_size*4, emb_size, bias=True),
            nn.Dropout(dropout_rate)
        )
        self.ln1 = nn.LayerNorm(emb_size)
        self.ln2 = nn.LayerNorm(emb_size)
    def forward(self, x):
        # x shape N,BL,EMB
        y = x + self.mh_attn(x) # N,BL,Hxn_heads
        y = self.ln1(y)
        out = y + self.FF(y) # N,BL,Hxn_heads # FeedForward
        out = self.ln2(out)
        return out


class NanoGptModel(nn.Module):
    def __init__(self, voc_size) -> None:
        super().__init__()
        self.token_emb = nn.Embedding(voc_size, emb_size)
        self.position_emb = nn.Embedding(block_size, emb_size)
        self.blocks = nn.Sequential(*[DecoderBlock(n_heads) for _ in range(n_layers)])
        self.lin = nn.Linear(emb_size, voc_size, bias=True)
    def forward(self, ids):
        B, T = ids.shape
        token_out = self.token_emb(ids)
        pos_out = self.position_emb(torch.arange(T, device=device))
        out = token_out + pos_out
        out = self.blocks(out)
        out = self.lin(out)
        return out
    def calc_loss(self, logits, Y):
        logits = logits.transpose(1,2)
        return F.cross_entropy(logits, Y)
    @torch.no_grad()
    def generate(self, ids, max_new_tokens):
        for i in range(max_new_tokens):
            logits = self(ids[:,-block_size:])
            logits = logits[:,-1,:]
            prob = torch.softmax(logits, dim=-1)
            y = torch.multinomial(prob, num_samples=1)
            ids = torch.cat((ids, y), dim=-1)
        return ids.detach().cpu().numpy()

# head = DotProductAttn(is_mask=True)
# mha = MultiHeadAttn(n_heads=n_heads, is_mask=True)
# dec_block0 = DecoderBlock(n_heads)
# dec_block1 = DecoderBlock(n_heads)
model = NanoGptModel(voc_size=voc_size).to(device)
print("Numel:", sum([p.numel() for p in model.parameters()]))

x, y = get_batch(train_data, 1, device)
out = model(x)
print(out.shape)
loss = model.calc_loss(out, y)
print("loss:", loss)

res = model.generate(x, 4)
print(res.shape)
print(decode(res[0]))

lossi = []


In [None]:
print("Exporting to ONNX...")
torch.onnx.export(model, x, "gpt.onnx")
print("Done")

In [None]:
# Training Loop
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 32
WIN = []
N = 1000
for i in range(N):
    if i > 0:
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    x, y_target = get_batch(train_data, batch_size, device)
    logits=model(x)
    loss = model.calc_loss(logits, y_target)
    WIN.append(loss.detach().cpu().item())
    if (i+1) % 100 == 0:
        avg_loss = np.mean(WIN)
        WIN=[]
        lossi.append(avg_loss)

if lossi:
    print(f"{lossi[-1]=}")

In [None]:
if lossi:
    plt.figure(figsize=(20,5))
    
    plt.grid()
    plt.plot(lossi)

In [None]:
#x, y = get_batch(train_data, 1, device)
model.eval()
x = torch.zeros((1,8), dtype=torch.long, device=device)
res = model.generate(x, 300)
print(decode(res[0]))