In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_k
        self.n_heads = n_heads
        
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.key = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)
        
        self.out = nn.Linear(d_k * n_heads, d_model)
        
    def forward(self, q, k, v, mask=None):
        
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
        
        q = self.query(q) # N x T x (h*d_k) 
        k = self.key(k)   # N x T x (h*d_k)
        v = self.value(v) # N x T x (h*d_v) # d_v == d_k
        
        N = q.shape[0] # batch size
        T = q.shape[1] # sequence length
        
        # Changing shapes (reuqing for matrix multiplication)
        # view: (N, T, h*d_k) -> (N, T, h, d_k)
        # transpose: (N, T, h, d_k) -> (N, h, T, d_k)
        
        q = q.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # (N, h, T, d_k) x (N, h, d_k, T) -> (N, h, T, T)
        atention_scores = q * k.transpose(-2, -1) / np.sqrt(self.d_k)
        
        # (N, h, T, T) x (N, h, T, d_k) -> (N, h, T, d_k)
        A = atention_scores @ v
        
        # Reshape (N, h, T, d_k) -> (N, T, h, d_k) -> (N, T, h*d_k)
        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T, self.n_heads * self.d_k)
        
        
        if mask is not None:
            # Mask has (N, T) shape, so we need to add two (inner) dimensions
            # We also change zeros with -inf, so that softmax will ignore these values
            atention_scores = atention_scores.masked_fill(
                 mask[:, None, None, :] == 0, float('-inf')
                 )
        attention_weights = F.softmax(atention_scores, dim=-1)
        
        A = attention_weights @ v
        
        # Reshape (N, h, T, d_k) -> (N, T, h, d_k) -> (N, T, h*d_k)
        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T, self.n_heads * self.d_k)
        
        return self.fc(A)
        
        
                