In [1]:
import numpy as np                                                                                                                                                                  
from mxnet import nd, autograd
import mxnet as mx
from tensorly import backend as T
from tensorly.tucker_tensor import tucker_to_tensor
from tensorly.random import check_random_state

Using mxnet backend.


# Tucker decomposition using SGD and autograd

We just fix the random seed for reproducibility

In [2]:
random_state = 1234
rng = check_random_state(random_state)

Define a random tensor

In [3]:
shape = [5, 5, 5]

tensor = T.tensor(rng.random_sample(shape))

Initialise a random Tucker decomposition of that tensor

In [4]:
ranks = [5, 5, 5]

core = T.tensor(rng.random_sample(ranks))
factors = [T.tensor(rng.random_sample((tensor.shape[i], ranks[i]))) for i in range(tensor.ndim)]

This is where the magic happens: we can attach gradients to the tensors!

In [5]:
core.attach_grad()
for f in factors:
    f.attach_grad()

Let's use the simplest possible learning method: SGD

In [6]:
def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad

Now we just iterate through the training loop and backpropagate... 

You might have seen such loop in the Gluon tutorials -- if not, check them out! 
https://github.com/zackchase/mxnet-the-straight-dope

In [7]:
n_iter = 40000
lr = 0.0005
penalty = 0.1
for i in range(1, n_iter):
    
    with autograd.record():
        
        # Reconstruct the tensor from the decomposed form
        rec = tucker_to_tensor(core, factors)
        
        # l2 loss 
        loss = nd.sum((rec - tensor)**2)

        # l2 penalty on the factors of the decomposition
        for f in factors:
            loss = loss + penalty * nd.sum(f**2)

    loss.backward()
    SGD([core] + factors, lr)

    if i % 1000 == 0:
        rec_error = T.norm(rec - tensor, 2)/T.norm(tensor, 2)
        print("Epoch %s,. Rec. error: %s" % (i, rec_error))

Epoch 1000,. Rec. error: 0.347279884328
Epoch 2000,. Rec. error: 0.278443665664
Epoch 3000,. Rec. error: 0.247618793528
Epoch 4000,. Rec. error: 0.225997005038
Epoch 5000,. Rec. error: 0.210141680724
Epoch 6000,. Rec. error: 0.199405929816
Epoch 7000,. Rec. error: 0.191911998173
Epoch 8000,. Rec. error: 0.185513944756
Epoch 9000,. Rec. error: 0.179243332725
Epoch 10000,. Rec. error: 0.172699486789
Epoch 11000,. Rec. error: 0.165597723646
Epoch 12000,. Rec. error: 0.157599644149
Epoch 13000,. Rec. error: 0.148224587749
Epoch 14000,. Rec. error: 0.136828436857
Epoch 15000,. Rec. error: 0.122807824452
Epoch 16000,. Rec. error: 0.106519893162
Epoch 17000,. Rec. error: 0.0904362764887
Epoch 18000,. Rec. error: 0.0770750921177
Epoch 19000,. Rec. error: 0.0665911877544
Epoch 20000,. Rec. error: 0.0581581382829
Epoch 21000,. Rec. error: 0.0510929837756
Epoch 22000,. Rec. error: 0.0450363835708
Epoch 23000,. Rec. error: 0.0398456304225
Epoch 24000,. Rec. error: 0.0354543330036
Epoch 25000,. Rec