In [2]:
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import *
from svd_linear import SVDLinear

config = LlamaConfig()
down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) #config.mlp_bias)
q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

input_tensor = torch.randn(1, config.intermediate_size)
residual = torch.randn(1, config.hidden_size)

# Original
down_proj_output = down_proj(input_tensor) + residual
hidden_states = post_attention_layernorm(down_proj_output)
query_states = q_proj(hidden_states)

In [3]:
down_proj_svd = SVDLinear.from_linear(down_proj, 0.9999, sigma_fuse="V")

SVD Rank: 4096


In [13]:
# hidden_states_svd = down_proj_svd(input_tensor)
BLinear = down_proj_svd.BLinear
ALinear = down_proj_svd.ALinear
rms_weight = post_attention_layernorm.weight
rms_variance_epsilon = post_attention_layernorm.variance_epsilon


hidden_states_svd = BLinear(input_tensor)

# residual
residual_svd = residual @ torch.linalg.inv(ALinear.weight.T)
print(residual_svd.shape)
hidden_states_svd = hidden_states_svd + residual_svd

# RMNS
variance = hidden_states_svd.pow(2).mean(-1, keepdim=True)
hidden_states_svd = hidden_states_svd * torch.rsqrt(variance + rms_variance_epsilon)

hidden_states_svd = ALinear(hidden_states_svd)

q_weight = torch.diag(rms_weight) @ q_proj.weight

query_states_svd = hidden_states_svd @ q_weight.T

assert torch.allclose(query_states, query_states_svd, rtol=1e-2, atol=1e-2), "Query outputs are not close enough."

torch.Size([1, 4096])


In [7]:
print(query_states, query_states_svd)

tensor([[-0.2041, -0.5595,  0.4613,  ..., -0.6617,  0.7821, -0.2123]],
       grad_fn=<MmBackward0>) tensor([[ 0.2511, -0.3472, -0.7434,  ..., -0.3110,  0.2701,  0.2605],
        [ 0.2107, -0.3737, -0.7364,  ..., -0.2949,  0.2445,  0.2642],
        [ 0.2116, -0.3664, -0.7571,  ..., -0.2833,  0.2490,  0.2561],
        ...,
        [ 0.2097, -0.3602, -0.7262,  ..., -0.2962,  0.2775,  0.2271],
        [ 0.2417, -0.3902, -0.7415,  ..., -0.2990,  0.2328,  0.2796],
        [ 0.2097, -0.4000, -0.7214,  ..., -0.3018,  0.2791,  0.2286]],
       grad_fn=<MmBackward0>)
