In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module): #inherits from nn.Module
    def __init__(self, d_in, d_out): # contructor of the class
        super().__init__() # intialize the parent class
        # keyword self in a classs refers to the instance of the class
        self.d_in = d_in
        self.d_out = d_out
        # create a layer that applies an affine transformation to the input
        # y = Ax + b, where A is a weight matrix and b is a bias vector
        # Weights intialized with a uniform distribution
        # its weights and biases are stored as torch.nn.Parameter objects.
        # This makes them part of the model’s .parameters() 
        # returns the parameters of the model when called
        self.Q = nn.Linear(d_in, d_out) 
        self.K = nn.Linear(d_in, d_out)
        self.V = nn.Linear(d_in, d_out)

    def forward(self, x):
        queries = self.Q(x) # apply the affine transformation to the input x
        keys = self.K(x)
        values = self.V(x)
        # Compute the attention scores, bmm is batch matrix multiplication
        # scores = queries * keys^T / sqrt(d_out)
        scores = torch.bmm(queries, keys.transpose(1, 2)) 
        # keys.transpose(1, 2) transposes the last two dimensions
        # (batch_size, seq_len, d_out) -> (batch_size, d_out, seq_len)
        scores = scores / (self.d_out ** 0.5)
        attention = F.softmax(scores, dim=2)
        # converts the attention scores into probabilities along the last dimension, 
        # so each set of scores sums to 1 for every query in the batch.
        hidden_states = torch.bmm(attention, values)
        return hidden_states


In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.out = nn.Linear(hidden_size, hidden_size)
        self.head = nn.ModuleList([
            Attention(hidden_size, hidden_size // num_heads)
            for _ in range(num_heads)
        ])
    
    def foward(self, )