In [1]:
import numpy as np

# Import PyTorch
import torch
from torch.autograd import Variable

# Import TensorLy
import tensorly as tl
from tensorly.tucker_tensor import tucker_to_tensor
from tensorly.random import check_random_state

Using numpy backend.


In [2]:
tl.set_backend('pytorch')

Using pytorch backend.


Make the results reproducible by fixing the random seed

In [4]:
random_state = 1234
rng = check_random_state(random_state)
device = 'cuda:8'
#device = 'cpu'

Define a random tensor which we will try to decompose. We wrap our tensors in Variables so we can backpropagate through them:



In [10]:
shape = [5, 5, 5]
tensor = tl.tensor(rng.random_sample(shape), device=device, requires_grad=True)

Initialise a random Tucker decomposition of that tensor



In [7]:
ranks = [5, 5, 5]
core = tl.tensor(rng.random_sample(ranks), device=device, requires_grad=True)
factors = [tl.tensor(rng.random_sample((tensor.shape[i], ranks[i])),
                 device=device, requires_grad=True) for i in range(tl.ndim(tensor))]


Now we just iterate through the training loop and backpropagate...



In [11]:
n_iter = 10000
lr = 0.00005
penalty = 0.1

optimizer = torch.optim.Adam([core]+factors, lr=lr)

for i in range(1, n_iter):
    # Important: do not forget to reset the gradients
    optimizer.zero_grad()

    # Reconstruct the tensor from the decomposed form
    rec = tucker_to_tensor(core, factors)

    # squared l2 loss
    loss = tl.norm(rec - tensor, 2)

    # squared l2 penalty on the factors of the decomposition
    for f in factors:
        loss = loss + penalty * f.pow(2).sum()

    loss.backward()
    optimizer.step()

    if i % 1000 == 0:
        rec_error = tl.norm(rec.data - tensor.data, 2)/tl.norm(tensor.data, 2)
        print("Epoch {},. Rec. error: {}".format(i, rec_error))


Epoch 1000,. Rec. error: 12.403386116027832
Epoch 2000,. Rec. error: 8.40932559967041
Epoch 3000,. Rec. error: 5.6003923416137695
Epoch 4000,. Rec. error: 3.6134707927703857
Epoch 5000,. Rec. error: 2.23235821723938
Epoch 6000,. Rec. error: 1.3306471109390259
Epoch 7000,. Rec. error: 0.8450437784194946
Epoch 8000,. Rec. error: 0.6635408997535706
Epoch 9000,. Rec. error: 0.5932939648628235
