In [7]:
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import *
from svd_linear import SVDLinear

config = LlamaConfig()
down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

input_tensor = torch.randn(1, config.intermediate_size)

# Original
down_proj_output = down_proj(input_tensor)

hidden_states = post_attention_layernorm(down_proj_output)

# variance = down_proj_output.pow(2).mean(-1, keepdim=True)
# hidden_states = down_proj_output * torch.rsqrt(variance + post_attention_layernorm.variance_epsilon)
# hidden_states = post_attention_layernorm.weight * hidden_states

query_states = q_proj(hidden_states)

In [None]:
down_proj_svd = SVDLinear.from_linear_rank_ratio(down_proj, 0.9999)

In [24]:
# hidden_states_svd = down_proj_svd(input_tensor)
BLinear = down_proj_svd.BLinear
ALinear = down_proj_svd.ALinear

hidden_states_svd = BLinear(input_tensor)

# hidden_states_svd = post_attention_layernorm(down_proj_output)
weight = post_attention_layernorm.weight
variance_epsilon = post_attention_layernorm.variance_epsilon

variance = hidden_states_svd.pow(2).mean(-1, keepdim=True)
hidden_states_svd = hidden_states_svd * torch.rsqrt(variance + variance_epsilon)

hidden_states_svd = ALinear(hidden_states_svd)

q_weight = torch.diag(weight) @ q_proj.weight

query_states_svd = hidden_states_svd @ q_weight.T

assert torch.allclose(query_states, query_states_svd, rtol=1e-2, atol=1e-2), "Query outputs are not close enough."

In [21]:
print(query_states, query_states_svd)

tensor([[ 0.1284, -0.2965, -0.5730,  ...,  0.1236,  0.4075,  0.5450]],
       grad_fn=<MmBackward0>) tensor([[ 0.1284, -0.2965, -0.5730,  ...,  0.1236,  0.4075,  0.5450]],
       grad_fn=<MmBackward0>)
