In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/shakespeare_data/input.txt

In [3]:
device = torch.device("cpu")
# if torch.backends.mps.is_available() and torch.backends.mps.is_built():
#     device = torch.device("mps")

device

device(type='cpu')

In [4]:
txt = open('shakespeare_data/input.txt', 'r').read()
len(txt)

1115394

In [5]:
txt[0:500]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor"

In [6]:
chars = list(set(txt))
chars.sort()

ctoi = {c:i for i, c in enumerate(chars)}
itoc = {i:c for i, c in enumerate(chars)}
vocab_size = len(chars)

print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [7]:
i =  math.floor(0.9 * len(txt))
train_txt = txt[0:i]
valid_txt = txt[i+1:]

len(train_txt), len(valid_txt)

(1003854, 111539)

In [8]:
train_tkns = [ctoi[c] for c in train_txt]
valid_tkns = [ctoi[c] for c in valid_txt]

In [9]:
from numpy.random import randint
block_size = 64

def txt_to_token(t):
    return [ctoi[c] for c in t]


# (B, L)
def random_batch(split="train"):
    data = train_tkns if split == "train" else valid_tkns
    
    i = randint(0, len(data)-block_size-1)
    x = torch.tensor(data[i:i+block_size], device=device)
    y = torch.tensor(data[i+1:i+block_size+1], device=device)
    
    return x, y

x, y = random_batch("train")
x.shape

torch.Size([64])

In [10]:
@torch.no_grad()
def estimate_loss(model,n_iter=10):
    model.eval()
    losses = []

    for split in ["train", "valid"]:   
        loss=0
        for _ in range(n_iter):     
            x, y = random_batch(split)
            logits = model(x) # (L, C)
            #L, C = logits.shape
            loss+= F.cross_entropy(logits, y)
        losses.append(loss.item()/n_iter)

    model.train()
    return losses

In [11]:
@torch.no_grad()
def sample(model):
    model.eval()

    max_len = 500
    tks = [0]*block_size

    for i in range(max_len):
        ctx = torch.tensor(tks[i:i+block_size]) # (L)
        ctx = ctx.view(-1) # (L)

        logits = model(ctx) # (L, C)
        probs = F.softmax(logits, dim=-1) # (L, C)
        probs = probs[-1,:] # (C), # the last in the sequence is the newly generated
        yi = torch.multinomial(probs, 1)
        tks.append(yi.item())

    tks = tks[block_size:]
    chars = [itoc[t] for t in tks]
    model.train()
    return "".join(chars)

In [12]:
class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=0):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim],device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, :]
            self.sin_cached = emb.sin()[:, None, :]
        return self.cos_cached, self.sin_cached


# rotary pos emb helpers:

def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

@torch.jit.script
def apply_rotary_pos_emb(q, k):
    cos, sin = rotary_pos_emb(q)
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot


seq_len = 61
n_heads = 4
head_dim = 32
q = torch.randn(seq_len, n_heads, head_dim)
k = torch.randn(seq_len, n_heads, head_dim)

# define a rotary positional embedding layer
rotary_pos_emb = Rotary(head_dim)

# apply the rotary positional embedding to the q and k vectors
q_rot, k_rot = apply_rotary_pos_emb(q, k)


RuntimeError: 
undefined value rotary_pos_emb:
  File "/var/folders/y0/09wrr2yx6r79sjmsdnsc_0jm0000gn/T/ipykernel_84182/3611785811.py", line 30
@torch.jit.script
def apply_rotary_pos_emb(q, k):
    cos, sin = rotary_pos_emb(q)
               ~~~~~~~~~~~~~~ <--- HERE
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)


In [None]:
seq_len = 61
n_heads = 4
head_dim = 32
q = torch.randn(seq_len, n_heads, head_dim)
k = torch.randn(seq_len, n_heads, head_dim)

# define a rotary positional embedding layer
rotary_pos_emb = Rotary(head_dim)

# apply the rotary positional embedding to the q and k vectors
q_rot, k_rot = apply_rotary_pos_emb(q, k)

NameError: name 'apply_rotary_pos_emb' is not defined

In [13]:
# return (L, C)
def pos_encoding(x):
    L, C = x.shape
    pos = torch.arange(0, L).view(-1, 1) # (L, 1)
    div = 2 * torch.arange(0, C) / C # (C)
    div = torch.pow(10000, div) # (C)
    e = pos / div
    pe = torch.zeros(L, C)
    pe[:,0::2] = torch.sin(e[:,0::2])
    pe[:,1::2] = torch.cos(e[:,1::2])
    
    pe = pe.to(device)
    return pe

In [14]:
from src.encoder import rotary_encoding

class MultiHeadAttension(nn.Module):    
    
    def __init__(self, head_num, head_size, in_size, out_size, rotary_encoding=False):
        super().__init__()
        
        self.head_size = head_size
        self.head_num = head_num        
        self.attn = nn.Linear(in_size, 3 * head_num * head_size, bias=False)
        self.ffn = nn.Linear(head_num * head_size, out_size, bias=False)
        self.rotary_encoding = rotary_encoding

        
    # x: (L, C)  
    # return: (L, C')
    def forward(self, x):
        L, C = x.shape
        
        z = self.attn(x) # (L, 3 * hn * hs)
        k, q, v = torch.split(z, self.head_num * self.head_size, dim=-1) # (L, hn * hs)

        # reshape the output to have the correct shape
        q = q.view(L, self.head_num, self.head_size)
        k = k.view(L, self.head_num, self.head_size)
        v = v.view(L, self.head_num, self.head_size)

        # apply rotary encoding if needed
        if self.rotary_encoding:
            q = rotary_encoding(q)
            k = rotary_encoding(k)

        
        q=q.permute(1,0,2) # ( hn, L, hs)
        k=k.permute(1,0,2)
        v=v.permute(1,0,2)

        
        q = q.permute(0, 2, 1) # ( hn, hs, L)
        attn = (k @ q) / self.head_size**0.5 # (hn, L, L)
        mask = torch.tril(torch.ones(L, L)) == 0
        mask = mask.to(device)
        attn = attn.masked_fill(mask, -float('inf')) # (B, hn, L, L)
        attn = F.softmax(attn, dim=-1)
        
        y = attn @ v # (hn, L, hs)
        y = y.permute(1, 0, 2) # (L, hn, hs)
        y = y.contiguous().view(L, -1) # (L, hn * hs)
        y = self.ffn(y) # (L, C)
        
        return y 
    
        
x = torch.randn(block_size, 9) # (L, C)
x = x.to(device)
mh = MultiHeadAttension(5, 3, 9, 7)
mh = mh.to(device)
mh(x).shape

torch.Size([64, 7])

In [15]:
class MLP(nn.Module):
    
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear1 = nn.Linear(in_size, out_size)
        self.linear2 = nn.Linear(out_size, out_size)
    
    # (B, L, C)
    def forward(self, x):
        y = self.linear1(x)
        y = torch.relu(y)
        y = self.linear2(y)
        
        return y

In [16]:
use_rotary=True


class Block(nn.Module):    
    
    def __init__(self, emb_size, head_size):
        super().__init__()
        
        assert emb_size % head_size == 0
        head_num = emb_size // head_size
        
        self.mha = MultiHeadAttension(head_num, 
                                      head_size, 
                                      in_size=emb_size, 
                                      out_size=emb_size,
                                      rotary_encoding=use_rotary)
        self.lnorm1 = nn.LayerNorm(emb_size)
        self.lnorm2 = nn.LayerNorm(emb_size)
        self.ffn = MLP(emb_size, emb_size)
        
        
    # x: (B, L, emb)
    def forward(self, x):
        y = self.mha(x) + x
        y = self.lnorm1(y)
        y = self.ffn(y) + y
        y = self.lnorm2(y)
        return y
    
# x = torch.randn(3, 4, 10)
# b = Block(10, 2)
# b(x)

In [17]:
emb_size = 128
head_size = 32

class Transformer(nn.Module):    
    
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.blocks = nn.Sequential(
            Block(emb_size, head_size),
            Block(emb_size, head_size),
        )
        self.linear = nn.Linear(emb_size, vocab_size)

    # (L) -> (L, C)
    def forward(self, x):
        y = self.embed(x) # (L, emb)
        if not use_rotary:
            y = y + pos_encoding(y) # (L, emb)
        y = self.blocks(y) # (L, emb)
        y = self.linear(y) # (L, vocab)
        
        return y

In [18]:
model = Transformer()
model = model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)

count = sum([p.numel() for p in model.parameters()])
print(f"total parameter: {count}")

total parameter: 214849


In [19]:
epoch = 80000
eval_interval = 500
eval_size = 500
lossi = []

model.train()

for i in range(epoch):
    optim.zero_grad()

    xb, yb = random_batch()
    logits = model(xb) # (L, C)

    L, C = logits.shape
    loss = F.cross_entropy(logits, yb)
    loss.backward()
    optim.step()

    if i % eval_interval == 0 or i == epoch-1:
        tr, va = estimate_loss(model)
        lossi.append((tr, va))
        print(f"{i:5d}/{epoch}: {tr:.4f}  {va:.4f}")

    0/80000: 4.3529  4.3444
  500/80000: 2.7713  2.8450
 1000/80000: 2.6230  2.5942
 1500/80000: 2.4577  2.4123
 2000/80000: 2.3886  2.3992
 2500/80000: 2.2748  2.2171
 3000/80000: 2.2638  2.2952
 3500/80000: 2.1692  2.2438
 4000/80000: 2.2033  2.1929
 4500/80000: 2.1412  2.1478
 5000/80000: 2.0780  2.2718
 5500/80000: 2.0798  2.1297
 6000/80000: 2.1470  2.1623
 6500/80000: 2.0404  2.0625
 7000/80000: 2.0953  2.0573
 7500/80000: 2.0496  2.0759
 8000/80000: 1.9404  2.0527
 8500/80000: 1.9855  2.0729
 9000/80000: 2.1580  2.0026
 9500/80000: 2.1106  2.0704
10000/80000: 1.9967  1.9953
10500/80000: 2.0391  2.1085
11000/80000: 2.0747  2.0027
11500/80000: 1.9446  2.0539
12000/80000: 1.9326  2.1832
12500/80000: 1.9245  1.9689
13000/80000: 1.8351  2.1008
13500/80000: 1.9117  2.0519
14000/80000: 1.9141  2.0313
14500/80000: 1.9876  2.0466
15000/80000: 1.9099  2.0705
15500/80000: 1.9833  2.0280
16000/80000: 1.8938  2.1066
16500/80000: 1.9482  2.0023
17000/80000: 2.0053  2.2195
17500/80000: 1.9146 

KeyboardInterrupt: 

In [20]:
tr_loss, va_loss = estimate_loss(model)

print(f"train: {tr_loss:.4f}")
print(f"valid: {va_loss:.4f}")

train: 1.7488
valid: 1.8058


In [21]:
print(sample(model))

I's Clay is lispect hand it fend queed trieghar,
But by Clatis your brothing cale, my be with of graw meed Grother.
For chonved, thinks worn if misle thee,
I then
Lord headguend, with get mither distingged. Ye't dery

TRUCE IF joing my a long, I fill pey man him an what's seet,
If anlembrot, do sin thee, everce felind at me?

KINCETuSlar.

SORCASTER:
Reakes more mone of madyer me, profel a poort
What Iarme keys.

Tursen, as tricust well on fall:
I have's with with use use over endy,
Good hear th


## Log

- Bi-gram: 2.4716, 2.4755
- Single-head attention: 2.3899, 2.4041
- Multi-head attention, single layer: 2.0820, 2.1165
- Multi-head attention, single layer, positional encoding: 1.8575, 1.9216
- 2-layer transformer (with everything, MHA, positional encoding, layer norm): 1.7155, 1.7952