In [1]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Looking in indexes: https://download.pytorch.org/whl/cpu
Collecting torch
  Downloading https://download.pytorch.org/whl/cpu/torch-2.1.2%2Bcpu-cp311-cp311-linux_x86_64.whl (184.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.9/184.9 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torchvision
  Downloading https://download.pytorch.org/whl/cpu/torchvision-0.16.2%2Bcpu-cp311-cp311-linux_x86_64.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m70.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio
  Downloading https://download.pytorch.org/whl/cpu/torchaudio-2.1.2%2Bcpu-cp311-cp311-linux_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m73.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting filelock (from torch)
  Downloading https://download.pytorch.org/whl/filelock-3.9.0-py3-none-any.whl (9.7 kB)
Collecting sympy (from torch)


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [34]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Throw error if the model dimension is not divisble by the number of heads
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads # Dimensions of query, key, and value to calc output

        # Linear layers apply a linear transformation to incoming data
        # nn.Linear() automatically initialize weights and biases
        self.W_q = nn.Linear(d_model, d_model) # Query Transformation
        self.W_k = nn.Linear(d_model, d_model) # Key Transformation
        self.W_v = nn.Linear(d_model, d_model) # Value Transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Not sure what the mask does but default to zero for now
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        attn_probs = torch.softmax(attn_scores, dim=-1) # Softmax across the rows
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size() # Batch size, sequence length, and dimension of the model set to the size of the provided tensor
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # The transpose swaps the second and third dimesion... why???

    def combine_heads(self, x):
        # Combine the heads back to original shape
        batch_size, _, seq_length, d_k = x.size() # Second value is num heads which we are concatting out
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        weighted_q = self.W_q(Q)
        weighted_k = self.W_k(K)
        weighted_v = self.W_v(V)

        print("------ WEIGHTED Q SHAPE: " + str(weighted_q.shape))
        print("batch_size, seq_length, d_model; Now weighted by a linear transformation")
        
        # Apply linear transformations and split heads
        Q = self.split_heads(weighted_q)
        K = self.split_heads(weighted_k)
        V = self.split_heads(weighted_v)

        print("------- SPLIT HEAD Q SHAPE: " + str(Q.shape))
        print("batch_size, num_heads, seq_length, dimension_keys; Model dimiensions split across heads and matrix transposed")

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output)) # Combine the outputs and then multiply against output weights
        return output

In [36]:
import torch

# Initialize parameters
d_model = 8  # Model's dimension
num_heads = 2  # Number of attention heads

# Ensure d_model is divisible by num_heads
assert d_model % num_heads == 0

# Create an instance of MultiHeadAttention
mha = MultiHeadAttention(d_model, num_heads)

# Prepare dummy input data (batch_size, seq_length, d_model)
batch_size = 2
seq_length = 10
Q = torch.rand(batch_size, seq_length, d_model)
K = torch.rand(batch_size, seq_length, d_model)
V = torch.rand(batch_size, seq_length, d_model)

print("------ STARTING Q SHAPE: " + str(Q.shape))
print("batch_size, seq_length, d_model")

# Forward pass through the MultiHeadAttention module
# Sequence length never changes throughout the path because we are not modifying the input
output = mha.forward(Q, K, V)

# Batch Size - Number of times the operation will be repeated
# Num Heads - Number of parallel computations
# Sequence Length - Number of tokens in the sequence
# Dimension Keys - Number of features representing a token; aka the size of the vector per token

# Input and Output shape should always be the same size
print("------ FINAL OUTPUT SHAPE: " + str(output.shape))  # Should be [batch_size, seq_length, d_model]

------ STARTING Q SHAPE: torch.Size([2, 10, 8])
batch_size, seq_length, d_model
------ WEIGHTED Q SHAPE: torch.Size([2, 10, 8])
batch_size, seq_length, d_model; Now weighted by a linear transformation
------- SPLIT HEAD Q SHAPE: torch.Size([2, 2, 10, 4])
batch_size, num_heads, seq_length, dimension_keys; Model dimiensions split across heads and matrix transposed
------ FINAL OUTPUT SHAPE: torch.Size([2, 10, 8])
