This part mainly contains the basic attention mechanism and a simple transformer-based model, the gpt2 model.

> some codes referenced from <a href='https://scholar.harvard.edu/binxuw/classes/machine-learning-scratch/materials/transformers'>MLFS</a>

# Self-Attention: Single Head

In [1]:
!pip install einops

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


In [10]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import math
from einops import rearrange

In [11]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [57]:
class SingleHeadAttention:
    def __init__(self, emb_dim, head_dim, head=1):
        self.to_q = nn.Linear(emb_dim, head_dim, bias=False)
        self.to_k = nn.Linear(emb_dim, head_dim, bias=False)
        self.to_v = nn.Linear(emb_dim, emb_dim, bias=False)

    def get_qkv(self, tokens):
        q = self.to_q(tokens)
        k = self.to_k(tokens)
        v = self.to_v(tokens)
        return q, k, v

    def __call__(self, tokens):
        q, k, v = self.get_qkv(tokens)
        att_score = torch.einsum("BTD,BSD->BTS", q, k)
        att_score = F.softmax(att_score / math.sqrt(k.shape[2]), dim=2)
        val = torch.einsum("BTS,BSD->BTD", att_score, v)
        return val

In [58]:
# TEST implementation
emb_dim = 128
head_dim = 64
tokens = torch.rand(32, 12, emb_dim)

m = SingleHeadAttention(emb_dim, head_dim)
q, k, v = m.get_qkv(tokens)
assert torch.allclose(m(tokens), F.scaled_dot_product_attention(q, k, v))

# Self-Attention: Multi Head

In [103]:
class MultiHeadAttention:
    def __init__(self, emb_dim, head_dim, head):
        assert emb_dim == head_dim * head, "required! emb_dim == head_dim * head"

        inner_dim = head_dim * head
        self.head_dim = head_dim
        self.head = head
        self.to_qs = nn.Linear(emb_dim, inner_dim, bias=False)
        self.to_ks = nn.Linear(emb_dim, inner_dim, bias=False)
        self.to_vs = nn.Linear(emb_dim, emb_dim, bias=False)

    def get_qkv(self, tokens):
        bs, t, _ = tokens.shape
        qs = self.to_qs(tokens)
        ks = self.to_ks(tokens)
        vs = self.to_vs(tokens)
        q = rearrange(qs, "b t (n d) -> b n t d", n=self.head)
        k = rearrange(ks, "b t (n d) -> b n t d", n=self.head)
        v = rearrange(vs, "b t (n d) -> b n t d", n=self.head)
        return q, k, v

    def get_att_score(self, tokens):
        q, k, v = self.get_qkv(tokens)
        att_score = torch.einsum("BNTD,BNSD->BNTS", q, k)
        return F.softmax(att_score / math.sqrt(self.head_dim), dim=-1)

    def __call__(self, tokens):
        att_score = self.get_att_score(tokens)
        val = torch.einsum("BNTS,BNSD->BNTD", att_score, v)
        return rearrange(val, "b n t d -> b t (n d)", n=self.head)

In [104]:
# TEST implementation
emb_dim = 768
head_dim = 64
head = 12

tokens = torch.rand(32, 12, emb_dim)

m = MultiHeadAttention(emb_dim, head_dim, head)
mha = nn.MultiheadAttention(emb_dim, head, batch_first=True,)

# load state dict
tmp_dict = {
    'in_proj_weight': torch.cat([m.to_qs.weight.data.T, m.to_ks.weight.data.T, m.to_vs.weight.data.T], dim=1).T
}
for k, v in mha.state_dict().items():
    if k in tmp_dict:
        continue
    tmp_dict[k] = v
mha.load_state_dict(tmp_dict)

mha_att = mha(tokens, tokens, tokens, average_attn_weights=False)[1]

assert torch.allclose(m.get_att_score(tokens), mha_att, atol=1e-6, rtol=1e-6)

# Multi Head Attention with mask

In [132]:
class MultiHeadAttentionMask:
    def __init__(self, emb_dim, head_dim, head, token_num):
        assert emb_dim == head * head_dim
        inner_dim = head * head_dim

        self.head = head
        self.head_dim = head_dim

        self.to_qs = nn.Linear(emb_dim, inner_dim, bias=False)
        self.to_ks = nn.Linear(emb_dim, inner_dim, bias=False)
        self.to_vs = nn.Linear(emb_dim, inner_dim, bias=False)

        self.mask = torch.ones(token_num, token_num)
        self.mask = -1E4 * torch.triu(self.mask, 1)

    def get_qkv(self, tokens):
        bs, t, _ = tokens.shape
        qs = self.to_qs(tokens)
        ks = self.to_ks(tokens)
        vs = self.to_vs(tokens)
        q = rearrange(qs, "b t (n d) -> b n t d", n=self.head)
        k = rearrange(ks, "b t (n d) -> b n t d", n=self.head)
        v = rearrange(vs, "b t (n d) -> b n t d", n=self.head)

        assert q.shape == (bs, self.head, t, self.head_dim)
        assert k.shape == (bs, self.head, t, self.head_dim)
        assert v.shape == (bs, self.head, t, self.head_dim)
        return q, k, v

    def get_att(self, tokens):
        q, k, v = self.get_qkv(tokens)
        att_s = torch.einsum("BNTD,BNSD->BNTS", q, k)
        att_s += self.mask[None, None, :, :]
        att = F.softmax(att_s / math.sqrt(self.head_dim), dim=-1)
        return att

    def __call__(self, tokens):
        _, _, v = self.get_qkv(tokens)
        att = self.get_att(tokens)
        val = torch.einsum("BNTS,BNSD->BNTD", att, v)
        val = rearrange(v, "b n t d -> b t (n d)")
        return val

In [133]:
# TEST implementation
emb_dim = 768
head_dim = 64
head = emb_dim // head_dim
token_num = 12

tokens = torch.rand(32, token_num, emb_dim)

m = MultiHeadAttentionMask(emb_dim, head_dim, head, token_num)
mha = nn.MultiheadAttention(emb_dim, head, batch_first=True)

# load state dict
tmp_dict = {
    'in_proj_weight': torch.cat([m.to_qs.weight.data.T, m.to_ks.weight.data.T, m.to_vs.weight.data.T], dim=1).T
}
for k, v in mha.state_dict().items():
    if k in tmp_dict:
        continue
    tmp_dict[k] = v
mha.load_state_dict(tmp_dict)

mha_att = mha(tokens, tokens, tokens, average_attn_weights=False, attn_mask=m.mask)[1]
m_att = m.get_att(tokens)

assert torch.allclose(m_att, mha_att, atol=1e-6, rtol=1e-6)

# Transformer Model

Key Part in Transformer Model

- Res connection
- LayerNorm
- Multi-Head Attention
- FFN

Using gpt2 as an exmaple to impelment a simple Transformer model. Arch:
1. word token embedding (50257, 768)
2. word position embedding (1024, 768)
3. 12 * gpt2_transformer_block
4. layer norm

Arch of gpt2 transformer block
```
     x-----------
layer_norm      |
mul_head_att    |
  res_add ------|
     x-----------
layer_norm      |
   ffn          |
 res_add -------|
```

In [None]:
!pip install transformers[torch]

In [22]:
class GPT2Block:
    def __init__(self, emb_dim, head, drop_out=0.0):
        self.ln1 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(emb_dim, head, batch_first=True)
        self.ln2 = nn.LayerNorm(emb_dim)
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, emb_dim * 4),
            nn.GELU(),
            nn.Linear(emb_dim * 4, emb_dim),
            nn.Dropout(drop_out)
        )

    def get_causal_mask(self, token_num):
        mask = torch.ones(token_num, token_num)
        mask = -1e4 * torch.triu(mask, 1)
        return mask

    def __call__(self, tokens, is_causal=True):
        x = tokens
        bs, token_num, emb_dim = x.shape
        if is_causal:
            mask = self.get_causal_mask(token_num)
        else:
            mask = None

        res = x
        x = self.ln1(x)
        x, _ = self.attn(x, x, x, attn_mask=mask)
        x = x + res

        res = x
        x = self.ln2(x)
        x = self.ffn(x)
        return x + res

**TEST implementation**

In [9]:
# @title test utils

from transformers.activations import NewGELUActivation

def GPT2block_to_TransformerBlock_simple(tfmblock, gpt2block, ):
    """copy the weights from a GPT2 block to a TransformerBlock_simple"""
    tfmblock.ln1.weight.data = gpt2block.ln_1.weight
    tfmblock.ln1.bias.data = gpt2block.ln_1.bias
    tfmblock.ln2.weight.data = gpt2block.ln_2.weight
    tfmblock.ln2.bias.data = gpt2block.ln_2.bias
    tfmblock.attn.in_proj_weight.data = gpt2block.attn.c_attn.weight.T
    tfmblock.attn.in_proj_bias.data = gpt2block.attn.c_attn.bias
    tfmblock.attn.out_proj.weight.data = gpt2block.attn.c_proj.weight.T
    tfmblock.attn.out_proj.bias.data = gpt2block.attn.c_proj.bias
    tfmblock.ffn[0].weight.data = gpt2block.mlp.c_fc.weight.T
    tfmblock.ffn[0].bias.data = gpt2block.mlp.c_fc.bias
    tfmblock.ffn[1] = NewGELUActivation()
    # mlp in GPT2 and BERT used a new GELU activation, using nn.GeLU() will cause a small error around 1E-3
    tfmblock.ffn[2].weight.data = gpt2block.mlp.c_proj.weight.T
    tfmblock.ffn[2].bias.data = gpt2block.mlp.c_proj.bias
    return tfmblock

def GPT2Model_to_GPT2Model_simple(gpt2modelsimple, gpt2model, ):
    """copy the weights from a GPT2 model to a GPT2Model_simple"""
    gpt2modelsimple.wte.weight.data = gpt2model.wte.weight
    gpt2modelsimple.wpe.weight.data = gpt2model.wpe.weight
    gpt2modelsimple.ln_f.weight.data = gpt2model.ln_f.weight
    gpt2modelsimple.ln_f.bias.data = gpt2model.ln_f.bias
    for i in range(12):
        GPT2block_to_TransformerBlock_simple(gpt2modelsimple.blocks[i], gpt2model.h[i])
    return gpt2modelsimple

In [5]:
from transformers import GPT2Tokenizer, GPT2Model

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")
model.eval()

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [23]:
embdim = 768
headcnt = 12
tfmblock = GPT2Block(embdim, headcnt)
GPT2block_to_TransformerBlock_simple(tfmblock, model.h[0])

tokens_embs = torch.randn(2, 5, 768)
tfmblock_out = tfmblock(tokens_embs, is_causal=True)
modelblock_out, = model.h[0](tokens_embs)
assert torch.allclose(tfmblock_out, modelblock_out, atol=1e-5, rtol=1e-5)

**Simple GPT2 Model**

In [32]:
class GPT2:
    def __init__(self, emb_dim, head, drop_out=0.0):
        self.wte = nn.Embedding(50257, 768)
        self.wpe = nn.Embedding(1024, 768)
        self.blocks = []
        for i in range(12):
            self.blocks.append(GPT2Block(emb_dim, head, drop_out))

        self.ln_f = nn.LayerNorm(emb_dim)

    def __call__(self, input_ids, is_causal=True):
        x = self.wte(input_ids)
        x = x + self.wpe(torch.arange(x.shape[1]))
        for f in self.blocks:
            x = f(x, is_causal)

        return self.ln_f(x)

In [33]:
embdim = 768
headcnt = 12

model_ours = GPT2(embdim, headcnt)
GPT2Model_to_GPT2Model_simple(model_ours, model)

inputs = tokenizer("I have a cat, her name is", return_tensors="pt")
outputs = model(**inputs, )

hidden_last_ours = model_ours(inputs['input_ids'])

assert torch.allclose(outputs.last_hidden_state, hidden_last_ours, atol=1e-5, rtol=1e-5)