In [1]:
def get_device():
  return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F

In [3]:
import torch
import torch.nn.functional as F
import math

def scaled_dot_prod(q, k, v, mask=None):
    d_k = q.size()[-1]

    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)

    if mask is not None:

        scaled = scaled.masked_fill(mask == 0, float('-inf'))

    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)

    return values, attention


In [4]:
class PosEnc(nn.Module):
  def __init__(self, d_model, max_seq_len):
     super().__init__()
     self.max_seq_len = max_seq_len
     self.d_model = d_model

  def forward(self):
    even_i = torch.arrange(0,self.d_model,2).float()
    denom = torch.pow(10000, even_i/self.d_model)
    pos = (torch.arrange(self.max_seq_len)).reshape(self.max_seq_len,1)
    even_PE = torch.sin(pos/denom)
    odd_PE = torch.cos(pos/denom)
    stacked = torch.stack([even_PE, odd_PE], dim = 2)
    PE = torch.flatten(stacked,start_dim= 1, end_dim= 2)
    return PE

In [5]:
class SentEmb(nn.Module):
  def __init(self,max_seq_len, d_model, lang_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
    super.__init__()
    self.vocab_seize = len(lang_to_index)
    self.max_seq_len = max_seq_len
    self.embedding = nn.Embedding(self.vocab_size, d_model)
    self.lang_to_index = lang_to_index
    self.positional_encoder = PosEnc(d_model, max_seq_len)
    self.dropout = nn.Dropout(p=0.1)
    self.START_TOKEN = START_TOKEN
    self.END_TOKEN = END_TOKEN
    self.PADDING_TOKEN = PADDING_TOKEN

  def batch_tknz(self,batch,start_token,end_token):
    def tknz(sent, start_token, end_token):
      sent_word_indicies = [ self.lang_to_index[token] for token in list(sent)]
      if start_token:
        sent_word_indicies.insert(0,self.lang_to_index[self.START_TOKEN])
      if end_token:
        sent_word_indicies.append(self.lang_to_index[self.END_TOKEN])
      for _ in range(len(sent_word_indicies), self.max_seq_len):
        sent_word_indicies.append(self.lang_to_index[self.PADDING_TOKEN])
      return torch.tensor(sent_word_indicies)


    tokenized = []
    for sent_num in range(len(batch)):
      tokenized.append(tknz(batch[sent_num], start_token,end_token))
    tokenized = torch.stack(tokenized)
    return tokenized.to(get_device())


  def forward(self, x, start_token, end_token):
    x = self.batch_tokenize(x, start_token, end_token)
    x = self.embedding(x)
    pos = self.positional_encoder().to(get_device())
    x = self.dropout(x + pos)

    return x


In [6]:
class multheadatt(nn.Module):
  def __init__(self,d_model,num_heads):
    super().__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model//num_heads
    self.qkv_layer = nn.Linear(d_model, 3*d_model)
    self.linear_layer = nn.Linear(d_model, d_model)


  def forward( self, x, mask):
    batch_size, seq_len, d_model = x.size()
    qkv = self.qkv_layer(x)
    qkv = qkv.reshape(batch_size,seq_len,self.num_heads, 3*self.head_dim)
    qkc = qkv.permute(0,2,1,3)
    q,k,v = qkv.chunk(3,dim=-1)
    values, attention = scaled_dot_prod(q,k,v,mask)
    values = values.permute(0,2,1,3).reshape(batch_size, seq_len,self.num_heads, self.head_dim)
    out = self.linear_layer(values)

    return out

In [7]:
class LayerNorm(nn.Module):
  def __init__(self, parameters_shape, eps = 1e-5):
    super().__init()
    self.paramenters_shape = parameters_shape
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameters_shape))
    self.beta = nn.Parameter(torch.zeros(parameters_shape))


  def forward(self,inputs):
    dims = [-(i + 1) for i in range (len(self.paramenters_shape))]
    mean = inputs.mean(dim = dims, keepdim = True)
    var = ((inputs - mean)** 2).mean(dims + dims)
    std = (var + self.eps).sqrt()
    y = (inputs - mean)/ std
    out = self.gamma * y + self.beta
    return out

In [8]:
class PosFF(nn.Module):
  def __init__ (self, d_model, hidden, drop_prob = 0.1):
    super(PosFF, self). __init()
    self.linear1 = nn.Linear(d_model, hidden)
    self.linear1 = nn.Linear(d_model, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p = drop_prob)


  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear2(x)

    return x


In [9]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
    super(EncoderLayer, self). _init__()
    self.attention = multheadatt(d_model = d_model, num_heads=num_heads)
    self.norm1 = LayerNorm(parameters_shape=[d_model])
    self.dropout1 = nn.Dropout(p = drop_prob)
    self.ffn = PosFF(d_model=d_model, hidden = ffn_hidden, drop_prob=drop_prob)
    self.norm2 = LayerNorm(parameters_shape=[d_model])
    self.dropout2 = nn.Dropout(p = drop_prob)



  def forward(self, x, self_attention_mask):
    residual_x = x.clone()
    x = self.attention(x, mask = self_attention_mask)
    x = self.dropout1(x)
    x = self.norm1(x + residual_x)
    residual_x = x.clone()
    x = self.ffn(x)
    x = self.dropout2(x)
    x = self.norm2(x + residual_x)
    return x


In [10]:
class SeqEnc(nn.Sequential):
  def forward(self, inputs):
    x,self_attention_mask = inputs
    for module in self._modules.values():
      x = module(x, self_attention_mask)
    return x

In [10]:
class Encoder(nn.Module):
  def __init__(self, d_model, ffn_head, num_heads,drop_prob,num_layers,max_seq_len, lang_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN):
    super().__init__()
    self.sentence_embedding = SentEmb(max_seq_len,d_model, lang_to_index,START_TOKEN, END_TOKEN, PADDING_TOKEN)
    self.layers = SeqEnc(*[EncoderLayer(d_model, ffn_head, num_heads, drop_prob)
                                      for _ in range(num_layers)])


    def forward(self,x,self_attention_mask, start_token, end_token):
      x = self.sentence_embedding(x, start_token, end_token)
      x = self.layers(x, self_attention_mask)


      return x




In [11]:
class multheadcrossatt(nn.Module):

  def __init__(self,d_model,num_heads):
    super.__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads
    self.kv_layer = nn.linear(d_model, 2 * d_model)
    self.q_layer = nn.Linear(d_model, d_model)
    self.linear_layer = nn.Linear(d_model, d_model)


  def forward(self, x, y, mask):
    batch_size, seq_len, d_model = x.size()
    kv = self.kv_layer(x)
    q = self.q_layer(y)

    kv = kv.reshape(batch_size, seq_len, self.num_heads, 2 * self.head_dim)
    q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)

    kv = kv.permute(0,2,1,3)
    q = q.permute(0,2,1,3)
    k,v = kv.chunk(2, dim = -1)

    values, attention = scaled_dot_prod(q,k,v,mask)
    values = values.permute(0,2,1,3).reshape(batch_size, seq_len, d_model)

    out = self.linear_layer(values)

    return out

In [12]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
    super(DecoderLayer, self).__init__()
    self.self_attention = multheadatt(d_model = d_model, num_heads = num_heads)
    self.layer_norm1 = LayerNorm(parameters_shape=[d_model])
    self.dropout = nn.dropout( p= drop_prob)


    self.enc_dec_att = multheadcrossatt( d_model=d_model, num_heads=num_heads)
    self.layer_norm2 = LayerNorm(parameters_shape=[d_model])
    self.droput2 = nn.dropout(p = drop_prob)


    self.ffn = PosFF(d_model = d_model, num_heads = num_heads)
    self.layer_norm3 = LayerNorm(parameters_shape=[d_model])
    self.dropout3 = nn.dropout(p = drop_prob)

  def forward(self,x,y,self_attention_mask,cross_attention_mask):
    _y = y.clone()
    y = self.self_attention(y, mask = self_attention_mask)
    y = self.dropout1(y)
    y = self.layer_norm1(y +_y)


    _y = y.clone()
    y = self.enc_dec_att(x,y, mask=cross_attention_mask)
    y = self.droput2(y)
    y = self.layer_norm2(y + _y)


    _y = y.clone()
    y = y.self.ffn(y)
    y = self.dropout3(y)
    y = self.layer_norm3(y + _y)

    return y


In [13]:
class seqdec(nn.Sequential):
  def forward(self, *inputs):
    x,y,self_attention_mask,cross_attention_mask = inputs
    for module in self._modules.values():
      y = module(x,y, self_attention_mask,cross_attention_mask)

    return y


In [15]:
class decoder(nn.Module):
  def __init__(self,d_model,ffn_hidden,num_heads,head_dim,drop_prob, num_layers, max_seq_len,lang_to_index, START_TOKEN,END_TOKEN,PADDING_TOKEN):
     super().__init__()
     self.sentence_embedding = SentEmb(max_seq_len, d_model, lang_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN)
     self.layers = seqdec(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
     for _ in range (num_layers)])

  def forward(self, x,y,self_attention_mask,cross_attention_mask,start_token, end_token):
    y = self.sentence_embedding(y, start_token, end_token)
    y = self.layers(x,y, self_attention_mask, cross_attention_mask)
    return y


In [14]:
class Transformer(nn. Module):
  def __init__(self,d_model,ffn_hidden, num_heads, drop_prob,num_layers, max_seq_len,fr_vocab_size,eng_vocab_size, eng_2_index, fr_2_index,START_TOKEN, END_TOKEN, PADDING_TOKEN):
    super().__init__()
    self.encoder = Encoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_seq_len,eng_2_index,START_TOKEN,END_TOKEN,PADDING_TOKEN)
    self.decoder = decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_seq_len,fr_2_index,START_TOKEN,END_TOKEN,PADDING_TOKEN)
    self.linear = nn.Linear(d_model, fr_vocab_size)
    self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')



  def forward(self,x,y,enc_self_att_mask=None, dec_self_att_mask = None, dec_cross_att_mask = None, enc_start_token = False, enc_end_token = False, dec_start_token = False, dec_end_token = False):
    x = self.encoder(x, enc_self_att_mask, start_token = enc_start_token, end_token = enc_end_token)
    out = self.decoder(x,y,dec_self_att_mask,dec_cross_att_mask, start_token = dec_start_token, end_token = dec_end_token)
    out = self.linear(out)

    return out