In [1]:
import torch
from torch import nn
import torch.functional as F
import math
import numpy as np
import pandas as pd

In [2]:
# Load a matrix from a CSV file
def load_csv_as_tensor(file_path):
    df = pd.read_csv(file_path, header=None)
    matrix = torch.tensor(df.values, dtype=torch.float32)
    return matrix

# class CustomMultiHeadAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads):
#         super().__init__()
#         self.mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

#     def set_custom_weights(self, Wqkv, Wo, bqkv, bo):
#         d_k = Wqkv.size(1) // 3
#         self.mha.in_proj_weight.data = Wqkv
#         self.mha.in_proj_bias.data = bqkv
#         self.mha.out_proj.weight.data = Wo
#         self.mha.out_proj.bias.data = bo

#     def forward(self, x):
#         output, attn_weights = self.mha(x, x, x)
#         return output

class CustomMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, Wqkv=None, Wo=None, bqkv=None, bo=None):
        super(CustomMultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        print("embed_dim", embed_dim)
        print("num_heads", self.num_heads)
        print("head_dim", self.head_dim)

        # Define the projections for queries, keys, and values
        self.proj_q = torch.zeros((embed_dim, embed_dim))
        self.proj_k = torch.zeros((embed_dim, embed_dim))
        self.proj_v = torch.zeros((embed_dim, embed_dim))

        # Define the output projection
        self.proj_out = torch.zeros((embed_dim, embed_dim))

        # Scaling factor to prevent the softmax from having too large/small gradients
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim]))
        
        # Optionally set weights and biases
        if Wqkv is not None:
            # Set weights for Q, K, V projections
            self.proj_q = Wqkv[:embed_dim]
            self.proj_k = Wqkv[embed_dim:2*embed_dim]
            self.proj_v = Wqkv[2*embed_dim:3*embed_dim]
            print("weights set")
#             if bqkv is not None:
#                 # Set biases for Q, K, V projections
#                 self.proj_q.bias.data = bqkv[:embed_dim]
#                 self.proj_k.bias.data = bqkv[embed_dim:2*embed_dim]
#                 self.proj_v.bias.data = bqkv[2*embed_dim:3*embed_dim]
        if Wo is not None and bo is not None:
            # Set weights and biases for output projection
            self.proj_out = Wo
#             self.proj_out.bias.data = bo

    def forward(self, x):
        seq_length, embed_dim = x.size()

        # Project the queries, keys, and values
        # x: (seq_length, embed_dim)
        # proj: (embeded_dim, num_heads * head_dim)
        # (x @ proj): (seq_length, num_heads * head_dim)
        # q, k, v: (seq_length, num_heads, head_dim)
        q = (x @ self.proj_q.transpose(0, 1)).view(seq_length, self.num_heads, self.head_dim)
        print(f"q: {q}")
        k = (x @ self.proj_k.transpose(0, 1)).view(seq_length, self.num_heads, self.head_dim)
        print(f"k: {k}")
        v = (x @ self.proj_v.transpose(0, 1)).view(seq_length, self.num_heads, self.head_dim)
        print(f"v: {v}")

        # Transpose for attention computation: b x n x l x d -> b x l x n x d
        # q, k, v: (num_heads, seq_length, head_dim)
        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)

        # Compute the attention scores
        # scores: (num_heads, seq_length, seq_length)
        scores = torch.matmul(q / self.scale, k.transpose(-2, -1))
        print(f"sqrt(dk): {self.scale}")
        print(f"q * k / sqrt(dk): {scores}")
        attn = torch.nn.functional.softmax(scores, dim=-1)
        print(f"softmax: {attn}")

        # Apply attention to the values
        # context: (num_heads, seq_length, head_dim)
        context = torch.matmul(attn, v)
        print(f"softmax * v: {context}")

        # Concatenate heads
        # contect: (seq_length, embed_dim)
        context = context.transpose(0, 1).contiguous().view(seq_length, embed_dim)
        print(f"softmax * v shape: {context.shape}")

        # Final output projection
        output = context @ self.proj_out.transpose(0, 1)
        return output, attn

# Parameters
embed_dim = 4
num_heads = 2

# Load weights and biases from CSV
base_dir = "/Users/billyli/Documents/UCSD course files/PHYS 244/project/parallel-mha/serial_code/compare_python"

# comopare cpp: Wq d_model, d_model
# in set_weights, it is interpreted (d_model, num_heads*depth)
Wq = load_csv_as_tensor(f'{base_dir}/Wq.csv')
Wk = load_csv_as_tensor(f'{base_dir}/Wk.csv')
Wv = load_csv_as_tensor(f'{base_dir}/Wv.csv')
Wo = load_csv_as_tensor(f'{base_dir}/Wo.csv')
bq = load_csv_as_tensor(f'{base_dir}/bq.csv')
bk = load_csv_as_tensor(f'{base_dir}/bk.csv')
bv = load_csv_as_tensor(f'{base_dir}/bv.csv')
bo = load_csv_as_tensor(f'{base_dir}/bo.csv')

print("Wq", Wq)
print("Wq size", Wq.size())

# Concatenate Wq, Wk, Wv into a single weight matrix for in_proj
Wqkv = torch.cat((Wq, Wk, Wv), dim=0)
bqkv = torch.cat((bq, bk, bv), dim=0)

# Initialize the model
mha = CustomMultiHeadAttention(embed_dim, num_heads, Wqkv=Wqkv, Wo=Wo, bqkv=bqkv, bo=bo)

# Load input
x = load_csv_as_tensor('mha_input.csv')  # Add batch dimension
print(x.shape)

# Compute output
output = mha(x)

# Print output
print("Output Tensor:\n", output[0])
print("Output Tensor:\n", output[0].shape)

print("Output Attention:\n", output[1])

Wq tensor([[-0.3575,  0.5264,  0.2733,  0.2968],
        [ 0.0322, -0.3140,  0.1036, -0.5718],
        [-0.4425, -0.3157,  0.3864,  0.3590],
        [-0.2716, -0.0221,  0.3916,  0.6088]])
Wq size torch.Size([4, 4])
embed_dim 4
num_heads 2
head_dim 2
weights set
torch.Size([2, 4])
q: tensor([[[-1.1320, -0.4090],
         [-1.2269, -0.9193]],

        [[ 0.0104, -0.0332],
         [ 1.3781,  0.9936]]])
k: tensor([[[-0.7493,  1.0601],
         [ 1.2502,  0.5704]],

        [[-0.5158,  0.1040],
         [-0.6362, -0.1282]]])
v: tensor([[[ 1.3135, -0.6319],
         [ 0.0035, -2.4008]],

        [[-1.6128, -0.3029],
         [ 0.3143, -0.4661]]])
sqrt(dk): tensor([1.4142])
q * k / sqrt(dk): tensor([[[ 0.2932,  0.3828],
         [-0.0304, -0.0062]],

        [[-1.4555,  0.6353],
         [ 1.6191, -0.7100]]])
softmax: tensor([[[0.4776, 0.5224],
         [0.4940, 0.5060]],

        [[0.1100, 0.8900],
         [0.9113, 0.0887]]])
softmax * v: tensor([[[-0.2152, -0.4600],
         [-0.1673, -0.

In [3]:
expected_out = load_csv_as_tensor('mha_output.csv')
expected_out

tensor([[-0.1277, -0.5129,  0.4347, -0.2219],
        [ 0.3021, -1.0609,  1.0369, -0.8575]])