In [None]:
from torch.utils.data import Dataset
import torch.nn.functional as F
from collections import Counter
from os.path import exists
import torch.optim as optim
import torch.nn as nn
import numpy as np
import random
import torch
import math
import re

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len = 80):
        super().__init__()
        self.d_model = d_model
        pe = torch.zeros(max_seq_len, d_model)
        pe.requires_grad = False
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:,:x.size(1)]


In [None]:
def attention(q, k, v): #No masking or dropout option implementation, should implement later

  scores = q.matmul(k.transpose(-2,-1))
  scores /= math.sqrt(q.shape(-1))

  scores = F.softmax(scores,dim = -1)
  output = scores.matmul(v)

  return output


In [None]:
class MultiHeadAttention(nn.Module):

  def __init__(self, n_heads, out_dim):
    super().__init__()
    self.linear = nn.Linear(out_dim, out_dim*3) #Creates matrix that we split into q,k,v later
      # We do {out_dim * 3} cuz we are later gonna split it into 3 parts  (why nn.linear? Neural network basically just matrix multiplication)
    self.n_heads = n_heads
    self.out_dim = out_dim
    self.out_dim_per_head = out_dim // n_heads
    self.out = nn.Linear(out_dim,out_dim)

  def split_heads(self, t): #t is a matrix
    return t.reshape(t.shape[0], -1, self.n_heads, self.out_dim_per_head)

  def forward(self, x):
    #what is x?
    qkv = self.linear(x)
    q = qkv[:,:, :self.out_dim]
    k = qkv[:,:, self.out_dim:self.out_dim*2]
    v = qkv[:,:, self.out_dim*2:]

    q,k,v = [self.split_heads(t) for t in (q,k,v)]

    q,k,v = [t.transpose(1,2) for t in (q,k,v)] # we switch row index 1 and 2 for the math (check blogpost)
                                                #https://hyugen-ai.medium.com/transformers-in-pytorch-from-scratch-for-nlp-beginners-ff3b3d922ef7

    scores = attention(q,k,v)
    scores = scores.transpose(1,2).contiguous().view(scores.shape[0],-1,self.out_dim)

    out = self.out(scores)

    return out



In [None]:

class FeedForward(nn.Module):

  def __init__(self, inp_dim, inner_dim, dropout = 0.1):
    super().__init__()
    self.linear1 = nn.Linear(inp_dim, inner_dim)
    self.linear2 = nn.Linear(inner_dim, inp_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.linear2(self.dropout(F.relu(self.linear1(x))))




In [None]:
class EncoderLayer(nn.Module):

  def __init__(self, n_heads, inner_transformer_size, inner_ff_size, dropout = 0.1):
    super().__init__()
    self.mha = MultiHeadAttention(n_heads. inner_transformer_size, dropout)
    self.ff = FeedForward(inner_transformer_size, inner_ff_size, dropout)
    self.norm1 = nn.LayerNorm(inner_transformer_size)
    self.norm2 = nn.LayerNorm(inner_transformer_size)
    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

  def forward(self, x):
    x2 = self.norm1(x)
    x = x + self.dropout1(self.mha(x2))
    s2 = self.norm2(x)
    x = x + self.dropout2(self.ff(x2))
    return x

In [None]:
class Transformer(nn.Module):

  def __init__(self, n_code, n_heads, embed_size,
               inner_ff_size, n_embeddings, seq_len, dropout = 0.1):
    super().__init__()

    self.embeddings = nn.Embedding(n_embeddings, embed_size)
    self.pe = PositionalEmbedding(embed_size, seq_len)

    encoders = []
    for i in range(n_code):
        encoders += [EncoderLayer(n_heads, embed_size, inner_ff_size, dropout)]
        self.encoders = nn.ModuleList(encoders)

        #language model
        self.norm = nn.LayerNorm(embed_size)
        self.linear = nn.Linear(embed_size, n_embeddings, bias=False)


    def forward(self, x):
        x = self.embeddings(x)
        x = x + self.pe(x)
        for encoder in self.encoders:
            x = encoder(x)
        x = self.norm(x)
        x = self.linear(x)
        return x