In [42]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math


In [2]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer

## get data 

In [3]:
ger_train_path = "/home/adam/play/translate_data/train.de"
eng_train_path = "/home/adam/play/translate_data/train.en"

ger_train = open(ger_train_path, "r").readlines()
eng_train = open(eng_train_path, "r").readlines()

# just to help my poor laptop
ger_train = ger_train[:10000]
eng_train = eng_train[:10000]

In [4]:
print(len(ger_train))
print(len(eng_train))

10000
10000


In [5]:
### tokenize 
tokenizer = Tokenizer(BPE())

tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]) # do i need to account for the weird #AT things??
tokenizer.train(files=[ger_train_path, eng_train_path,], trainer=trainer)
tokenizer.enable_padding()
# tokenizer.train_from_iterator(iter(ger_train + eng_train), trainer=trainer)







In [6]:
ger_train = tokenizer.encode_batch(ger_train)
eng_train = tokenizer.encode_batch(eng_train)

In [7]:
ger_train = torch.tensor([t.ids for t in ger_train])
eng_train = torch.tensor([t.ids for t in eng_train])

In [11]:
ger_train.shape

torch.Size([10000, 258])

## define model

In [12]:
"""
PLAN

1. Create attention + multi-headed attention block
2. Create encoder blocks (should just be MHA + layer-norm + FF) 
3. create decoder blocks (needs both normal attention + weird feature attention)
4. create embedding 
5. create positional embedding (if absolute pos is possible I will use, otherwise I will be forced to enage with horrible sinusoidals etc)
"""

'\nPLAN\n\n1. Create multi-headed attention block\n2. Create encoder blocks (should just be MHA + layer-norm + FF) \n3. create decoder blocks (needs both normal attention + weird feature attention)\n4. create embedding \n5. create positional embedding (if absolute pos is possible I will use, otherwise I will be forced to enage with horrible sinusoidals etc)\n'

In [105]:
class AttentionHead(nn.Module):
    def __init__(self, n_embed, head_size, dropout, needs_mask=True, store_kv=False):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embed, head_size, bias=False) # (n_embed, head_size)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.needs_mask = needs_mask
        self.store_kv = store_kv
        if self.needs_mask:
            self.register_buffer("tril", torch.tril(torch.ones(head_size, head_size)))
    
    def forward(self, k_in, q_in, v_in):
        B, T, C = k_in.shape # (batch, context_size, n_embed)
        # create k and q matrices, use them to create attention matrix
    
        k = self.key(k_in) # (B, context_size, n_embed) @ (n_embed, head_size) --> (B, context_size, head_size)
        q = self.query(q_in) # (B, context_size, n_embed) @ (n_embed, head_size) --> (B, context_size, head_size)
        v = self.value(v_in) # (B, context_size, n_embed) @ (n_embed, head_size) --> (B, context_size, head_size)

        # if we're in the Encoder block then store the k and v
        if self.store_kv:           
            self.k = k
            self.v = v

        attn = q @ k.transpose(-2,-1) * self.head_size**-0.5  # (B, context_size, head_size) @ (B, head_size, context_size) --> (B, context_size, context_size)
        
        # if this is a masked attention layer (causal?) mask out all tokens before the cur pos 
        if self.needs_mask:
            attn = attn.masked_fill(self.tril[:T,:T] == 0, float("-inf")) # NOTE: still not sure why tril needs to be index up to T, shouldn't that be the same as its size??

        attn = self.dropout(F.softmax(attn, dim=-1))
        # generate the v matrix and use the attn matrix to pluck out relevant info
        out = attn @ v # (B, context_size, context_size) @ (B, context_size, head_size) --> (B, context_size, head_size)
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embed, n_heads, dropout, needs_mask=True, store_kv=False):
        super().__init__()
        self.n_embed = n_embed
        self.n_heads = n_heads
        assert(n_embed % n_heads == 0) # check dims work
        self.n_head_size = n_embed / n_heads
        self.dropout = dropout
        self.needs_mask = needs_mask
        self.store_kv = store_kv
        if self.needs_mask:
            self.register_buffer("tril", torch.tril(torch.ones(head_size, head_size)))

        self.wk = self.Linear(n_embed, n_embed, bias=False)
        self.wq = self.Linear(n_embed, n_embed, bias=False)
        self.wv = self.Linear(n_embed, n_embed, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(n_embed, n_embed)

    def forward(self, k_in, q_in, v_in):
        B, T, C = k_in.shape # (batch, context_size, n_embed)

        # create k
        # NOTE: make sure the maths works here
        k = self.wk(k_in) # (B, context_size, n_embed) @ (n_embed, n_embed) ---> (B, context_size, n_embed)
        # split per head
        k = k.view(B, T, self.n_heads, self.head_size)  # (B, context_size, n_embed) --> (B, context_size, n_heads, head_size)
        # switch context size and n_heads dim so we can batch matmul over B and n_heads
        k = k.transpose(1,2) # (B, context_size, n_heads, head_size) --> (B, n_heads, context_size, head_size)
        
        # create q 
        q = self.wq(q_in)
        q = q.view(B, T, self.n_heads, self.head_size)
        q = q.transpose(1,2)

        # create v 
        v = self.wv(v_in)
        v = v.view(B, T, self.n_heads, self.head_size)
        v = v.transpose(1, 2)

        attn = q @ k.transpose(-2, -1) # (B, n_heads, context_size, head_size) @ (B, n_heads, head_size, context_size) --> (B, n_heads, context_size, context_size)
        
        # if this is a masked attention layer (causal?) mask out all tokens before the cur pos 
        if self.needs_mask:
            attn = attn.masked_fill(self.tril[:T,:T] == 0, float("-inf")) # NOTE: still not sure why tril needs to be index up to T, shouldn't that be the same as its size??

        attn = self.dropout(F.softmax(attn, dim=-1))
        
        # generate the v matrix and use the attn matrix to pluck out relevant info on a per head basis
        out = attn @ v # (B, n_heads, context_size, context_size) @ (B, n_heads, context_size, head_size) --> (B, n_heads, context_size, head_size)
        # remove per-head dimension and use final linear projection
        out = out.view(B, T, self.n_embed) # (B, n_heads, context_size, head_size) --> (B, context_size, n_embed)
        return self.proj(out)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, n_embed, dropout, needs_mask=True, store_kv=False):
        super().__init__()
        head_size = n_embed // n_heads
        self.heads = nn.ModuleList([AttentionHead(n_embed, head_size, dropout, needs_mask, store_kv) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed, n_embed)

    def forward(self, k_in, q_in, v_in):
        x = torch.cat([head(k_in, q_in, v_in) for head in self.heads], dim=-1) # concat( (B, context_size, head_size) ) --> (B, context_size, n_embed)
        return self.proj(x) # (B, context_size, n_embed) @ (n_embed, n_embed) --> (B, context_size, n_embed)

class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout):
        super().__init__()
        self.ffw = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4*n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ffw(x) # (B, context_size, n_embed) @ (n_embed, 4*n_embed) @ (4*n_embed, n_embed) --> (B, context_size, n_embed)

class EncoderBlock(nn.Module):
    def __init__(self, n_heads, n_embed, dropout):
        super().__init__()
        self.attention = MultiHeadAttention(n_heads, n_embed, dropout, False, True)
        self.ffw = FeedForward(n_embed, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        out = self.attention(x, x, x)
        out = self.ln1(out)
        out = self.ffw(out)
        out = self.ln2(out)
        return x + out # NOTE: not clear if this is the right place for adding to take place - refer to other implementations

class DecoderBlock(nn.Module):
    def __init__(self, n_heads, n_embed, dropout):
        super().__init__()
        self.masked_attention = MultiHeadAttention(n_heads, n_embed, dropout, True)
        self.attention = MultiHeadAttention(n_heads, n_embed, dropout, True)
        self.ffw = FeedForward(n_embed, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
        self.ln3 = nn.LayerNorm(n_embed)

    def forward(self, x, k_in, v_in):
        x = self.masked_attention(x, x, x)
        x = self.ln1(x)
        x = self.attention(k_in, x, v_in)
        x = self.ln2(x)
        x = self.ffw(x)
        x = self.ln3(x)
        return x

class PositionalEmbedding(nn.Module):
    def __init__(self, n_embed, context_size):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        posem = torch.zeros((context_size, n_embed))
        for pos in range(context_size):
            for i in range(n_embed, 2):
                posem[pos, i] = math.sin(pos/(10000**(2*i/n_embed)))
                posem[pos, i+1] = math.cos(pos/(10000**(2*(i+1)/n_embed)))
        posem.unsqueeze(0)
        self.register_buffer("posem", posem)

    def forward(self, x):
        return x + self.posem[torch.arange(x.shape[1]),:]


class EncoderDecoderTransformer(nn.Module):
    def __init__(self, n_heads, n_embed, dropout, n_blocks, context_size, input_vocab_size, output_vocab_size):
        super().__init__()
        self.input_embedding = nn.Embedding(input_vocab_size, n_embed)
        self.positional_embedding = PositionalEmbedding(n_embed, context_size)
        self.encoders = nn.Sequential(*[EncoderBlock(n_heads, n_embed, dropout) for _ in range(n_blocks)])
        self.decoders = nn.Sequential(*[DecoderBlock(n_heads, n_embed, dropout) for _ in range(n_blocks)])
        self.output_embedding = nn.Embedding(output_vocab_size, n_embed) # NOTE: should this really be the same vocab size as input?? 

    def forward(self, xi, xo):
        x = self.input_embedding(xi)
        x = self.positional_embedding(x)
        print(f"post emb {x.shape}")
        x = self.encoders(x)
        print(f"post enc {x.shape}")
        x = self.decoders(xo , self.encoders[-1].k, self.encoders[-1].v)
        print(f"post dec {x.shape}")
        x = self.output_embedding(x)
        return F.softmax(x, dim=2)

In [107]:
# TODO: work out how to handle extracting k and v from multiple heads in final encoder block?? very possibly I should just rewrite in the kombo-head pattern 

model = EncoderDecoderTransformer(
    n_heads=3, n_embed=60, dropout=0.2, n_blocks=5, context_size=12, input_vocab_size=11, output_vocab_size=11
)

# x = torch.arange(11)
# y = torch.tensor([[0]])
src = torch.tensor([[0, 2, 5, 6, 4, 3, 9, 5, 2, 9, 10, 1], 
                    [0, 2, 8, 7, 3, 4, 5, 6, 7, 2, 10, 1]])
target = torch.tensor([[0, 1, 7, 4, 3, 5, 9, 2, 8, 10, 9, 1], 
                       [0, 1, 5, 6, 2, 4, 7, 6, 2, 8, 10, 1]])
model(src, target[1:])

post emb torch.Size([2, 12, 60])
post enc torch.Size([2, 12, 60])


AttributeError: 'EncoderBlock' object has no attribute 'k'

In [74]:
y

tensor([[0]])

In [85]:
print(src.shape)

torch.Size([2, 12])


tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [108]:
a = torch.ones((5,3))

In [111]:
a.view(5,3, 2)

RuntimeError: shape '[5, 3, 2]' is invalid for input of size 15

In [112]:
b = nn.Linear(5, 5)

In [114]:
b.view(3, 5, 5)

AttributeError: 'Linear' object has no attribute 'view'