In [1]:
import torch
import torch.nn as nn
import seaborn as sns

In [3]:
def get_close_at(tensor1, tensor2):
    for i in range(100):
        if not torch.allclose(tensor1, tensor2, atol=eval(f"1e-{i}")):
            print(f"Similar to the decimal place : {i}")
            return
    
    print("Similar to the 100th decimal place")

In [None]:
### settings ###

device = torch.device("cuda:0")
# device = torch.device("cpu")

dtype = torch.bfloat16
# dtype = torch.float16
# dtype = torch.float32

random_weights = False

fused_dims = (4096, 1024, 1024)


In [18]:
hidden_states = torch.load("./weights/normed_inputs.pt", map_location= device).to(dtype)
q_proj_weight = torch.load("./weights/q_proj_weight.pt", map_location= device).to(dtype)
k_proj_weight = torch.load("./weights/k_proj_weight.pt", map_location= device).to(dtype)
v_proj_weight = torch.load("./v_proj_weight.pt", map_location= device).to(dtype)

if random_weights:
    stats_hidden_states = (hidden_states.float().mean().item(), hidden_states.float().std().item())
    stats_q_proj_weight = (q_proj_weight.float().mean().item(), q_proj_weight.float().std().item())
    stats_k_proj_weight = (k_proj_weight.float().mean().item(), k_proj_weight.float().std().item())
    stats_v_proj_weight = (v_proj_weight.float().mean().item(), v_proj_weight.float().std().item())


    hidden_states = torch.normal(*stats_hidden_states, size=(1, 4, 4096))
    q_proj_weight = torch.normal(*stats_q_proj_weight, size=(4096, 4096))
    k_proj_weight = torch.normal(*stats_k_proj_weight, size=(1024, 4096))
    v_proj_weight = torch.normal(*stats_v_proj_weight, size=(1024, 4096))
    

print(f"Hidden States Shape = {hidden_states.shape}")
print(f"Query Projection Shape = {q_proj_weight.shape}")
print(f"Key Projection Shape = {k_proj_weight.shape}")
print(f"Value Projection Shape = {v_proj_weight.shape}")

In [20]:
fused_attention_proj = nn.Linear(in_features=4096, out_features=sum(fused_dims), bias = False, device = device)
fused_attention_proj.weight = nn.Parameter(torch.concatenate([q_proj_weight, k_proj_weight, v_proj_weight]))
qkv = fused_attention_proj(hidden_states)
q_fused, k_fused, v_fused = qkv.split(fused_dims, dim = -1)

In [22]:
q_proj = nn.Linear(in_features=4096, out_features=4096, bias = False, device = device)
k_proj = nn.Linear(in_features=4096, out_features=1024, bias = False, device = device)
v_proj = nn.Linear(in_features=4096, out_features=1024, bias = False, device = device)

q_proj.weight = nn.Parameter(q_proj_weight)
k_proj.weight = nn.Parameter(k_proj_weight)
v_proj.weight = nn.Parameter(v_proj_weight)

q_single = q_proj(hidden_states)
k_single = k_proj(hidden_states)
v_single = v_proj(hidden_states)

In [24]:
get_close_at(q_single, q_fused)

Similar to the 100th digit


In [25]:
get_close_at(k_single, k_fused)

Similar to the 100th digit


In [26]:
get_close_at(v_single, v_fused)

Similar to the 100th digit
