<a href="https://colab.research.google.com/github/Bustion11/NN-projects/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import torchtext.vocab as TV
from torch.nn import functional as F
import timeit
import torchtext

In [None]:
iterable = {"I":1, "ate":2, "an":3, "apple":4}

In [None]:
# Working sentence encoder w/ detection of new words
class ModVocab(TV.Vocab):
  def __init__(self, iterable):
    vocab = TV.vocab(iterable)
    super().__init__(vocab)
  
  def forward(self, tokens):
    dictionary = super().get_itos()
    for token in tokens:
      if(not super().__contains__(token)):
        super().append_token(token)
    
    return super().forward(tokens)


class Encoder(nn.Module):
  def __init__(self, emb_dim, n_heads = 1, k_dim = None, v_dim = None):
    super().__init__()
    self.mha = MultiHeadAttention(emb_dim, n_heads, k_dim, v_dim)
    self.norm = nn.LayerNorm(emb_dim)
    self.lin = PointWiseFFN(emb_dim, emb_dim * 2)
  
  def forward(self, x, causal_mask = None):
    out, mtx = self.mha(x, x, x, causal_mask)

    x = self.norm(x + out)

    out = self.lin(x, causal_mask)

    x = self.norm(x + out)

    return x, mtx


class Decoder(nn.Module):
  def __init__(self, emb_dim, n_heads = 1, k_dim = None, v_dim = None):
    super().__init__()
    self.masked_mha = MultiHeadAttention(emb_dim, n_heads, k_dim, v_dim)
    self.middle_mha = MultiHeadAttention(emb_dim, n_heads, k_dim, v_dim)
    self.final_mha = MultiHeadAttention(emb_dim, n_heads, k_dim, v_dim)

    self.project = PointWiseFFN(emb_dim, emb_dim*2)
    self.norm = nn.LayerNorm(emb_dim)

  def forward(self, encoder_output, decoder_output, causal_mask = None, attn_mask = None):
    out, mtx = self.masked_mha(decoder_output, decoder_output, decoder_output, causal_mask = causal_mask, attn_mask = attn_mask)
    decoder_output = self.norm(out) + decoder_output

    out, mtx = self.middle_mha(encoder_output, encoder_output, decoder_output, causal_mask = causal_mask)
    decoder_output = self.norm(out) + decoder_output

    out, mtx = self.final_mha(decoder_output, decoder_output, decoder_output, causal_mask = causal_mask)
    decoder_output = self.norm(out) + decoder_output

    out = self.project(decoder_output)
    decoder_output = self.norm(out) + decoder_output

    return decoder_output


class Attention(nn.Module):
  def __init__(self, **kwargs):
    super().__init__()
  
  def forward(self, q, k, v, attn_mask = None, causal_mask = None):
    k = torch.transpose(k, 1, 2)
    attn_mtx = torch.bmm(q, k)
    attn_mtx = torch.div(attn_mtx, k.shape[0]**(1/2))
  
    if attn_mask is not None:
      attn_mtx = attn_mtx + attn_mask

    attn_mtx = F.softmax(attn_mtx, -1)

    if causal_mask is not None:
      attn_mtx = attn_mtx * causal_mask
      
    v = torch.bmm(attn_mtx, v)
  
    return v, attn_mtx
    

class MultiHeadAttention(nn.Module):
  def __init__(self, emb_dim, n_heads = 1, k_dim = None, v_dim = None):
    super().__init__()

    if k_dim is None:
      k_dim = emb_dim
    
    if v_dim is None:
      v_dim = emb_dim

    self.n_heads = n_heads
    head_dim = emb_dim//n_heads

    self.q_proj = self._build_stack(nn.Linear(emb_dim, head_dim, False))
    self.k_proj = self._build_stack(nn.Linear(k_dim, head_dim, False))
    self.v_proj = self._build_stack(nn.Linear(v_dim, head_dim, False))

    self.attn = Attention()

    self.o_proj = nn.Linear(head_dim*n_heads, emb_dim, False)

  def _build_stack(self, layer: nn.Module):
    stack = nn.ModuleList()
    for i in range(self.n_heads):
      stack.append(layer)
    return stack

  def _apply(self, x: torch.Tensor, layer_stack: nn.ModuleList) -> torch.Tensor:
    if self.n_heads == 1:
      return layer_stack[0](x)

    B, N, E = x.shape
    x = x.unsqueeze(1)
    temp = layer_stack[0](x)

    for i in range(1, self.n_heads):
      temp = torch.cat((temp, layer_stack[i](x)), 1)

    temp = temp.reshape(B*self.n_heads, N, E//self.n_heads)
    return temp
  
  def forward(self, q, k, v, causal_mask: torch.Tensor = None, attn_mask = None):
    if causal_mask is not None:
      causal_mask = causal_mask.repeat_interleave(self.n_heads, 0)

    q = self._apply(q, self.q_proj)
    k = self._apply(k, self.k_proj)
    v = self._apply(v, self.v_proj)

    output, attn_mtx = self.attn(q, k, v, attn_mask, causal_mask)

    B, N, E = output.shape
    output = output.transpose(0, 1).reshape(N, B//self.n_heads, E*self.n_heads).transpose(0, 1)
    
    output = self.o_proj(output)
    return output, attn_mtx


class PointWiseFFN(nn.Module):
  def __init__(self, in_emb_dim, inner_dim, out_emb_dim = None):
    super().__init__()
    if out_emb_dim is None:
      out_emb_dim = in_emb_dim
    self.l1 = nn.Linear(in_emb_dim, inner_dim)
    self.l2 = nn.Linear(inner_dim, out_emb_dim)
    self.act = nn.ReLU()

  def forward(self, x, causal_mask = None):
    assert len(x.shape) == 3
    if causal_mask is None:
      causal_mask = torch.tensor([1])

    x = self.l2(self.act(self.l1(x))*causal_mask)*causal_mask

    return x
    

class PositionalEmbedding(nn.Module):
  def __init__(self, num_words, sequence_len, emb_dim, pad_idx = 0):
    super().__init__()
    self.embed = nn.Embedding(num_words, emb_dim, pad_idx)
    self.seq_len = sequence_len
    self.emb_dim = emb_dim
    self.pad_idx = pad_idx

    self.register_buffer('PE', self.calculate_pos_emb(sequence_len, emb_dim))

  def calculate_pos_emb(self, sentence_len, emb_dim):
    power = torch.arange(0, emb_dim, 2).div(emb_dim)
    pos = torch.arange(0, sentence_len, 1).unsqueeze(1).repeat_interleave(emb_dim//2, -1)
    pos = pos.div(torch.tensor(10000).pow(power))
    pe = torch.zeros(sentence_len, emb_dim)
    pe[:, 0::2] = torch.sin(pos)
    pe[:, 1::2] = torch.cos(pos)
    pe = pe.unsqueeze(0)
    return pe

  def forward(self, x):
    padding_mask = torch.where(x == self.pad_idx, x, 1).unsqueeze(-1)
    
    x = self.embed(x)
    x = x + self.get_buffer('PE')
    return x*padding_mask, padding_mask

In [None]:
class Transformer(nn.Module):
  def __init__(self, num_words, seq_len, emb_dim, pad_idx = 0, shared_n_heads = 1, num_encoder = 1, num_decoder = 1):
    super().__init__()
    self.encoder_stack = nn.ModuleList()
    self.decoder_stack = nn.ModuleList()

    for i in range(num_encoder):
      self.encoder_stack.append(Encoder(emb_dim, shared_n_heads))

    for i in range(num_decoder):
      self.decoder_stack.append(Decoder(emb_dim, shared_n_heads))

    self.embeder_encoder = PositionalEmbedding(num_words, seq_len, emb_dim)
    self.embeder_decoder = PositionalEmbedding(num_words, seq_len, emb_dim)

    attn_mask = torch.full((1, seq_len, seq_len), -torch.inf).triu(1)
    self.register_buffer('attn_mask', attn_mask)

  def forward(self, input, output):
    encoder_output, causal_mask_input = self.embeder_encoder(input)
    decoder_output, causal_mask_output = self.embeder_decoder(output)

    for encoder_layer in self.encoder_stack:
      encoder_output, encoder_mtx = encoder_layer(encoder_output, causal_mask_input)
    
    for decoder_layer in self.decoder_stack:
      decoder_output = decoder_layer(encoder_output, decoder_output, causal_mask_output, self.get_buffer('attn_mask'))
    
    return decoder_output


In [None]:
NUM_WORDS = 100
NUM_CLASS = 10
SEQ_LEN = 64
EMB_DIM = 128
N_HEADS = 8

test_input = torch.randint(1, 100, (3, 64))

In [None]:
print(test_input)

tensor([[98, 78, 61, 85, 31, 32, 29, 69, 99,  6, 26, 56, 48, 43, 54, 26, 16, 48,
         85, 37, 73, 33, 52,  6,  6, 79, 87, 35, 33, 29, 56, 79, 78, 68, 59, 48,
         57, 94, 62, 29, 49, 68, 74, 12, 22, 32, 55,  2, 99, 90, 95, 73, 41, 11,
         23, 89, 57, 54, 48, 77, 44, 48, 54,  1],
        [27, 88, 34, 31, 63, 49, 13, 92, 11, 36, 74,  7, 37, 31, 29, 31, 81,  5,
         95, 12,  2, 99, 37, 97,  1, 53, 29, 60, 37, 16, 63, 74, 77, 46, 72, 48,
         63, 37, 80, 70, 67,  2, 34, 91, 64, 14, 64, 81, 52, 21, 54, 72, 55, 98,
         46, 31, 80, 46, 14, 18, 50, 80, 19, 94],
        [43, 27, 49, 77, 80, 88, 16, 72,  8, 65, 55, 36, 42, 90, 35, 26, 68, 51,
         93, 69, 79, 32, 80, 77, 73, 46, 32, 82, 21, 82, 57, 60, 99, 68, 96, 53,
         49, 86, 58, 16, 32, 30, 78, 41, 17, 48, 90,  5, 66, 79, 77, 97,  3, 42,
         49, 34, 52, 40, 79, 63, 97, 46,  9, 46]])


In [None]:
test_model = Transformer(NUM_WORDS, NUM_CLASS, SEQ_LEN, EMB_DIM, 0, N_HEADS, 4, 4)

In [None]:
test_output = test_model(test_input, test_input)

In [None]:
softmax = nn.Softmax(-1)

In [None]:
(softmax(test_output)*100).round()

tensor([[ 1.,  0.,  2., 56.,  0.,  5.,  2., 17., 16.,  1.],
        [ 0.,  0.,  1., 67.,  0.,  1.,  1., 11., 17.,  1.],
        [ 0.,  0.,  1., 73.,  0.,  7.,  1.,  5., 12.,  0.]],
       grad_fn=<RoundBackward0>)