In [61]:
import time
import torch
import torch.nn as nn
from einops import einsum, rearrange

device = 'mps'
if device == 'mps':
    func = torch.mps.synchronize
else:
    func = torch.cpu.synchronize

iter_num = 1
num = 10
N0_ = 2 ** num
N1_ = 1

# 权重输入通道旋转

```python
@torch.no_grad()
def rotate_mlp_input(layer, Q):
    # Rotate the MLP input weights.
    mlp_inputs = [layer.mlp.up_proj, layer.mlp.gate_proj]
    for W in mlp_inputs:
        dtype = W.weight.dtype
        W_ = W.weight.data.to(dtype=torch.float32)
        W.weight.data = torch.matmul(W_, Q).to(dtype=dtype)
```

In [62]:
for i in range(num + 1):
    N0 = N0_ // (2 ** i)
    N1 = N1_ * (2 ** i)
    R0 = torch.randn(N0, N0, device=device, dtype=torch.float32)
    R1 = torch.randn(N1, N1, device=device, dtype=torch.float32)
    # 权重的形状是(C_out, C_in)
    weight = torch.randn(N0 * N1, N0 * N1, device=device, dtype=torch.float32)
    
    # # 在沿着输入维度旋转的操作
    with torch.inference_mode():
        Q = torch.kron(R0, R1)
        start = time.time()
        for _ in range(iter_num):
            output_baseline = weight @ Q
        func()
    
    # print(20 * '==')
        start = time.time()
        for _ in range(iter_num):
            output2 = torch.einsum('ij,aik,km->ajm', R0, rearrange(weight, 'b (h c)->b h c', h=N0), R1).reshape(N0 * N1, -1)
        func()
    
    # print(20 * '==')
        start = time.time()
        for _ in range(iter_num):
            # output3 = einsum(R0, rearrange(weight, 'b (h c)->b h c', h=N0), R1, 'i j, a i k, k m -> a j m').reshape(N0 * N1, -1)
            output3 = R0.t() @ rearrange(weight, 'b (h c)->b h c', h=N0) @ R1
            output3 = output3.reshape(N0 * N1, -1)
        func()
        
        diff1 = (output_baseline - output2).abs().max().detach().cpu().item()
        diff2 = (output2 - output3).abs().max().detach().cpu().item()
        print(f"N0: {N0:<5}, N1: {N1:<5} diff1: {diff1:.5f}  diff2: {diff2:.5f}")

N0: 1024 , N1: 1     diff1: 0.00029  diff2: 0.00022
N0: 512  , N1: 2     diff1: 0.00055  diff2: 0.00024
N0: 256  , N1: 4     diff1: 0.00021  diff2: 0.00008
N0: 128  , N1: 8     diff1: 0.00021  diff2: 0.00007
N0: 64   , N1: 16    diff1: 0.00026  diff2: 0.00000
N0: 32   , N1: 32    diff1: 0.00031  diff2: 0.00000
N0: 16   , N1: 64    diff1: 0.00035  diff2: 0.00000
N0: 8    , N1: 128   diff1: 0.00031  diff2: 0.00000
N0: 4    , N1: 256   diff1: 0.00029  diff2: 0.00000
N0: 2    , N1: 512   diff1: 0.00025  diff2: 0.00000
N0: 1    , N1: 1024  diff1: 0.00009  diff2: 0.00000


# 权重输出通道旋转

```python
@torch.no_grad()
def rotate_mlp_output(layer, Q):
    # Rotate the MLP output weights and bias.
    W = layer.mlp.down_proj
    dtype = W.weight.data.dtype
    W_ = W.weight.data.to(dtype=torch.float32)
    W.weight.data = torch.matmul(Q.T, W_).to(dtype=dtype)
    if W.bias is not None:
        raise NotImplementedError
```

In [63]:
for i in range(num + 1):
    N0 = N0_ // (2 ** i)
    N1 = N1_ * (2 ** i)
    R0 = torch.randn(N0, N0, device=device, dtype=torch.float32)
    R1 = torch.randn(N1, N1, device=device, dtype=torch.float32)
    # 权重的形状是(C_out, C_in)
    weight = torch.randn(N0 * N1, N0 * N1, device=device, dtype=torch.float32)
    
    # # 在沿着输入维度旋转的操作
    with torch.inference_mode():
        Q = torch.kron(R0, R1)
        start = time.time()
        for _ in range(iter_num):
            output_baseline = Q.t() @ weight
        func()
    
    # print(20 * '==')
        start = time.time()
        for _ in range(iter_num):
            output2 = torch.einsum('ij,ika,km->jma', R0, rearrange(weight, '(h c) b->h c b', h=N0), R1).reshape(-1, N0 * N1)
        func()
    
    # print(20 * '==')
        start = time.time()
        for _ in range(iter_num):
            output3 = torch.einsum('ij,ika,km->jma', R0, rearrange(weight, '(h c) b->h c b', h=N0), R1).reshape(-1, N0 * N1)

        func()
        
        diff1 = (output_baseline - output2).abs().max().detach().cpu().item()
        diff2 = (output_baseline - output3).abs().max().detach().cpu().item()
        
        print(f"N0: {N0:<5}, N1: {N1:<5} diff1: {diff1:.5f}  diff2: {diff2:.5f} sum1: {output_baseline.sum() - output3.sum()}")

N0: 1024 , N1: 1     diff1: 0.00041  diff2: 0.00041 sum1: 0.03125
N0: 512  , N1: 2     diff1: 0.00026  diff2: 0.00026 sum1: 0.025390625
N0: 256  , N1: 4     diff1: 0.00021  diff2: 0.00021 sum1: 0.0078125
N0: 128  , N1: 8     diff1: 0.00034  diff2: 0.00034 sum1: 0.00390625
N0: 64   , N1: 16    diff1: 0.00023  diff2: 0.00023 sum1: -0.013671875
N0: 32   , N1: 32    diff1: 0.00024  diff2: 0.00024 sum1: 0.015625
N0: 16   , N1: 64    diff1: 0.00031  diff2: 0.00031 sum1: -0.05078125
N0: 8    , N1: 128   diff1: 0.00027  diff2: 0.00027 sum1: 0.009765625
N0: 4    , N1: 256   diff1: 0.00037  diff2: 0.00037 sum1: 0.0
N0: 2    , N1: 512   diff1: 0.00028  diff2: 0.00028 sum1: 0.0234375
N0: 1    , N1: 1024  diff1: 0.00001  diff2: 0.00001 sum1: 0.002685546875


# 输入旋转 XQ

In [71]:
for i in range(num + 1):
    batch_size = 4
    token_num = 10
    N0 = N0_ // (2 ** i)
    N1 = N1_ * (2 ** i)
    R0 = torch.randn(N0, N0, device=device, dtype=torch.float32)
    R1 = torch.randn(N1, N1, device=device, dtype=torch.float32)
    # 权重的形状是(batch_size, token_num, N0 * N1)
    data = torch.randn(batch_size, token_num, N0 * N1, device=device, dtype=torch.float32)
    
    # # 在沿着输入维度旋转的操作
    with torch.inference_mode():
        Q = torch.kron(R0, R1)
        start = time.time()
        for _ in range(iter_num):
            output_baseline = data @ Q
        func()
    
        start = time.time()
        for _ in range(iter_num):
            output2 = torch.einsum('ij,aik,km->ajm', R0, rearrange(data, 'b h (t p)->(b h) t p', t=N0), R1)
            output2 = output2.reshape(batch_size, token_num, -1)
        func()

        
        diff1 = (output_baseline - output2).abs().max().detach().cpu().item()
        print(f"N0: {N0:<5}, N1: {N1:<5} diff1: {diff1:.5f}   sum: {output_baseline.sum() - output2.sum():.5f}")

N0: 1024 , N1: 1     diff1: 0.00008   sum: -0.00098
N0: 512  , N1: 2     diff1: 0.00021   sum: -0.00269
N0: 256  , N1: 4     diff1: 0.00018   sum: 0.00049
N0: 128  , N1: 8     diff1: 0.00019   sum: -0.00391
N0: 64   , N1: 16    diff1: 0.00016   sum: -0.00024
N0: 32   , N1: 32    diff1: 0.00023   sum: -0.00391
N0: 16   , N1: 64    diff1: 0.00020   sum: 0.00342
N0: 8    , N1: 128   diff1: 0.00044   sum: -0.00195
N0: 4    , N1: 256   diff1: 0.00038   sum: -0.00073
N0: 2    , N1: 512   diff1: 0.00048   sum: -0.00781
N0: 1    , N1: 1024  diff1: 0.00000   sum: 0.00000
