In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class SingleHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.d_k= d_k
        self.d_v= d_v
        self.Q_W= nn.Linear(d_model, d_k, bias= False)
        self.K_W= nn.Linear(d_model, d_k, bias= False)
        self.V_W= nn.Linear(d_model, d_v, bias= False)
        self.out_proj= nn.Linear(d_v, d_model, bias= False)
    
    def forward(self, Q, K, V, mask= None):
        # Q, K, V are of shape (batch_size, seq_len, d_model)
        # Project to d_k and d_v
        Q_proj= self.Q_W(Q) # (batch_size, seq_len_q, d_k)
        K_proj= self.K_W(K) # (batch_size, seq_len_k, d_k)
        V_proj= self.V_W(V) # (batch_size, seq_len_v, d_v)

        # Compute attention scores
        scores= torch.matmul(Q_proj, K_proj.transpose(-2, -1)) # Dot product -> Shape: (batch_size, seq_len_q, seq_len_k)
        scores= scores / (self.d_k ** 0.5) # Scaling

        if mask is not None:
            scores= scores.masked_fill(mask == 0, -1e9)
        
        attn_weights= F.softmax(scores, dim= -1) # (batch_size, seq_len_q, seq_len_k)
        output= torch.matmul(attn_weights, V_proj) # (batch_size, seq_len_q, d_v)
        output= self.out_proj(output)
        return output

In [3]:
# Setting random seed for reproducibility
torch.manual_seed(42)

# Parameters
batch_size = 2
seq_len = 3
d_model = 4
d_k = 4  # Must equal d_model for single-head attention in PyTorch
d_v = 4  # Must equal d_model for single-head attention in PyTorch

# Toy input
X = torch.randn(batch_size, seq_len, d_model)

# Initialize models
my_attn = SingleHeadAttention(d_model, d_k, d_v)
torch_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=1, bias=False, batch_first=True)

# Copy weights from torch imp to my imp before any forward pass
combined_weights = torch_attn.in_proj_weight.data  # Shape: [3 * d_model, d_model]
Q_W_torch, K_W_torch, V_W_torch = combined_weights.chunk(3)  # Each is [d_model, d_model]

my_attn.Q_W.weight.data = Q_W_torch.clone()
my_attn.K_W.weight.data = K_W_torch.clone()
my_attn.V_W.weight.data = V_W_torch.clone()
my_attn.out_proj.weight.data = torch_attn.out_proj.weight.data.clone()  # Copy output projection weight

# Compute outputs after copying weights
my_output = my_attn(X, X, X)
torch_output, _ = torch_attn(X, X, X)

print(f"My Implementation's Output Shape: {my_output.shape}")
print(f"Torch Implementation's Output Shape: {torch_output.shape}")
print(f"Output Difference: {torch.abs(my_output - torch_output).max().item()}")

My Implementation's Output Shape: torch.Size([2, 3, 4])
Torch Implementation's Output Shape: torch.Size([2, 3, 4])
Output Difference: 5.960464477539063e-08


In [4]:
# Set up for gradcheck
input = torch.randn(batch_size, seq_len, d_model, dtype=torch.double, requires_grad=True)
my_attn_double = SingleHeadAttention(d_model, d_k, d_v)
my_attn_double.to(torch.double)

# Test gradcheck
def get_attn_out(input):
    return my_attn_double(input, input, input)

test = torch.autograd.gradcheck(get_attn_out, input, eps=1e-6, atol=1e-4)
print("Gradcheck passes:", test)

Gradcheck passes: True
