In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

In [2]:
class InputEmbeddings(nn.Module):
  def __init__(self, vocab_size, d_model=512):
    super().__init__()
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.embedding(x) * math.sqrt(self.d_model)

In [16]:
class PositionalEncoding(nn.Module):
  def __init__(self, seq_size, d_model=512):
    self.seq_size = seq_size
    self.d_model = d_model

    self.pos_enc = torch.zeros(seq_size, d_model)
    pos = torch.arange(0, seq_size).view(-1, 1).float()
    common = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

    self.pos_enc[:, 0::2] = torch.sin(pos * common)
    self.pos_enc[:, 1::2] = torch.cos(pos * common)
    self.pos_enc.view(1, seq_size, -1).requires_grad_(False)

  def forward(self, x):
    return x + self.pos_enc[:, :x.size(1)] # x.size(1) is the sequence length, pos_enc is always fixed


In [17]:
class FeedForward(nn.Module):
  def __init__(self, d_model=512, d_ff=2048, drop=0.1):
    super().__init__()
    self.d_model = d_model
    self.d_ff = d_ff
    
    self.net = nn.Sequential(
      nn.Linear(self.d_model, self.d_ff),
      nn.Relu(),
      nn.Dropout(drop),
      nn.Linear(self.d_ff, self.d_model),
    )
  
  def forward(self, x):
    return self.net(x)

In [None]:
class MHA(nn.Module):
  def __init__(self, d_model=512, heads=8):
    super().__init__()
    self.d_model = d_model
    self.heads = heads
    self.heads_dim = d_model // heads

    self.Qw = nn.Linear(self.heads_dim, self.heads_dim)
    self.Kw = nn.Linear(self.heads_dim, self.heads_dim)
    self.Vw = nn.Linear(self.heads_dim, self.heads_dim)

  def forward(self, q, k, v):
