In [1]:
import torch as th
th.set_default_dtype(th.float64)
th.manual_seed(42)
import tensorly as tl
tl.set_backend("pytorch")

In [2]:
weight = th.randn((3, 128, 8, 16))
X = th.randn((7, 5, 128))

In [3]:
K, E, H, D = weight.shape
weight_res = weight.reshape((K, E, H*D)).swapaxes(-2, -1)
weight_res = weight_res.permute(1, 0, 2)
weight_res = weight_res.reshape((H*D, -1))
print(f"Input shape: {X.shape}")
print(f"Weight shape: {weight_res.shape}")
Y = X @ weight_res
Y = Y.reshape((7, 5, K, 128)).permute(2, 0, 1, 3)
print(f"Y shape: {Y.shape}")

Input shape: torch.Size([7, 5, 128])
Weight shape: torch.Size([128, 384])
Y shape: torch.Size([3, 7, 5, 128])


In [4]:
X_res = X.reshape((7, 5, H, D))
print(f"CP Input shape {X_res.shape}")

CP Input shape torch.Size([7, 5, 8, 16])


In [5]:
K, E, H, D = weight.shape
decomp = tl.decomposition.CP(rank=K*E, normalize_factors=False, verbose=False, init="random", tol=1e-24, random_state=42)
weight_terf = weight.reshape((K*E, H, D))
_, (A2, A3, A4) = decomp.fit_transform(weight_terf)
recons_weight = tl.cp_to_tensor((None, (A2, A3, A4)))
print(th.allclose(recons_weight, weight_terf), th.mean((recons_weight - weight_terf)**2))
print(A2.shape, A3.shape, A4.shape)
X_res = X_res.unsqueeze(0)
preprocess = lambda  x: x.unsqueeze(0).unsqueeze(0).unsqueeze(0).permute((-1, 0, 1, 2, 3))
#A1 = preprocess(A1)
A2 = preprocess(A2)
A3 = preprocess(A3)
A4 = preprocess(A4)
print(A2.shape, A3.shape, A4.shape)
inter_1 = X_res @ A4.swapaxes(-2, -1)
inter_1 = inter_1.squeeze(-1)
inter_2 = inter_1 @ A3.squeeze(-2).swapaxes(-2, -1)
output = inter_2 @ A2.squeeze(-2)
Y2 = th.sum(output, dim=0)
Y2 = Y2.reshape((7, 5, K, 128)).permute(2, 0, 1, 3)
print(Y2.shape)

True tensor(3.2570e-20)
torch.Size([384, 384]) torch.Size([8, 384]) torch.Size([16, 384])
torch.Size([384, 1, 1, 1, 384]) torch.Size([384, 1, 1, 1, 8]) torch.Size([384, 1, 1, 1, 16])
torch.Size([3, 7, 5, 128])


In [6]:
th.allclose(Y, Y2)

True

## 4D Implementation

In [7]:
K, E, H, D = weight.shape
decomp = tl.decomposition.CP(rank=K*E, normalize_factors=False, verbose=False, init="random", tol=1e-24, random_state=42)
_, (A1, A2, A3, A4) = decomp.fit_transform(weight)
recons_weight = tl.cp_to_tensor((None, (A1, A2, A3, A4)))
print(th.allclose(recons_weight, weight), th.mean((recons_weight - weight)**2))
print(A1.shape, A2.shape, A3.shape, A4.shape)
preprocess = lambda  x: x.unsqueeze(0).unsqueeze(0).unsqueeze(0).permute((-1, 0, 1, 2, 3))
A1 = preprocess(A1)
A2 = preprocess(A2)
A3 = preprocess(A3)
A4 = preprocess(A4)
print(X_res.shape, A1.shape, A2.shape, A3.shape, A4.shape)
inter_1 = X_res @ A4.swapaxes(-2, -1)
inter_2 = A3 @ inter_1
inter_3 = inter_2 @ A2
output = A1.swapaxes(-2, -1) @ inter_3
output = th.sum(output, 0).permute((2, 0, 1, 3))
print(output.shape)

True tensor(1.7704e-23)
torch.Size([3, 384]) torch.Size([128, 384]) torch.Size([8, 384]) torch.Size([16, 384])
torch.Size([1, 7, 5, 8, 16]) torch.Size([384, 1, 1, 1, 3]) torch.Size([384, 1, 1, 1, 128]) torch.Size([384, 1, 1, 1, 8]) torch.Size([384, 1, 1, 1, 16])
torch.Size([3, 7, 5, 128])


In [8]:
th.allclose(Y, output)

True