In [1]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
  def __init__(self, embed_size, heads):
    super(SelfAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size // heads

    assert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"
    self.values = nn.Linear(self.head_dim, self.head_dim, bias = False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias = False)
    self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

  def forward(self, values, keys, query, mask):
    N = query.shape[0]
    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

    #split embeddings into self.heads pieces
    values = values.reshape(N, value_len, self.head_dim)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)
    queries = query.reshape(N, key_len, self.heads, self.head_dim)

    energy = torch.eimsum("nqhd, nkhd --> nhqk", [queries, keys])# q= query length, h= head, d = heads dimension, n = batch size, k = key length, h = head, d= head dimension
    #addiding a mask
    if mask is not None:
      energy = energy.masked_fill(mask == 0, float(-1e28))#if the leement of mask is 0 the shut it off
    #pass this through softmax
    attention = torch.softmax(energy / (self.embed_size ++ (1/2)), dim = 1)
    out = torch.einsum("nhql, nlhd --> nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
    out = self.fc_out(out)
    return out