<a href="https://colab.research.google.com/github/JeffBla/implement_transformer_from_scratch/blob/main/implement_transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat

In [None]:
class MultiheadSelfAttention(nn.Module):
  def __init__(self, k, head=4, mask=False):
    super().__init__()

    self.k = k
    self.head = head
    self.mask = mask

    self.to_query = nn.Linear(k, k, bias=False)
    self.to_key = nn.Linear(k, k, bias=False)
    self.to_value = nn.Linear(k, k, bias=False)

    self.unitify_layer = nn.Linear(k, k)

  def forward(self, x, kv=None):
    """ x : (batch_size × seq_length × embed_dim) """

    h = self.head
    if kv is None:
      kv = x
    b,tq,k = x.size()
    _,tk,_ = kv.size()

    s = k // h

    q = self.to_query(x).view(b,tq,h,s)
    k_ = self.to_key(kv).view(b,tk,h,s)
    v = self.to_value(kv).view(b,tk,h,s)

    attn_scores = torch.einsum("bths,behs->bhte", (q,k_)) / k**0.5

    if self.mask:
      indices = torch.triu_indices(tq, tk, offset=1)
      attn_scores[: , indice[0], indice[1]] = float("-inf")

    attn_scores = F.softmax(attn_scores, dim=-1)

    out = torch.einsum("bhtd,bdhe->bthe", (attn_scores,v))
    out = out.reshape((b,tq,k))

    return self.unitify_layer(out)

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, k, head):
    super().__init__()

    self.k = k
    self.head = head

    self.attention = MultiheadSelfAttention(k, head)
    self.norm1 = nn.LayerNorm(k)
    self.ff = nn.Sequential(
        nn.Linear(k, 4*k),
        nn.ReLU(),
        nn.Linear(4*k, k)
    )
    self.norm2 = nn.LayerNorm(k)

  def forward(self, x):
    attended = self.attention(x)
    norm1ed = self.norm1(attended + x)
    ff_out = self.ff(norm1ed)

    return self.norm2(ff_out+norm1ed)

In [None]:
class CTransformer(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size, num_class):
    super().__init__()

    self.token_emb = nn.Embedding(token_size, k)
    self.pos_emb = nn.Embedding(seq_len, k)

    blocks = []
    for i in range(depth):
      blocks.append(TransformerBlock(k, head))
    self.blocks = nn.Sequential(*blocks)

    self.ff = nn.Linear(k, num_class)

  def forward(self, x):
    token_emb = self.token_emb(x)
    b, t, k = token_emb.size()

    pos_emb = torch.arange(t)
    pos_emb = self.pos_emb(pos_emb).unsqueeze(0).expand([b,t,k])

    emb = token_emb + pos_emb

    out = self.blocks(emb)

    out = self.ff(out.mean(dim=1))

    return F.log_softmax(out, dim=1)

In [None]:
class CTransformerCLS(nn.Module):
  def __init__(self, k, head, depth, seq_len, token_size, num_class):
    super().__init__()

    self.k = k
    self.token_emb = nn.Embedding(token_size, k)
    self.pos_emb = nn.Embedding(seq_len+1, k)
    self.cls_token = nn.Parameter(torch.rand(1,1,k))

    blocks = []
    for i in range(depth):
      blocks.append(TransformerBlock(k, head))
    self.blocks = nn.Sequential(*blocks)

    self.ff = nn.Linear(k, num_class)

  def forward(self, x):
    b,t = x.size()
    cls_token = self.cls_token.expand([b,1,self.k])
    token_emb = self.token_emb(x)
    token_emb = torch.cat((cls_token, token_emb), dim=1)

    pos_emb = torch.arange(t+1, device=x.device)
    pos_emb = self.pos_emb(pos_emb).unsqueeze(0).expand([b,t+1,self.k])

    emb = token_emb + pos_emb

    out = self.blocks(emb)

    out = out[:,0]

    return self.ff(out)

In [None]:
class ViTTransformer(nn.Module):
  def __init__(self, k, head, depth, patch_size, img_size, channel, num_class):
    super().__init__()

    assert img_size % patch_size == 0

    self.img_size = img_size
    self.patch_size = patch_size
    self.patch_dim = pathch_size * patch_size * channel
    self.num_patch = (img_size // patch_size) ** 2
    self.token_emb = nn.Linear(patch_dim, k)
    self.pos_emb = nn.Embedding(num_patch+1, k)

    self.cls_token = nn.Parameter(torch.rand(1,1,k))

    blocks = []
    for i in range(depth):
      blocks.append(TransformerBlock(k, head))
    self.blocks = nn.Sequential(*blocks)

    self.ff = nn.Linear(k, num_class)

  def forward(self, x):
    B, C, H, W = x.size()
    p = self.pathch_size

    x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)",
            p1=p, p2=p)

    token_emb = self.token_emb(x)
    b, t, k = token_emb.size()

    pos_emb = torch.arange(t)
    pos_emb = self.pos_emb(pos_emb).unsqueeze(0).expand([b,t,k])

    emb = token_emb + pos_emb

    out = self.blocks(emb)

    out = self.ff(out.mean(dim=1))

    return F.log_softmax(out, dim=1)