In [1]:
import numpy as np

# Import MXNet
from mxnet import nd, autograd
import mxnet as mx

# Import TensorLy
import tensorly as tl
from tensorly.tucker_tensor import tucker_to_tensor
from tensorly.random import check_random_state

  import OpenSSL.SSL
Using numpy backend.


Set the backend to MXNet

In [2]:
tl.set_backend('mxnet')

Using mxnet backend.


# Tucker decomposition using SGD and autograd

We just fix the random seed for reproducibility

In [3]:
random_state = 1234
rng = check_random_state(random_state)

Define a random tensor

In [4]:
shape = [5, 5, 5]

tensor = tl.tensor(rng.random_sample(shape))

Initialise a random Tucker decomposition of that tensor

In [5]:
ranks = [5, 5, 5]

core = tl.tensor(rng.random_sample(ranks))
factors = [tl.tensor(rng.random_sample((tensor.shape[i], ranks[i]))) for i in range(tensor.ndim)]

This is where the magic happens: we can attach gradients to the tensors!

In [6]:
core.attach_grad()
for f in factors:
    f.attach_grad()

Let's use the simplest possible learning method: SGD

In [7]:
def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad

Now we just iterate through the training loop and backpropagate... 

You might have seen such loop in the Gluon tutorials -- if not, check them out! 
https://github.com/zackchase/mxnet-the-straight-dope

In [8]:
n_iter = 40000
lr = 0.0005
penalty = 0.1
for i in range(1, n_iter):
    
    with autograd.record():
        
        # Reconstruct the tensor from the decomposed form
        rec = tucker_to_tensor(core, factors)
        
        # l2 loss 
        loss = nd.sum((rec - tensor)**2)

        # l2 penalty on the factors of the decomposition
        for f in factors:
            loss = loss + penalty * nd.sum(f**2)

    loss.backward()
    SGD([core] + factors, lr)

    if i % 1000 == 0:
        rec_error = tl.norm(rec - tensor, 2)/tl.norm(tensor, 2)
        print("Epoch %s,. Rec. error: %s" % (i, rec_error))

Epoch 1000,. Rec. error: 0.347279879522
Epoch 2000,. Rec. error: 0.278443677975
Epoch 3000,. Rec. error: 0.247618797548
Epoch 4000,. Rec. error: 0.22599700906
Epoch 5000,. Rec. error: 0.210141688389
Epoch 6000,. Rec. error: 0.199405938799
Epoch 7000,. Rec. error: 0.191912010837
Epoch 8000,. Rec. error: 0.185513961026
Epoch 9000,. Rec. error: 0.179243347001
Epoch 10000,. Rec. error: 0.172699493971
Epoch 11000,. Rec. error: 0.16559772863
Epoch 12000,. Rec. error: 0.157599658769
Epoch 13000,. Rec. error: 0.148224595152
Epoch 14000,. Rec. error: 0.136828442124
Epoch 15000,. Rec. error: 0.122807826415
Epoch 16000,. Rec. error: 0.106519898169
Epoch 17000,. Rec. error: 0.0904362832508
Epoch 18000,. Rec. error: 0.0770750992197
Epoch 19000,. Rec. error: 0.0665911924915
Epoch 20000,. Rec. error: 0.0581581413731
Epoch 21000,. Rec. error: 0.0510929855842
Epoch 22000,. Rec. error: 0.0450363862519
Epoch 23000,. Rec. error: 0.0398456307076
Epoch 24000,. Rec. error: 0.0354543345719
Epoch 25000,. Rec. 