In [1]:
import numpy as np

# Import PyTorch
import torch
from torch.autograd import Variable

# Import TensorLy
import tensorly as tl
from tensorly.tucker_tensor import tucker_to_tensor
from tensorly.random import check_random_state

Using numpy backend.


In [2]:
tl.set_backend('pytorch')

Using pytorch backend.


Make the results reproducible by fixing the random seed

In [3]:
random_state = 1234
rng = check_random_state(random_state)


Define a random tensor which we will try to decompose. We wrap our tensors in Variables so we can backpropagate through them:



In [4]:
shape = [5, 5, 5]
tensor = Variable(tl.tensor(rng.random_sample(shape)), requires_grad=True)


Initialise a random Tucker decomposition of that tensor



In [6]:
ranks = [5, 5, 5]
core = Variable(tl.tensor(rng.random_sample(ranks)), requires_grad=True)
factors = [Variable(tl.tensor(rng.random_sample((tensor.shape[i], ranks[i]))),
                 requires_grad=True) for i in range(tl.ndim(tensor))]


Now we just iterate through the training loop and backpropagate...



In [7]:
n_iter = 10000
lr = 0.00005
penalty = 0.1

optimizer = torch.optim.Adam([core]+factors, lr=lr)

for i in range(1, n_iter):
    # Important: do not forget to reset the gradients
    optimizer.zero_grad()

    # Reconstruct the tensor from the decomposed form
    rec = tucker_to_tensor(core, factors)

    # squared l2 loss
    loss = (rec - tensor).pow(2).sum()

    # squared l2 penalty on the factors of the decomposition
    for f in factors:
        loss = loss + penalty * f.pow(2).sum()

    loss.backward()
    optimizer.step()

    if i % 1000 == 0:
        rec_error = tl.norm(rec.data - tensor.data, 2)/tl.norm(tensor.data, 2)
        print("Epoch %s,. Rec. error: %s" % (i, rec_error))


Epoch 1000,. Rec. error: 12.675263516353017
Epoch 2000,. Rec. error: 9.445408624499127
Epoch 3000,. Rec. error: 7.188053087882553
Epoch 4000,. Rec. error: 5.518843837818505
Epoch 5000,. Rec. error: 4.245414784950455
Epoch 6000,. Rec. error: 3.2595007915550918
Epoch 7000,. Rec. error: 2.4935482213833473
Epoch 8000,. Rec. error: 1.9015988361988734
Epoch 9000,. Rec. error: 1.450236875395197
