In [3]:
!pip install -q transformers torch


In [4]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [5]:
torch.manual_seed(0)
print(torch.rand(3))
torch.cuda.manual_seed_all(0)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

tensor([0.4963, 0.7682, 0.0885])


In [6]:
model_name = "gpt2"   

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model.eval()  

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [7]:
text = "Artificial intelligence is"

inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=20,
        do_sample=False,  
        pad_token_id=tokenizer.eos_token_id
    )

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Artificial intelligence is a new field of research that has been in the works for a while now. It is a field


In [8]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [9]:
block = model.transformer.h[0]
print(block)

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D(nf=2304, nx=768)
    (c_proj): Conv1D(nf=768, nx=768)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D(nf=3072, nx=768)
    (c_proj): Conv1D(nf=768, nx=3072)
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)


In [10]:
fc_weight = block.mlp.c_fc.weight      
proj_weight = block.mlp.c_proj.weight  

print("c_fc weight shape:", fc_weight.shape)
print("c_proj weight shape:", proj_weight.shape)


c_fc weight shape: torch.Size([768, 3072])
c_proj weight shape: torch.Size([3072, 768])


In [11]:
fc_fro_norm = torch.norm(fc_weight, p="fro")
proj_fro_norm = torch.norm(proj_weight, p="fro")

print("Frobenius norm of c_fc:", fc_fro_norm.item())
print("Frobenius norm of c_proj:", proj_fro_norm.item())


Frobenius norm of c_fc: 216.8270721435547
Frobenius norm of c_proj: 135.10826110839844


In [12]:
!pip install einops



In [13]:
from einops import rearrange

def kronecker_decompose(W, m, n, k=1, niter=10):
    """
    Van Loan Kronecker decomposition.
    
    W: weight matrix (torch.Tensor)
    m, n: dimensions of first Kronecker factor
    k: number of Kronecker factors (start with 1)
    """

    out_dim, in_dim = W.shape

    m2 = out_dim // m
    n2 = in_dim // n

    assert m * m2 == out_dim
    assert n * n2 == in_dim


    W_re = rearrange(
        W,
        '(m m2) (n n2) -> (m n) (m2 n2)',
        m=m, m2=m2, n=n, n2=n2
    )


    U, S, V = torch.svd_lowrank(W_re, q=k, niter=niter)


    A = rearrange(U, '(m n) k -> k m n', m=m, n=n)
    B = rearrange(V, '(m2 n2) k -> k m2 n2', m2=m2, n2=n2)

    scale = S.sqrt().view(-1, 1, 1)

    return A * scale, B * scale


In [14]:
A, B = kronecker_decompose(
    fc_weight,
    m=768,
    n=1536,
    k=1
)

print("A shape:", A.shape)
print("B shape:", B.shape)


A shape: torch.Size([1, 768, 1536])
B shape: torch.Size([1, 1, 2])


In [15]:
W_hat = torch.kron(A[0], B[0])

print("Original shape:", fc_weight.shape)
print("Reconstructed shape:", W_hat.shape)


Original shape: torch.Size([768, 3072])
Reconstructed shape: torch.Size([768, 3072])


In [16]:
recon_error = torch.norm(fc_weight - W_hat, p="fro")
orig_norm = torch.norm(fc_weight, p="fro")

print("Original Frobenius norm:", orig_norm.item())
print("Reconstruction error:", recon_error.item())
print("Relative error:", (recon_error / orig_norm).item())


Original Frobenius norm: 216.8270721435547
Reconstruction error: 151.0072479248047
Relative error: 0.6964409351348877


In [17]:
W_hat_norm_before = torch.norm(W_hat, p="fro")
W_orig_norm = torch.norm(fc_weight, p="fro")

print("Original W Frobenius norm:", W_orig_norm.item())
print("Kronecker Ŵ Frobenius norm (before):", W_hat_norm_before.item())
print("Norm ratio (Ŵ / W):", (W_hat_norm_before / W_orig_norm).item())


Original W Frobenius norm: 216.8270721435547
Kronecker Ŵ Frobenius norm (before): 155.60047912597656
Norm ratio (Ŵ / W): 0.7176247835159302


In [18]:
alpha = W_orig_norm / W_hat_norm_before
print("Adaptive scaling factor α:", alpha.item())


Adaptive scaling factor α: 1.3934859037399292


In [19]:
sqrt_alpha = torch.sqrt(alpha)

A_norm = A * sqrt_alpha
B_norm = B * sqrt_alpha


In [20]:
W_hat_norm = torch.kron(A_norm[0], B_norm[0])

print("Reconstructed normalized shape:", W_hat_norm.shape)


Reconstructed normalized shape: torch.Size([768, 3072])


In [21]:
new_norm = torch.norm(W_hat_norm, p="fro")

print("Original Frobenius norm:", W_orig_norm.item())
print("Normalized Kronecker Frobenius norm:", new_norm.item())
print("Absolute difference:", abs(W_orig_norm - new_norm).item())


Original Frobenius norm: 216.8270721435547
Normalized Kronecker Frobenius norm: 216.82717895507812
Absolute difference: 0.0001068115234375


In [22]:
import torch.nn as nn

class KroneckerLinear(nn.Module):
    def __init__(self, A, B):
        """
        A: (in_features, out_factor1)  → (768, 1536)
        B: (1, out_factor2)            → (1, 2)
        """
        super().__init__()

        self.A = nn.Parameter(A)   # (768, 1536)
        self.B = nn.Parameter(B)   # (1, 2)

    def forward(self, x):
        """
        x: (batch, seq_len, 768)
        """
        # Step 1: standard linear with A
        # result: (batch, seq_len, 1536)
        y = torch.matmul(x, self.A)

        # Step 2: Kronecker expansion with B
        # expand last dimension
        # (batch, seq_len, 1536, 1) × (1, 2) → (batch, seq_len, 1536, 2)
        y = y.unsqueeze(-1) * self.B

        # flatten last two dims → 1536 × 2 = 3072
        y = y.reshape(y.shape[0], y.shape[1], -1)

        return y



In [23]:
A_fc = A_norm[0]  # (768, 1536)
B_fc = B_norm[0]  # (1, 2)


In [24]:
# Save original layer (for comparison if needed)
original_fc = model.transformer.h[0].mlp.c_fc

# Replace with KroneckerLinear
model.transformer.h[0].mlp.c_fc = KroneckerLinear(A_fc, B_fc)

print(model.transformer.h[0].mlp)


GPT2MLP(
  (c_fc): KroneckerLinear()
  (c_proj): Conv1D(nf=768, nx=3072)
  (act): NewGELUActivation()
  (dropout): Dropout(p=0.1, inplace=False)
)


In [25]:
test_text = "Artificial intelligence is"
inputs = tokenizer(test_text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

logits = outputs.logits

print("Logits shape:", logits.shape)


Logits shape: torch.Size([1, 4, 50257])


In [26]:
block = model.transformer.h[0]
proj_weight = block.mlp.c_proj.weight  # shape: (768, 3072)

print(proj_weight.shape)


torch.Size([3072, 768])


In [27]:
A_p, B_p = kronecker_decompose(
    proj_weight.T,   # transpose
    m=768,
    n=1536,
    k=1
)


In [28]:
W_hat_p = torch.kron(A_p[0], B_p[0])

orig_norm_p = torch.norm(proj_weight.T, p="fro")
hat_norm_p = torch.norm(W_hat_p, p="fro")

alpha_p = orig_norm_p / hat_norm_p
sqrt_alpha_p = torch.sqrt(alpha_p)

A_p_norm = A_p * sqrt_alpha_p
B_p_norm = B_p * sqrt_alpha_p


In [29]:
A_proj = A_p_norm[0]  # (768, 1536)
B_proj = B_p_norm[0]  # (1, 2)


In [30]:
class KroneckerLinearProj(nn.Module):
    def __init__(self, A, B):
        """
        A: (out_features, mid)  → (768, 1536)
        B: (1, 2)
        """
        super().__init__()
        self.A = nn.Parameter(A)
        self.B = nn.Parameter(B)

    def forward(self, x):
        """
        x: (batch, seq_len, 3072)
        """
        # reshape: 3072 → (1536, 2)
        x = x.view(x.shape[0], x.shape[1], -1, 2)

        # apply B
        x = (x * self.B).sum(dim=-1)  # → (batch, seq, 1536)

        # apply A
        x = torch.matmul(x, self.A.t())  # → (batch, seq, 768)

        return x


In [31]:
model.transformer.h[0].mlp.c_proj = KroneckerLinearProj(A_proj, B_proj)


In [32]:
test_text = "Artificial intelligence is"
inputs = tokenizer(test_text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

print("Logits shape:", outputs.logits.shape)


Logits shape: torch.Size([1, 4, 50257])


In [33]:
for layer_idx in range(len(model.transformer.h)):
    block = model.transformer.h[layer_idx]

    # ---- c_fc ----
    W_fc = block.mlp.c_fc.weight

    A_fc, B_fc = kronecker_decompose(
        W_fc,
        m=768,
        n=1536,
        k=1
    )

    W_hat_fc = torch.kron(A_fc[0], B_fc[0])
    alpha_fc = torch.norm(W_fc, p="fro") / torch.norm(W_hat_fc, p="fro")

    A_fc = A_fc * torch.sqrt(alpha_fc)
    B_fc = B_fc * torch.sqrt(alpha_fc)

    block.mlp.c_fc = KroneckerLinear(A_fc[0], B_fc[0])

    # ---- c_proj ----
    W_proj = block.mlp.c_proj.weight

    A_p, B_p = kronecker_decompose(
        W_proj.T,
        m=768,
        n=1536,
        k=1
    )

    W_hat_p = torch.kron(A_p[0], B_p[0])
    alpha_p = torch.norm(W_proj.T, p="fro") / torch.norm(W_hat_p, p="fro")

    A_p = A_p * torch.sqrt(alpha_p)
    B_p = B_p * torch.sqrt(alpha_p)

    block.mlp.c_proj = KroneckerLinearProj(A_p[0], B_p[0])


AttributeError: 'KroneckerLinear' object has no attribute 'weight'