In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np


In [11]:
# 1. Setup Hyperparameters
B, S, N, H = 2, 6, 4, 2 ## Batch, Seq, Features (d_model), Heads

# 2. Define Input X
X = torch.randn(B, S, N, dtype=torch.float32).round(decimals=2)
X

tensor([[[ 2.3100, -0.8800, -0.1000, -0.8100],
         [ 1.2000, -0.1900, -2.4200, -0.6300],
         [ 0.3900, -0.5800, -0.3900,  0.8000],
         [-1.4700,  0.7700,  0.2000,  0.7900],
         [-1.9700,  0.4300,  1.1700,  0.4100],
         [-1.0100,  0.0800, -0.1700, -0.5400]],

        [[ 0.6400,  1.5200,  0.0700, -1.0900],
         [-0.0500,  2.1200,  1.2900,  0.3000],
         [-0.8300,  1.2300,  0.6700,  0.8800],
         [-0.6000, -4.2100, -1.8400,  0.9100],
         [ 1.4800,  1.6800, -0.7000,  0.0600],
         [-0.4900, -1.7600, -1.0200, -1.0100]]])

In [12]:
# 3. Define Weight Matrices (from Cell 24)
# Note: We transpose (.T) because PyTorch weights are (out_features, in_features)
W_q = torch.tensor([[ 0.1, 0.27, -0.35, -0.11], [-1.86, -0.25, -1.25, -0.36], [-0.71, 1.71, -1.17, 1.05], [ 0.08, 1.76, 0.53, -0.1 ]]).T
W_k = torch.tensor([[ 0.7, -1.52, 2.67, 0.06], [-1.46, 2.02, -0.78, -0.83], [ 1.39, -1.07, -0.52, 1.11], [ 1.73, 1.05, 1.47, 1.17 ]]).T
W_v = torch.tensor([[-1.26, 0.25, 0.03, 1.92], [ 0.58, -0.25, -0.46, 1.69], [ 0.08, -1.85, -0.71, 0.3 ], [ 1.64, -0.9, 1.04, 0.72 ]]).T

In [15]:
# 4. Initialize the PyTorch Layer
# batch_first=True makes it (B, S, N)
mha_layer = nn.MultiheadAttention(embed_dim=N,num_heads=H,batch_first=True, bias=False)

In [17]:
# 5. Manually Load Weights
# PyTorch combines Q,K,V weights into one "in_proj_weight" tensor

with torch.no_grad():
    mha_layer.in_proj_weight.copy_(torch.cat([W_q, W_k, W_v]))
    # set output projection to identity to match notebook logic
    mha_layer.out_proj.weight.copy_(torch.eye(N))


# 6. Run the Attention forward pass
# returns (output, weights) as tuple
attention_res = mha_layer(X, X, X)


# --- THE PRINTS YOU WANTED ---

print(f"Attention Res: {attention_res}")

print("=== ATTENTION_RES[0] (Concatenated Output) ===")
print(f"Shape: {attention_res[0].shape}") # (2, 6, 4)
print(attention_res[0]) 

print("\n=== ATTENTION_RES[1] (Weight Matrix) ===")
print(f"Shape: {attention_res[1].shape}") # (2, 6, 6)
print(attention_res[1])

# This is how your residual line works:
x_updated = X + attention_res[0]
print("\n=== X + ATTENTION_RES[0] (Residual Connection) ===")
print(x_updated)

Attention Res: (tensor([[[-4.1946,  1.5485,  0.1938, -0.4855],
         [-4.7574,  1.7115,  1.0208,  0.9742],
         [ 2.9631, -1.8571,  0.0343,  1.9806],
         [ 1.1111,  0.7130, -0.1266, -1.4250],
         [ 3.2098, -1.3483, -0.5693, -2.2644],
         [-4.7355,  1.7198,  0.4110,  0.4292]],

        [[-1.8944,  1.7517, -0.8658,  2.2282],
         [-1.6384,  0.6529, -0.5742,  2.0830],
         [-0.8443,  0.1915, -0.7233,  1.8812],
         [-0.3406,  3.4875,  4.0624, -7.8203],
         [-1.7342,  1.0133, -1.4312,  2.2687],
         [-0.3406,  3.4875,  2.9986, -4.4684]]], grad_fn=<TransposeBackward0>), tensor([[[5.2302e-01, 7.2577e-02, 1.3457e-01, 8.9100e-02, 9.4946e-02,
          8.5791e-02],
         [5.5581e-01, 4.4412e-01, 6.1800e-05, 5.5316e-09, 3.4390e-11,
          6.4870e-08],
         [3.9126e-01, 9.4588e-02, 1.0231e-01, 2.1267e-01, 1.9389e-01,
          5.2844e-03],
         [2.1251e-02, 1.7069e-01, 3.6781e-02, 3.4365e-01, 2.0563e-01,
          2.2199e-01],
         [3.9