## Decoder only Transformer

In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
import pandas as pd

device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
device

device(type='cuda')

## Parts:
0. Dataset

1. character-level-tokenization
2. single self-attention module
3. multiple attention modules (multi-head)
4. Positional Embeddings
5. Decoder
6. Transformer

#### Part 1 : Character Level Tokenization

#### Part 0 : Dataset - Tiny Shakespeare

In [2]:
class ShakespeareDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.data = pd.read_csv("hf://datasets/Trelis/tiny-shakespeare/train.csv")
        tokens = set()
        self.maxlen = 0
        for _, sentence in self.data.iterrows():
            if self.maxlen < len(sentence.Text):
                self.maxlen = len(sentence.Text)
            for char in sentence.Text:
                tokens.add(char)
        self.maxlen += 2
        self.tokens = list(tokens)
        self.tokens.sort()
        self.tokens.insert(0, '<pad>')
        self.tokens.insert(1, '<eos>')
        self.tokens.insert(2, '<sos>')
        self.char_to_token = { c:idx for idx, c in enumerate(self.tokens) }

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        sentence = self.data.iloc[index]
        tokenized_sentence = [self.char_to_token['<sos>']] + [self.char_to_token[char] for char in sentence.Text] + [self.char_to_token['<eos>']]
        sample = torch.LongTensor(tokenized_sentence)
        mask = torch.ones(self.maxlen)
        if len(sample) < self.maxlen:
            mask[len(sample):] = 0
            sample = torch.cat([sample, torch.full((self.maxlen - len(sample),), self.char_to_token['<pad>'])])
        target = torch.cat([sample[1:],torch.LongTensor([self.char_to_token['<pad>']])])

        return sample, target, mask

In [3]:
shsp_data = ShakespeareDataset()
print(len(shsp_data))
sample, target, mask =shsp_data[4]
sample

  from .autonotebook import tqdm as notebook_tqdm


472


tensor([ 2, 40, 56,  ...,  0,  0,  0])

In [4]:
sentence = ''.join([shsp_data.tokens[sample[i].item()] for i in range(len(sample)) if shsp_data.tokens[sample[i].item()] not in ['<pad>','<sos>','<eos>']])
print(sentence)

Your virtue is
To make him worthy whose offence subdues him
Your virtue is
And curse that justice did it.
Who deserves greatness
Deserves your hate; and your affections are
A sick man's appetite, who desires most that
Which would increase his evil. He that depends
Upon your favours swims with fins of lead
And hews down oaks with rushes. Hang ye! Trust Ye?
With every minute you do change a mind,
And call him noble that was now your hate,
Him vile that was your garland. What's the matter,
That in these several places of the city
You cry against the noble senate, who,
Under the gods, keep you in awe, which else
Would feed on one another? What's their seeking?

MENENIUS:
For corn at their own rates; whereof, they say,
The city is well stored.

MARCIUS:
Hang 'em! They say!
They'll sit by the fire, and presume to know
What's done i' the Capitol; who's like to rise,
Who thrives and who declines; side factions
and give out
Conjectural marriages; making parties strong
And feebling such as stand

In [5]:
print(len(shsp_data.tokens))

68


#### Part 2 : Single Self-Attention Head


#### Part 3 : Multi-Head Attention

In [6]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads : int, model_dim : int):
        super().__init__()
        assert model_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads
        self.model_dim = model_dim

        self.Wq = nn.Linear(model_dim, model_dim)
        self.Wk = nn.Linear(model_dim, model_dim)
        self.Wv = nn.Linear(model_dim, model_dim)

        self.Wo = nn.Linear(model_dim, model_dim)

    def forward(self, q : torch.Tensor, k : torch.Tensor, v : torch.Tensor, mask : torch.Tensor | None = None) -> torch.Tensor:
        # q, k, v = (bs, sl, model_dim)

        bs, sl = q.shape[0], q.shape[1]

        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)

        q = q.view(bs, sl, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(bs, sl, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(bs, sl, self.num_heads, self.head_dim).transpose(1,2)

        out = self.scaled_dot_product_attention(q,k,v,mask)
        # out = (bs, num_heads, sl, head_dim)

        out = out.transpose(1,2).contiguous().view(bs, sl, self.model_dim)
        out = self.Wo(out)
        return out

    def scaled_dot_product_attention(self, q : torch.Tensor, k : torch.Tensor , v : torch.Tensor, mask : torch.Tensor | None =None) -> torch.Tensor:
        out = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(self.head_dim)
        if mask is not None:
            if mask.shape == torch.Size([q.shape[0], q.shape[2]]):
                #padding mask (bs, sl)
                mask = mask.unsqueeze(1).unsqueeze(1) #(bs, 1, 1, sl)

            elif mask.shape == torch.Size([q.shape[2], q.shape[2]]):
                #causal mask (sl, sl)
                mask = mask.unsqueeze(0).unsqueeze(0) #(1, 1, sl, sl)

            elif mask.shape == torch.Size([q.shape[0],q.shape[2],q.shape[2]]):
                #combined mask (bs, sl, sl)
                mask = mask.unsqueeze(1) #(bs, 1, sl, sl)

            else:
                raise TypeError(f"INPUT MASK SHAPE {mask.shape} INVALID")

            out = out.masked_fill(mask==0,value=float('-inf'))
        out = nn.functional.softmax(out, dim=-1)
        out = torch.matmul(out, v)
        return out

#### Part 4 : Positional Encoding

In [7]:
def positional_encoding(input : torch.Tensor) -> torch.Tensor:
    bs, sl, model_dim = input.shape
    pe = torch.zeros(sl, model_dim)

    pos = torch.arange(0, sl).unsqueeze(1)
    dim = torch.arange(0, model_dim, 2)

    pe[:,0::2] = torch.sin(pos / torch.pow(10000,2 * dim / model_dim))
    pe[:,1::2] = torch.cos(pos / torch.pow(10000,2 * dim / model_dim))

    return input + pe.unsqueeze(0).to(input.device)

In [8]:
class LayerNormalization(nn.Module):
    def __init__(self, model_dim : int):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(model_dim))
        self.beta = nn.Parameter(torch.zeros(model_dim))

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1,keepdim=True)
        return self.gamma * (x - mean) / std + self.beta

#### Part 5 : Decoder

In [9]:
class Decoder(nn.Module):
    def __init__(self, model_dim : int, num_heads : int):
        super().__init__()
        self.attn1 = MultiHeadedAttention(num_heads,model_dim)
        self.norm1 = LayerNormalization(model_dim)

        self.ffn1 = nn.Sequential(
            nn.Linear(model_dim, model_dim),
            nn.SiLU(),
            nn.Linear(model_dim, model_dim)
        )
        self.norm2 = LayerNormalization(model_dim)

    def forward(self, outputs : torch.Tensor, mask : torch.Tensor) -> torch.Tensor:

        outputs = self.norm1(outputs + self.attn1(outputs, outputs, outputs, mask))

        outputs = self.norm2(outputs + self.ffn1(outputs))

        return outputs

#### Part 7 : Transformer Assembly

In [10]:
class DecoderTransformer(nn.Module):
    def __init__(self, model_dim : int, num_heads : int, num_decoders : int, vocab_size : int):
        super().__init__()
        self.out_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=model_dim)
        self.decoders = nn.ModuleList(
            [Decoder(model_dim, num_heads) for _ in range(num_decoders)]
        )
        self.linear_out = nn.Linear(model_dim, vocab_size)

    def forward(self, outputs : torch.Tensor, padding_mask : torch.Tensor | None) -> torch.Tensor:
        bs, sl = outputs.shape[:2]

        causal_mask = torch.tril(torch.ones(sl, sl, device=outputs.device))

        if padding_mask is not None:
            combined_mask = padding_mask.unsqueeze(1) * causal_mask.unsqueeze(0) #(bs, 1, sl) * (1, sl, sl) = (bs, sl, sl)
        else:
            combined_mask = causal_mask

        outputs = self.out_embed(outputs)

        outputs = positional_encoding(outputs)

        for decoder in self.decoders:
            outputs = decoder(outputs, combined_mask)

        outputs = self.linear_out(outputs)
        return outputs # (bs, sl, vocab_size)

#### Training

In [11]:
def train_transformer(
        model : nn.Module,
        train_loader : DataLoader,
        device=device,
        num_epochs = 10000,
        lr = 1e-3,
):
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss(ignore_index=train_loader.dataset.char_to_token['<pad>'])
    vocab_size = len(train_loader.dataset.tokens)
    model.train()

    losses = []
    pbar = tqdm(range(num_epochs))
    for idx in pbar:
        epoch_loss = 0
        for samples, targets, masks in train_loader:
            samples = samples.to(device)
            targets = targets.to(device)
            masks = masks.to(device)
            optim.zero_grad()
            loss = loss_fn(model(samples, masks).view(-1, vocab_size), targets.view(-1))
            loss.backward()
            optim.step()
            epoch_loss += loss
            pbar.set_description(f'Epoch {idx}, Loss: {loss}')
        losses.append(epoch_loss.detach())
    model.eval()
    return torch.stack(losses)

In [12]:
train_loader = DataLoader(ShakespeareDataset(), 8, shuffle=True, num_workers=4)
model = DecoderTransformer(32, 8, 3, len(train_loader.dataset.tokens)).to(device)

In [None]:
losses = train_transformer(model,train_loader, num_epochs=200)
plt.plot(losses.cpu())
plt.show()

In [None]:
# torch.save(model.state_dict(), 'shksp_decoder_transformer_32_8_3.pth')
# from google.colab import files
# files.download('shksp_decoder_transformer_32_8_3.pth')

In [13]:
model.load_state_dict(torch.load('shksp_decoder_transformer_32_8_3.pth'))

<All keys matched successfully>

In [16]:
def generate(model : nn.Module, init_token : str = '<sos>', maxlen : int = 3000):
  toks = torch.zeros(1,maxlen).long().to(device)
  toks[0,0] = shsp_data.char_to_token[init_token]
  last_tok = toks[0,0].item()
  mask = torch.zeros(1,maxlen).long().to(device)
  for i in range(maxlen):
    mask[0, 0:i+1] = 1
    probs = torch.softmax(model(toks, mask), dim=-1) #(bs, sl, vocab_Size)
    last_tok = torch.multinomial(probs[:,-1], 1).item()
    toks[0,i] = last_tok
    if last_tok == shsp_data.char_to_token['<eos>']:
      break
  output = ''.join(
      [shsp_data.tokens[tok] for tok in toks[0] if tok != 0]
  )
  return output

In [None]:
output = generate(model)
print(output)

oa
eeeeaeameeaeeaa:aeeaheeea eeehaereeaeaaaaarreeaaaaeree
reehaaheweeeaehh
areaaee

reeeh
ee!rsraeeerwh
aeOe-reaehheehOOeeaeae
raawrhUUeUrieaaaehUeUheaeaeeeaeeeeeeryaeUeeUaRUhwaeTeeeraeeo

u

-a

e
aueee

i
eee

aeeaaaae
hee
hrr
ehxeUa:e
uaa

eaue
heeraeeeeaireeaaereeeRaeeeaeeeeahieeeereeeOmeeaaaeehaeetyeekrer
ueeareeeahreeaeeuheeaaaeehayareuueaieeheaeeaereeeaaeeeeeeaee
eaeehueeeaaaxeaeeeee
iaeeeeyeeeaeseeeeeaeeeeaeeeekeeeal
healm
aeeeerueaoeia
eeeahaaepeeeaeeaeeaaeiaaeweeeaaeaweleaeee
eaaaew
Raeeeeheaeaaxeye
h
aae
eeieeeeea
ee
eeeseaeeueereeeee
wreeaee
eehe?aee
eaeeehaaaeieaaeeeeataeaoeeeeaeieeahhaaeBaeeeeeeeaaea
eeree
eeeuoaeieaeieieeeeaiieeaeeereiaaee?aaeaeaereeh
eleeee
rreuh
hhweeeueaeeeeew
eeeuceaaeaeaeaepeereeeeteaheheeakuyyueeeeeutao
eeeaee
eae.e;e.euae

ae;
e
heeaie
ee
ee?h-uwe
emeeaia.
uaaue
Eedreeeqenehireeeaepeeueewlomexaeeewle
ah
aumaeaxxe.taeeeeeemaareaumR
mea
oaaeeeeafeewaeeeeae.eeeeyeeueaeeaeenxebmueeeeneteepaeeweaohareaaaeatieeeppeeeeee;eoaeeeeaa:eeeeera
etaptesuaterate

In [None]:
print("Available tokens:", shsp_data.tokens)
print("Looking for:", '<sos>')
print("Token ID:", shsp_data.char_to_token.get('<sos>', "NOT FOUND"))
print("Embedding expects vocab size:", model.out_embed.num_embeddings)  # or however you access it
print("Actual vocab size:", len(shsp_data.tokens))
print("Max token ID:", max(shsp_data.char_to_token.values()))
print("Min token ID:", min(shsp_data.char_to_token.values()))

In [None]:
# Test if your model can handle a simple input:
test_input = torch.tensor([[2]]).to(device)  # Just <sos> token
test_mask = torch.ones(1, 1).to(device)

try:
    test_output = model(test_input, test_mask)
    print("Model forward pass successful")
    print("Output shape:", test_output.shape)
except RuntimeError as e:
    print("Model forward pass failed:", e)