<a href="https://colab.research.google.com/github/HWAN722/Deep-learning-models/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

def scale_dot_product(q, k, v, scale=None):
  d_k = k.size()[-1]
  scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)
  if scale:
    scores = scores.masked_fill(scale == 0, float('-inf'))
  p_attn = F.softmax(scores, dim=-1)
  values = torch.matmul(p_attn, v)
  return values

class AttentionHead(nn.Module):
  def __init__(self, embed_dim, dim_q, dim_k):
    super().__init__()
    self.linear_Q = nn.Linear(embed_dim, dim_q)
    self.linear_K = nn.Linear(embed_dim, dim_k)
    self.linear_V = nn.Linear(embed_dim, dim_k)

  def forward(self, q, k, v, mask):
    return scale_dot_product(self.linear_Q(q), self.linear_K(k), self.linear_V(v), mask)


class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads, dim_q, dim_k):
    super().__init__()
    self.heads = nn.ModuleList([AttentionHead(embed_dim, dim_q, dim_k) for _ in range(num_heads)])
    self.linear = nn.Linear(embed_dim, embed_dim)

  def forward(self, q, k, v, mask=None):
    return self.linear(torch.cat([head(q, k, v, mask) for head in self.heads], dim=-1))


def feed_forward(embed_dim, hidden_dim) -> nn.Sequential:
  return nn.Sequential(
      nn.Linear(embed_dim, hidden_dim),
      nn.ReLU(),
      nn.Linear(hidden_dim, embed_dim)
  )


class Residual(nn.Module):
  def __init__(self, sublayer:nn.Module, dimension, dropout=0.1):
    super().__init__()
    self.sublayer = sublayer
    self.norm = nn.LayerNorm(dimension)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, *args, **kwargs):
    return self.norm(x + self.dropout(self.sublayer(x, *args, **kwargs)))

def positional_encoding(max_len, d_model):
  def get_angle(pos, i, d_model):
    angle_rate = 1 / np.power(10000, 2 * (i // 2) / d_model)
    return pos * angle_rate

  angle_rads = get_angle(np.arange(max_len)[:, np.newaxis],
               np.arange(d_model)[np.newaxis, :],
               d_model)
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
  pos_encoding = angle_rads[np.newaxis, ...]
  return torch.tensor(pos_encoding, dtype=torch.float32)

class TransformerEncoderLayer(nn.Module):
  def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
    super().__init__()
    query_dim = key_dim = max(embed_dim // num_heads, 1)
    self.multi_head_attention = Residual(
        MultiHeadAttention(embed_dim, num_heads, query_dim, key_dim),
        embed_dim,
        dropout)
    self.feed_forward = Residual(
        feed_forward(embed_dim, ffn_dim),
        embed_dim,
        dropout)

  def forward(self, src, mask):
    return self.feed_forward(self.multi_head_attention(src, src, src, mask))

class TransformerEncoder(nn.Module):
  def __init__(self, num_layers, embed_dim, num_heads, ffn_dim, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList([TransformerEncoderLayer(embed_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)])

  def forward(self, x, mask=None):
    seq_len = x.size(1)
    dimension=x.size(2)
    x = x + positional_encoding(seq_len, dimension)
    for layer in self.layers:
      x = layer(x, mask)
    return x

class TransformerDecoderLayer(nn.Module):
  def __init__(self, embed_dim, num_heads, ffn_dim, dropout=0.1):
    super().__init__()
    head_dim = embed_dim // num_heads
    self.self_attention = Residual(
        MultiHeadAttention(embed_dim, num_heads, head_dim, head_dim),
        embed_dim,
        dropout)
    self.cross_attention = Residual(
        MultiHeadAttention(embed_dim, num_heads, head_dim, head_dim),
        embed_dim,
        dropout)
    self.feed_forward = Residual(
        feed_forward(embed_dim, ffn_dim),
        embed_dim,
        dropout)

  def forward(self, src, memory):
    target = self.self_attention(src, src, src)
    target = self.cross_attention(target, memory, memory)
    target = self.feed_forward(target)
    return target

class TransformerDecoder(nn.Module):
  def __init__(self, num_layers, embed_dim, num_heads, ffn_dim, dropout=0.1):
    super().__init__()
    self.layers = nn.ModuleList([TransformerDecoderLayer(embed_dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)])
    self.final_layer = nn.Linear(embed_dim, embed_dim)

  def forward(self, x, memory):
    seq_len = x.size(1)
    dimension=x.size(2)
    x = x + positional_encoding(seq_len, dimension)
    for layer in self.layers:
      x = layer(x, memory)
    return torch.softmax(self.final_layer(x), dim=-1)

class Transformer(nn.Module):
  def __init__(self, num_layers, embed_dim, num_heads, ffn_dim, dropout=0.1):
    super().__init__()
    self.encoder = TransformerEncoder(num_layers, embed_dim, num_heads, ffn_dim, dropout)
    self.decoder = TransformerDecoder(num_layers, embed_dim, num_heads, ffn_dim, dropout)

  def forward(self, src, target):
    return self.decoder(target, self.encoder(src))

batch_size=2
seq_len=4
embd_dim=8

src = torch.rand(batch_size, seq_len, embd_dim)
tgt = torch.rand(batch_size, seq_len, embd_dim)

transformer = Transformer(2, embd_dim, 4, 512)
out = transformer(src, tgt)

print(out)

tensor([[[0.2240, 0.2658, 0.0978, 0.0648, 0.0655, 0.1634, 0.0594, 0.0593],
         [0.2318, 0.2820, 0.0737, 0.0478, 0.0701, 0.1623, 0.0614, 0.0711],
         [0.3643, 0.1717, 0.0644, 0.0433, 0.0843, 0.0924, 0.0957, 0.0838],
         [0.2160, 0.2720, 0.0575, 0.0559, 0.1137, 0.0989, 0.0942, 0.0918]],

        [[0.2272, 0.2034, 0.0989, 0.0723, 0.1048, 0.1242, 0.0772, 0.0920],
         [0.2709, 0.1994, 0.1149, 0.0374, 0.0469, 0.1996, 0.0635, 0.0672],
         [0.2017, 0.2876, 0.0751, 0.0522, 0.0793, 0.1472, 0.0800, 0.0769],
         [0.1946, 0.1486, 0.0787, 0.0571, 0.1171, 0.0762, 0.2135, 0.1143]]],
       grad_fn=<SoftmaxBackward0>)
