In [37]:
import torch
from torch import nn
import torch.functional as F
import math
import numpy as np
import pandas as pd

In [72]:
# Load a matrix from a CSV file
def load_csv_as_tensor(file_path):
    df = pd.read_csv(file_path, header=None)
    matrix = torch.tensor(df.values, dtype=torch.float32)
    return matrix

# class CustomMultiHeadAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super().__init__()
#         self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

#     def set_custom_weights(self, Wqkv, Wo, bqkv, bo):
#         d_k = Wqkv.size(1) // 3
#         self.mha.in_proj_weight.data = Wqkv
#         self.mha.in_proj_bias.data = bqkv
#         self.mha.out_proj.weight.data = Wo
#         self.mha.out_proj.bias.data = bo

#     def forward(self, x):
#         output, attn_weights = self.mha(x, x, x)
#         return output

class CustomMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, Wqkv=None, Wo=None, bqkv=None, bo=None):
        super(CustomMultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Define the projections for queries, keys, and values
        self.proj_q = torch.zeros((embed_dim, embed_dim))
        self.proj_k = torch.zeros((embed_dim, embed_dim))
        self.proj_v = torch.zeros((embed_dim, embed_dim))

        # Define the output projection
        self.proj_out = torch.zeros((embed_dim, embed_dim))

        # Scaling factor to prevent the softmax from having too large/small gradients
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
        
        # Optionally set weights and biases
        if Wqkv is not None:
            # Set weights for Q, K, V projections
            self.proj_q = Wqkv[:embed_dim]
            self.proj_k = Wqkv[embed_dim:2*embed_dim]
            self.proj_v = Wqkv[2*embed_dim:3*embed_dim]
            print("weights set")
#             if bqkv is not None:
#                 # Set biases for Q, K, V projections
#                 self.proj_q.bias.data = bqkv[:embed_dim]
#                 self.proj_k.bias.data = bqkv[embed_dim:2*embed_dim]
#                 self.proj_v.bias.data = bqkv[2*embed_dim:3*embed_dim]
        if Wo is not None and bo is not None:
            # Set weights and biases for output projection
            self.proj_out = Wo
#             self.proj_out.bias.data = bo

    def forward(self, x):
        seq_length, embed_dim = x.size()

        # Project the queries, keys, and values
        q = (x @ self.proj_q).view(seq_length, self.num_heads, self.head_dim)
        print(f"q: {q}")
        k = (x @ self.proj_k).view(seq_length, self.num_heads, self.head_dim)
        print(f"k: {k}")
        v = (x @ self.proj_v).view(seq_length, self.num_heads, self.head_dim)
        print(f"v: {v}")

        # Transpose for attention computation: b x n x l x d -> b x l x n x d
#         q = q.transpose(1, 2)
#         k = k.transpose(1, 2)
#         v = v.transpose(1, 2)

        # Compute the attention scores
        scores = torch.matmul(q / self.scale, k.transpose(-2, -1))
        print(f"sqrt(dk): {self.scale}")
        print(f"q * k / sqrt(dk): {scores}")
        attn = torch.nn.functional.softmax(scores, dim=-1)
        print(f"softmax: {attn}")

        # Apply attention to the values
        context = torch.matmul(attn, v)
        print(f"softmax * v: {context}")

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(seq_length, embed_dim)
        print(f"softmax * v shape: {context.shape}")

        # Final output projection
        output = context @ self.proj_out
        return output, attn

# Parameters
embed_dim = 4
num_heads = 2

# Load weights and biases from CSV
base_dir = "/Users/zhaozihan/Desktop/PHYS 244/parallel-mha/serial_code/compare_python"

Wq = load_csv_as_tensor(f'{base_dir}/Wq.csv')
Wk = load_csv_as_tensor(f'{base_dir}/Wk.csv')
Wv = load_csv_as_tensor(f'{base_dir}/Wv.csv')
Wo = load_csv_as_tensor(f'{base_dir}/Wo.csv')
bq = load_csv_as_tensor(f'{base_dir}/bq.csv')
bk = load_csv_as_tensor(f'{base_dir}/bk.csv')
bv = load_csv_as_tensor(f'{base_dir}/bv.csv')
bo = load_csv_as_tensor(f'{base_dir}/bo.csv')

# Concatenate Wq, Wk, Wv into a single weight matrix for in_proj
Wqkv = torch.cat((Wq, Wk, Wv), dim=0)
bqkv = torch.cat((bq, bk, bv), dim=0)

# Initialize the model
mha = CustomMultiHeadAttention(embed_dim, num_heads, Wqkv=Wqkv, Wo=Wo, bqkv=bqkv, bo=bo)

# Load input
x = load_csv_as_tensor('mha_input.csv')  # Add batch dimension
print(x.shape)

# Compute output
output = mha(x)

# Print output
print("Output Tensor:\n", output[0])
print("Output Tensor:\n", output[0].shape)

weights set
torch.Size([2, 4])
q: tensor([[[ 0.2494,  1.5786],
         [-0.2285,  0.1889]],

        [[-0.0634, -0.2776],
         [ 0.0428,  1.1328]]])
k: tensor([[[-0.0716, -0.3137],
         [ 1.5609, -0.7852]],

        [[-0.0297,  0.6089],
         [ 0.0216,  1.0270]]])
v: tensor([[[-0.0751,  2.2267],
         [-0.7416, -0.0169]],

        [[-1.1777, -1.1061],
         [ 0.1794, -1.1788]]])
sqrt(dk): tensor([1.4142])
q * k / sqrt(dk): tensor([[[-0.3628, -0.6012],
         [-0.0303, -0.3571]],

        [[-0.1182, -0.2026],
         [ 0.4868,  0.8233]]])
softmax: tensor([[[0.5593, 0.4407],
         [0.5810, 0.4190]],

        [[0.5211, 0.4789],
         [0.4167, 0.5833]]])
softmax * v: tensor([[[-0.3688,  1.2380],
         [-0.3544,  1.2866]],

        [[-0.5278, -1.1409],
         [-0.3860, -1.1485]]])
softmax * v shape: torch.Size([2, 4])
Output Tensor:
 tensor([[-0.6420, -0.3552,  0.4712, -0.0707],
        [ 0.2790, -0.2336, -0.0741,  0.0035]])
Output Tensor:
 torch.Size([2, 4])

In [17]:
expected_out = load_csv_as_tensor('mha_output.csv')
expected_out

tensor([[-0.1277, -0.5129,  0.4347, -0.2219],
        [ 0.3021, -1.0609,  1.0369, -0.8575]])