<a href="https://colab.research.google.com/github/JeffBla/implement_transformer_from_scratch/blob/main/implement_transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat

In [None]:
class MultiheadSelfAttention(nn.Module):
  def __init__(self, k, head=4, mask=False):
    super().__init__()

    self.k = k
    self.head = head
    self.mask = mask

    self.to_query = nn.Linear(k, k, bias=False)
    self.to_key = nn.Linear(k, k, bias=False)
    self.to_value = nn.Linear(k, k, bias=False)

    self.unitify_layer = nn.Linear(k, k)

  def forward(self, x, kv=None):
    """ x : (batch_size × seq_length × embed_dim) """

    h = self.head
    if kv is None:
      kv = x
    b,tq,k = x.size()
    _,tk,_ = kv.size()

    s = k // h

    q = self.to_query(x).view(b,tq,h,s)
    k_ = self.to_key(kv).view(b,tk,h,s)
    v = self.to_value(kv).view(b,tk,h,s)

    attn_scores = torch.einsum("bths,behs->bhte", (q,k_)) / k**0.5

    if self.mask:
      mask = torch.triu(torch.ones(tq,tk, device=x.device), diagonal=1).bool()
      attn_scores = attn_scores.masked_fill(mask[None, None, :, :],
                          float('-inf'))

    attn_scores = F.softmax(attn_scores, dim=-1)

    out = torch.einsum("bhtd,bdhe->bthe", (attn_scores,v))
    out = out.reshape((b,tq,k))

    return self.unitify_layer(out)

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, k, head, is_cross_attention=False):
    super().__init__()

    self.k = k
    self.head = head
    self.is_cross_attention = is_cross_attention

    self.self_attn = MultiheadSelfAttention(k, head)
    self.norm1 = nn.LayerNorm(k)
    if is_cross_attention:
      self.cross_attn = MultiheadSelfAttention(k, head)
      self.cross_norm = nn.LayerNorm(k)

    self.ff = nn.Sequential(
        nn.Linear(k, 4*k),
        nn.ReLU(),
        nn.Linear(4*k, k)
    )
    self.norm2 = nn.LayerNorm(k)

  def forward(self, x, enc_out=None):
    attended = self.self_attn(x)
    norm1ed = self.norm1(attended + x)
    x = norm1ed

    if self.is_cross_attention and enc_out is not None:
      cross_attended = self.cross_attn(x, enc_out)
      cross_normed = self.cross_norm(cross_attended + x)
      x = cross_normed

    ff_out = self.ff(x)

    return self.norm2(ff_out + x)

In [None]:
class CTransformer(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size, num_class):
    super().__init__()

    self.token_emb = nn.Embedding(token_size, k)
    self.pos_emb = nn.Embedding(seq_len, k)

    blocks = []
    for i in range(depth):
      blocks.append(TransformerBlock(k, head))
    self.blocks = nn.Sequential(*blocks)

    self.ff = nn.Linear(k, num_class)

  def forward(self, x):
    token_emb = self.token_emb(x)
    b, t, k = token_emb.size()

    pos_emb = torch.arange(t)
    pos_emb = self.pos_emb(pos_emb).unsqueeze(0).expand([b,t,k])

    emb = token_emb + pos_emb

    out = self.blocks(emb)

    out = self.ff(out.mean(dim=1))

    return F.log_softmax(out, dim=1)

In [None]:
class CTransformerCLS(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size, num_class):
    super().__init__()

    self.k = k
    self.token_emb = nn.Embedding(token_size, k)
    self.pos_emb = nn.Embedding(seq_len+1, k)
    self.cls_token = nn.Parameter(torch.rand(1,1,k))

    blocks = []
    for i in range(depth):
      blocks.append(TransformerBlock(k, head))
    self.blocks = nn.Sequential(*blocks)

    self.ff = nn.Linear(k, num_class)

  def forward(self, x):
    b,t = x.size()
    cls_token = self.cls_token.expand([b,1,self.k])
    token_emb = self.token_emb(x)
    token_emb = torch.cat((cls_token, token_emb), dim=1)

    pos_emb = torch.arange(t+1, device=x.device)
    pos_emb = self.pos_emb(pos_emb).unsqueeze(0).expand([b,t+1,self.k])

    emb = token_emb + pos_emb

    out = self.blocks(emb)

    out = out[:,0]

    return self.ff(out)

In [None]:
class ViTTransformer(nn.Module):
  def __init__(self, k, head, depth, patch_size, img_size, channel, num_class):
    super().__init__()

    assert img_size % patch_size == 0

    self.img_size = img_size
    self.patch_size = patch_size
    self.patch_dim = patch_size * patch_size * channel
    self.num_patch = (img_size // patch_size) ** 2
    self.token_emb = nn.Linear(self.patch_dim, k)
    self.pos_emb = nn.Parameter(torch.rand(1,self.num_patch+1,k))

    self.cls_token = nn.Parameter(torch.rand(1,1,k))

    blocks = []
    for i in range(depth):
      blocks.append(TransformerBlock(k, head))
    self.blocks = nn.Sequential(*blocks)

    self.ff = nn.Linear(k, num_class)

  def forward(self, x):
    B, C, H, W = x.size()
    p = self.pathch_size

    x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)",
            p1=p, p2=p)

    token_emb = self.token_emb(x)

    cls_token = repeat(self.cls_token, "1 1 k -> b 1 k", b=B)
    token_emb = torch.cat((cls_token, token_emb), dim=1)

    emb = token_emb + self.pos_emb[:,:token_emb.size(1)]

    out = self.blocks(emb)

    cls_out = out[:, 0]

    return self.ff(cls_out)

In [None]:
class Encoder(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size):
    super().__init__()

    self.k = k
    self.head = head

    self.token_emb = nn.Embedding(token_size, k)
    self.pos_emb = nn.Embedding(seq_len, k)

    self.block = nn.Sequential(*[TransformerBlock(k, head) for _ in range(depth)])

  def forward(self, x):
    token = self.token_emb(x)
    pos = self.pos_emb(torch.arange(token.size(1), device=x.device)).unsqueeze(0)

    emb = token+pos

    return self.block(emb)

你不能使用 nn.Sequential([...])，因為每一層 TransformerBlock 需要接受兩個參數：

```python
x = block(x, enc_out)
```
但 nn.Sequential 預設只會把 上一層的輸出傳給下一層作為唯一參數。無法處理 cross_attention=True 的 TransformerBlock 所需的額外 enc_out。

In [None]:
class Decoder(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size, num_class):
    super().__init__()

    self.k = k
    self.head = head

    self.token_emb = nn.Embedding(token_size, k)
    self.cls_token = nn.Parameter(torch.rand(1,1,k))
    self.pos_emb = nn.Embedding(seq_len+1, k)

    self.block = nn.ModuleList([TransformerBlock(k, head, True) for _ in range(depth)])

    for b in self.block:
      b.self_attn.mask = True

    self.ff = nn.Linear(k, num_class)

  def forward(self, x, enc_out):
    cls_token = repeat(self.cls_token, "1 1 k -> b 1 k", b=x.size(0))
    token = self.token_emb(x)
    token = torch.cat((cls_token, token), dim=1)

    pos = self.pos_emb(torch.arange(token.size(1), device=x.device)).unsqueeze(0)

    emb = token+pos

    for b in self.block:
      emb = b(emb, enc_out)

    return self.ff(emb[:,0])

In [None]:
class CTransformerEncoderDecoder(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size, num_class):
    super().__init__()

    self.encoder = Encoder(k, head, depth, seq_len, token_size)
    self.decoder = Decoder(k, head, depth, seq_len, token_size, num_class)

  def forward(self, x):
    enc_out = self.encoder(x)
    return self.decoder(x, enc_out)