In [None]:
# Transformer implementation from scratch with PyTorch
# https://medium.com/towards-data-science/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb

# Multi-Head Attention


In [None]:
class MultiHeadAttention(nn.module):
  def __init__(self, d_model, num_heads):

    # super(child_class).some_function(): accessing some function in the parent class
    super(MultiHeadAttention, self).__init__()

    assert d_model % num_heads == 0, "d_model must be divisble by num heads"

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads # dimension of each head

    # learnable parameters
    self.W_q = nn.Linear(d_model, d_model) # projects input embeddings to query vector
    self.W_k = nn.Linear(d_model, d_model) # key vector
    self.W_v = nn.Linear(d_model, d_model) # value vector
    self.W_o = nn.Linear(d_model, d_model) #

  def scaled_dot_product_attention(self, Q, K, V, mask=None):

    # Q x K / sqrt(dim_k)
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

    if mask is not None:
      # replace with a very small value so it is ignored by softmax
      attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

    attn_probs = torch.softmax(attn_scores, dim = -1)
    output = torch.matmul(attn_probs, V) # multiply with V

    return output

  # split x, dim(x) = d_model = d_k * num_heads
  def split_heads(self, x):
    batch_size, seq_length, d_model = x.size()
    # x.view() reshapes x
    # transpose to swap the 2nd and 3rd dimensions
    return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1,2)

  # combining the results of each head into a single tensor
  # .contiguous(): ensures that the tensor is stored in a contiguous block of memory.
  # This is necessary before calling .view() because transpose() can sometimes return a non-contiguous tensor.
  def combine_heads(self, x):
    batch_size, _, seq_length, d_k = x.size()
    return x.transpose(1,2).contiguous().view(batch_size, seq_length, self.d_model)

  def forward(self, Q, K, V, mask = None):

    # transforms Query Q into the space of d_model and then split into heads
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_q(K))
    V = self.split_heads(self.W_q(V))

    attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
    output = self.W_o(self.combine_heads(attn_output))

    return output


# (FNN) Position-wise Feed-Forward Networks

In [None]:
# FNN is what follows attention
# also where "facts" are stored in the model
# a very good explanation of FNN, model # parameters and the scaling law
# https://www.youtube.com/watch?v=9-Jl0dxWQs8

#

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# Encoder

In [None]:
# Encoder is essentailly attention + FNN
# it is called encoder because it encodes imformation into the model