In [1]:
import torch as th
th.set_default_dtype(th.float64)
import tensorly as tl
tl.set_backend("pytorch")

In [2]:
s = 8

In [3]:
P = th.nn.Linear(64, 32*s, bias=False)
Pw = P.weight.data
x = th.randn((3, 17, 64))
print(Pw.shape, x.shape)

torch.Size([256, 64]) torch.Size([3, 17, 64])


In [4]:
y_lin = P(x)

In [5]:
print(y_lin.shape)

torch.Size([3, 17, 256])


In [6]:
Pw = Pw.reshape((s, 32, 64))
Xw = x.unsqueeze(2).unsqueeze(2)
y_multi = Xw@Pw.unsqueeze(0).unsqueeze(0).swapaxes(-2, -1)
y_multi = y_multi.squeeze(-2)
print(y_multi.shape)
B, N, r, e = y_multi.shape
y_multi = y_multi.reshape((B, N, r*e))

torch.Size([3, 17, 8, 32])


In [7]:
y_multi.shape

torch.Size([3, 17, 256])

In [8]:
th.allclose(y_lin, y_multi)

True

In [9]:
th.mean((y_lin - y_multi)**2)

tensor(0., grad_fn=<MeanBackward0>)

## Tensorly decomposition

In [10]:
decomp = tl.decomposition.CP(rank=32*s, normalize_factors=False, verbose=False, init="random", tol=1e-24, random_state=42)
_, (P1, P2, P3) = decomp.fit_transform(Pw)

In [11]:
recons_pw = tl.cp_to_tensor((None, (P1, P2, P3)))
th.allclose(recons_pw, Pw)

True

In [12]:
print(P1.shape, P2.shape, P3.shape)

torch.Size([8, 256]) torch.Size([32, 256]) torch.Size([64, 256])


In [13]:
x_ = x.unsqueeze(0).unsqueeze(-2)  # (1, bs, patches, heads, headdim)
preprocess = (
    lambda x: x.unsqueeze(0).unsqueeze(0).unsqueeze(0).permute([-1, 0, 1, 2, 3])
)
P_3 = preprocess(P3)
P_2 = preprocess(P2)
P_1 = preprocess(P1)
inter_1 = x_ @ P_3.swapaxes(-2, -1)
inter_2 = inter_1 @ P_2
op = P_1.swapaxes(-2, -1) @ inter_2
print(op.shape)
R, B, N, e, k = op.shape
op = op.reshape((R, B, N, -1))
op = th.sum(op, dim=0)

torch.Size([256, 3, 17, 8, 32])


In [14]:
th.allclose(y_multi, op)

True

In [15]:
th.mean((y_multi-op)**2)

tensor(1.3726e-23)

# Down forward pass

In [16]:
P = th.nn.Linear(32*s, 64, bias=False)
Pw = P.weight.data
x = th.randn((3, 17, 32*s))
print(Pw.shape, x.shape)

torch.Size([64, 256]) torch.Size([3, 17, 256])


In [17]:
y_lin = P(x)
print(y_lin.shape)

torch.Size([3, 17, 64])


In [18]:
Pw = Pw.reshape((64, 32, s))
Pw.shape

torch.Size([64, 32, 8])

In [19]:
decomp = tl.decomposition.CP(rank=32*s, normalize_factors=False, verbose=False, init="random", tol=1e-24, random_state=42)
_, (P1, P2, P3) = decomp.fit_transform(Pw)

In [20]:
recons_pw = tl.cp_to_tensor((None, (P1, P2, P3)))
th.allclose(recons_pw, Pw)

True

In [21]:
print(P1.shape, P2.shape, P3.shape)

torch.Size([64, 256]) torch.Size([32, 256]) torch.Size([8, 256])


In [22]:
B, N, C = x.shape
x_ = x.reshape((B, N, C//s, s))
x_ = x_.unsqueeze(0) # (1, bs, patches, heads, headdim)
preprocess = (
    lambda x: x.unsqueeze(0).unsqueeze(0).unsqueeze(0).permute([-1, 0, 1, 2, 3])
)
P_3 = preprocess(P3)
P_2 = preprocess(P2)
P_1 = preprocess(P1)
inter_1 = x_ @ P_3.swapaxes(-2, -1)
inter_2 = P_2 @ inter_1
inter_2 = inter_2.squeeze(-1)
output = inter_2 @ P_1.squeeze(1)
output = th.sum(output, dim=0)

In [23]:
th.allclose(output, y_lin)

True

In [24]:
th.mean((y_lin-output)**2)

tensor(1.1375e-18, grad_fn=<MeanBackward0>)