In [1]:
# Tensorly library
import tensorly as tl
import numpy as np
from functools import reduce

In [2]:
"""
TENSOR TO TENSOR REGRESSION v1    
  Inputs:
    X - Input Tensor, shape (N, I_1, ..., I_{M1})
    Y - Output Tensor, shape (N, J_1, ..., J_{M2})
    rank - The rank of the solution weights matrix
    lambda_reg - how much l2 regularization to use
    return_factors - whether to include the weight matrix factors along with the weight matrix itself
    eps - the precision of our estimate
    
  Outputs:
    W - weight matrix tensor, shape (I_1, ..., I_{M1}, J_1, ..., J_{M2})
    factors - List of CP factors which forms the CP decomposition of W. Will only be returned if return_factors==True
"""
def tensor_to_tensor_regression_1(X, Y, rank, lambda_reg=1e-1, return_factors=False, eps=1e-6):
    # Check that the N is consistent between X and Y tensors
    if X.shape[0] != Y.shape[0]:
        print("Wrong leading dimensions for tensors X and Y")
    
    # Number of examples (leading dimension in the tensor)
    N = X.shape[0]
    
    # Setup the sizes of the X and Y matrices
    I = reduce(lambda x, y: x * y, X.shape[1:])
    J = reduce(lambda x, y: x * y, Y.shape[1:])
    
    # Initialize the u_r and v_r vectors which we will directly optimize
    u_rs = [np.random.random(I)-0.5 for r in range(rank)]
    v_rs = [np.random.random(J)-0.5 for r in range(rank)]
    
    # Matricize X and Y
    X_mat = tl.unfold(X, mode=0)
    Y_mat = tl.unfold(Y, mode=0)
    
    # Precompute some repeatedly used matrices
    XtXpL = (tl.transpose(X_mat) @ X_mat) + lambda_reg * tl.eye(I)
    XtXpL_inv = np.linalg.inv(XtXpL)
    XtY = tl.transpose(X_mat) @ Y_mat
    YtX = tl.transpose(XtY)
    
    steps = 0
    prev_error = -1
    #errors = [] # TODO: Remove for production (only used for testing)
    # Keep optimizing until the error has converged
    while True:
        # For every rank of the weights matrix
        for r in range(rank):
            
            # Update all the u_rs first
            u_comp = tl.zeros(I)
            # Do some ugly linear algebra
            for r1 in range(rank):
                # Only want cases where r1 != r
                if r1 == r:
                    continue
                # Computing the summation in equations (13)
                u_comp += (v_rs[r1] @ v_rs[r]) * (XtXpL @ u_rs[r1])
            # More ugly linear algebra, finishing off the calculation from (13) 
            u_rs[r] = (XtXpL_inv @ ((XtY @ v_rs[r])-(u_comp / 2))) / (v_rs[r] @ v_rs[r])
            
            # Now update all the v_rs
            v_comp = np.zeros(J)
            # Do some ugly linear algebra
            for r1 in range(rank):
                # Only want cases where r1 != r
                if r1 == r:
                    continue
                v_comp += (u_rs[r] @ XtXpL @ u_rs[r1]) * v_rs[r1]
            v_rs[r] = ((YtX @ u_rs[r]) - (v_comp / 2)) / (u_rs[r] @ XtXpL @ u_rs[r])
        
        # Here comes the normalization step (ensures numerical stability)  
        for r in range(rank):
            u_scale = (tl.norm(v_rs[r], 2) / tl.norm(u_rs[r], 2)) ** 0.5
            v_scale = 1 / u_scale
            u_rs[r] *= u_scale
            v_rs[r] *= v_scale
        
        # Compute the new error, this time ignoring regularization
        W_mat = np.zeros((I, J))
        for r in range(rank):
            # Add each W_r
            W_mat += tl.kron(u_rs[r], v_rs[r]).reshape(I, J)
        # Use MSE to normalize the by the size of the tensor
        error = np.square(Y_mat - X_mat @ W_mat).mean() + lambda_reg * np.square(W_mat).mean()
        #errors.append(error)
        # Determine if we have converged
        if prev_error >= 0 and abs(prev_error - error) < eps:
            print("Converged after", steps, "steps. Final Error:", error)
            break
        else:
            print("Step:", steps, "Error:", error)
        # Reset the previous error
        prev_error = error
        # Next step
        steps += 1
        
    
    # TODO: Add functionality to make the factors
    
    return tl.reshape(W_mat, (X.shape[1:]) + (Y.shape[1:])) 
    

In [3]:
N = 11
X = np.random.random((N, 5, 7))
Y = np.zeros((N, 2, 3))
# Setup Y tensor with some dummy data
for n in range(N):
    x = X[n]
    Y[n] = np.array([[x[0, 0] + x[1, 1] , 2 * x[1, 0] - x[3, 2], (x[4, 5] + 1) ** 2],
                    [-x[1, 6] + 3 * x[1, 5] ,  - x[0, 6] - x[3, 3], (x[2, 5] + 2) ** 2]])
# Now fit it
W = tensor_to_tensor_regression_1(X, Y, 5, lambda_reg=1)
print(W.shape)
# We can verify our answer is correct with the following
Y_pred = tl.tenalg.contract(X, range(1, tl.ndim(X)), W, range(tl.ndim(X) - 1))
print(np.square(Y_pred - Y).mean()) # Should print the same as the final error

Step: 0 Error: 5.061624698862072
Step: 1 Error: 5.335491109894079
Step: 2 Error: 4.4700380983810755
Step: 3 Error: 3.9390258590155547
Step: 4 Error: 3.684024941495695
Step: 5 Error: 3.6061418243456775
Step: 6 Error: 3.6112425132332824
Step: 7 Error: 3.640379583089283
Step: 8 Error: 3.6663803940855915
Step: 9 Error: 3.681947287045
Step: 10 Error: 3.68873035634585
Step: 11 Error: 3.6905187456440935
Step: 12 Error: 3.690308366191057
Step: 13 Error: 3.6897143634262375
Step: 14 Error: 3.6893277998142864
Step: 15 Error: 3.6892168193960995
Step: 16 Error: 3.689275923513273
Step: 17 Error: 3.6893947510374248
Step: 18 Error: 3.6895082610743466
Step: 19 Error: 3.689592757309681
Step: 20 Error: 3.6896478116962337
Step: 21 Error: 3.6896813270527478
Step: 22 Error: 3.689701524404915
Step: 23 Error: 3.689714209027369
Step: 24 Error: 3.689722728484145
Step: 25 Error: 3.689728784867512
Step: 26 Error: 3.689733212747787
Step: 27 Error: 3.689736467066588
Step: 28 Error: 3.6897388542700518
Step: 29 Error

In [46]:
# For timing
import time

start_time = time.time()
W = tensor_to_tensor_regression_1(X, Y, 5, lambda_reg=10)
print(time.time() - start_time)

Step: 0 Error: 4.0214449899317515
Step: 1 Error: 4.615278855711329
Step: 2 Error: 4.330122752733615
Step: 3 Error: 3.9053734534798177
Step: 4 Error: 3.6026137719382403
Step: 5 Error: 3.4503076963740353
Step: 6 Error: 3.399219947618061
Step: 7 Error: 3.397285823498813
Step: 8 Error: 3.410641701689073
Step: 9 Error: 3.423343543054102
Step: 10 Error: 3.4308872772046177
Step: 11 Error: 3.4339501762629157
Step: 12 Error: 3.434492264382796
Step: 13 Error: 3.434083972665766
Step: 14 Error: 3.433562029351891
Step: 15 Error: 3.433216856867045
Step: 16 Error: 3.4330632900403915
Step: 17 Error: 3.4330291269262503
Step: 18 Error: 3.4330451177581387
Step: 19 Error: 3.433070386632795
Step: 20 Error: 3.4330887479150083
Step: 21 Error: 3.433097858556881
Step: 22 Error: 3.433100518950238
Converged after 23 steps. Final Error: 3.4331000359015005
0.09400153160095215


In [4]:
import torchvision
from torchvision import transforms
import torch

In [None]:
transform = transforms.Compose([
    transforms.Scale(64),
    # transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_X = torchvision.datasets.CelebA("", split='train', transform = transform, download=False)



In [8]:
data_loader = torch.utils.data.DataLoader(train_X,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=1)

In [15]:
transform2=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
mnist_X_train = torchvision.datasets.MNIST('../data', train=True, download=True,
                       transform=transform2)
mnist_X_test = torchvision.datasets.MNIST('../data', train=False,
                       transform=transform2)
train_loader = torch.utils.data.DataLoader(mnist_X_train, batch_size=10000)
test_loader = torch.utils.data.DataLoader(mnist_X_test, batch_size=10000)

In [18]:
image_lst = []
label_lst = []
for (image, labels) in train_loader:
    np_image = image.detach().numpy().reshape((10000, 28, 28))
    labels = labels.detach().numpy().reshape((10000, 1))
    image_lst.append(np_image)
    label_lst.append(labels)
X_train_mnist = np.concatenate(image_lst, axis=0)
Y_train_mnist = np.concatenate(label_lst, axis=0)
print(X_train_mnist.shape, Y_train_mnist.shape)

(60000, 28, 28) (60000, 1)


In [62]:
W = tensor_to_tensor_regression_1(X_train_mnist, Y_train_mnist - np.mean(Y_train_mnist), 6, lambda_reg=50)

Step: 0 Error: 10.116550579891335
Step: 1 Error: 8.820110526058821
Step: 2 Error: 8.010290946282508
Step: 3 Error: 7.450563029164129
Step: 4 Error: 7.107466239265297
Step: 5 Error: 6.93553883150915
Step: 6 Error: 6.8761384254764195
Step: 7 Error: 6.876331805173679
Step: 8 Error: 6.898339114857782
Step: 9 Error: 6.9209409892584315
Step: 10 Error: 6.935958768227891
Step: 11 Error: 6.94292353094067
Step: 12 Error: 6.944448696898016
Step: 13 Error: 6.943402120384501
Step: 14 Error: 6.941751098594194
Step: 15 Error: 6.940443254327546
Step: 16 Error: 6.939715975574696
Step: 17 Error: 6.939458041800317
Step: 18 Error: 6.9394656375204455
Step: 19 Error: 6.939570466901196
Step: 20 Error: 6.939675672066649
Step: 21 Error: 6.93974469592882
Step: 22 Error: 6.93977628134768
Step: 23 Error: 6.9397828814601095
Step: 24 Error: 6.939777806099475
Step: 25 Error: 6.939770091931709
Step: 26 Error: 6.939764048596222
Step: 27 Error: 6.939760720061978
Step: 28 Error: 6.939759561690776
Converged after 29 step

In [63]:
image_lst = []
label_lst = []
for (image, labels) in test_loader:
    np_image = image.detach().numpy().reshape((10000, 28, 28))
    labels = labels.detach().numpy().reshape((10000, 1))
    image_lst.append(np_image)
    label_lst.append(labels)
X_test_mnist = np.concatenate(image_lst, axis=0)
Y_test_mnist = np.concatenate(label_lst, axis=0)
Y_test_pred = tl.tenalg.contract(X_test_mnist, range(1, tl.ndim(X_test_mnist)), W, range(tl.ndim(X_test_mnist) - 1)) + np.mean(Y_train_mnist)

In [64]:
total = 0
correct = 0
for i, entry in enumerate(Y_test_mnist):
    #print(entry[0], int(round(Y_test_pred[i][0])))
    if entry[0] == int(round(Y_test_pred[i][0])):
        correct += 1
    total += 1
print(correct / total)

0.1706


In [67]:
lambda_regs = [1e-6, 1e-2, 1e-1, 1, 10, 100, 1000]
ranks = [1, 2, 3, 4, 5, 10, 20]
accuracies = np.zeros((7, 7))
for i, rank in enumerate(ranks):
    print(i)
    for j, reg in enumerate(lambda_regs):
        print(j)
        W = tensor_to_tensor_regression_1(X_train_mnist, Y_train_mnist - np.mean(Y_train_mnist), rank, lambda_reg=reg)
        Y_test_pred = tl.tenalg.contract(X_test_mnist, range(1, tl.ndim(X_test_mnist)), W, range(tl.ndim(X_test_mnist) - 1)) + np.mean(Y_train_mnist)
        total = 0
        correct = 0
        for k, entry in enumerate(Y_test_mnist):
            #print(entry[0], int(round(Y_test_pred[i][0])))
            if entry[0] == int(round(Y_test_pred[k][0])):
                correct += 1
            total += 1
        accuracy = correct / total
        accuracies[i, j] = accuracy
print(accuracies)        

0
0
Step: 0 Error: 3.204548265980794
Converged after 1 steps. Final Error: 3.2045482659807925
1
Step: 0 Error: 3.2140090152484606
Converged after 1 steps. Final Error: 3.2140090152484597
2
Step: 0 Error: 3.228069937738532
Converged after 1 steps. Final Error: 3.2280699377385327
3
Step: 0 Error: 3.273026355416069
Converged after 1 steps. Final Error: 3.2730263554160697
4
Step: 0 Error: 3.358455337326873
Converged after 1 steps. Final Error: 3.358455337326873
5
Step: 0 Error: 3.8212322015244817
Converged after 1 steps. Final Error: 3.8212322015244844
6
Step: 0 Error: 6.236447678574821
Converged after 1 steps. Final Error: 6.236447678574864
1
0
Step: 0 Error: 5.54831449320324
Step: 1 Error: 4.038074326948852
Step: 2 Error: 3.830230671385107
Step: 3 Error: 3.788878742784988
Step: 4 Error: 3.7792036545750016
Step: 5 Error: 3.7768263264634774
Step: 6 Error: 3.7762345849364887
Step: 7 Error: 3.7760868114671484
Step: 8 Error: 3.7760498782153946
Step: 9 Error: 3.776040645535632
Step: 10 Error: 

Step: 15 Error: 5.079987017564885
Step: 16 Error: 5.079985718706634
Step: 17 Error: 5.079987133099814
Step: 18 Error: 5.0799884860283715
Converged after 19 steps. Final Error: 5.079989223534988
2
Step: 0 Error: 7.3232114126223085
Step: 1 Error: 6.09591858470482
Step: 2 Error: 5.472794005445499
Step: 3 Error: 5.2061017942552645
Step: 4 Error: 5.115140193562248
Step: 5 Error: 5.097080803638201
Step: 6 Error: 5.101536259854459
Step: 7 Error: 5.10858069246527
Step: 8 Error: 5.112988585554311
Step: 9 Error: 5.114865829934226
Step: 10 Error: 5.115371936770033
Step: 11 Error: 5.115363977766392
Step: 12 Error: 5.115251463673782
Step: 13 Error: 5.115167034013804
Step: 14 Error: 5.115126054149986
Step: 15 Error: 5.115112418698074
Step: 16 Error: 5.115110582885214
Step: 17 Error: 5.115112036617015
Step: 18 Error: 5.115113532162085
Converged after 19 steps. Final Error: 5.115114370433674
3
Step: 0 Error: 7.229726833250071
Step: 1 Error: 6.02496251300648
Step: 2 Error: 5.566830334574445
Step: 3 Err

Step: 16 Error: 7.171261098836975
Converged after 17 steps. Final Error: 7.171261634779852
6
Step: 0 Error: 25.53835026853555
Step: 1 Error: 18.499800875295836
Step: 2 Error: 15.713464034041202
Step: 3 Error: 14.44997835487049
Step: 4 Error: 13.894837368358353
Step: 5 Error: 13.701457283642206
Step: 6 Error: 13.677130125188816
Step: 7 Error: 13.710429695363214
Step: 8 Error: 13.747402986566641
Step: 9 Error: 13.77077797941909
Step: 10 Error: 13.780872936262838
Step: 11 Error: 13.78305733862987
Step: 12 Error: 13.782044642319388
Step: 13 Error: 13.780486047698346
Step: 14 Error: 13.77937898341882
Step: 15 Error: 13.778840539419198
Step: 16 Error: 13.778681828167219
Step: 17 Error: 13.778697772251478
Step: 18 Error: 13.778759096490592
Step: 19 Error: 13.778809741390702
Step: 20 Error: 13.77883732733398
Step: 21 Error: 13.778847255275235
Converged after 22 steps. Final Error: 13.778848061069372
5
0
Step: 0 Error: 8.294785654962286
Step: 1 Error: 8.203519086935898
Step: 2 Error: 8.30712239

Converged after 57 steps. Final Error: 6.868228929168961
4
Step: 0 Error: 9.274959291714554
Step: 1 Error: 9.345551309819719
Step: 2 Error: 9.151325009738825
Step: 3 Error: 8.773813628025518
Step: 4 Error: 8.330703054816137
Step: 5 Error: 7.895103515551847
Step: 6 Error: 7.516118525141227
Step: 7 Error: 7.227548243222281
Step: 8 Error: 7.042688736537057
Step: 9 Error: 6.953805451458132
Step: 10 Error: 6.9391649125966115
Step: 11 Error: 6.971971947151299
Step: 12 Error: 7.027097205617666
Step: 13 Error: 7.084848496928198
Step: 14 Error: 7.132455657823256
Step: 15 Error: 7.163939325262613
Step: 16 Error: 7.178845060963884
Step: 17 Error: 7.1803392314366175
Step: 18 Error: 7.173213395195542
Step: 19 Error: 7.162254861775508
Step: 20 Error: 7.151224097053896
Step: 21 Error: 7.142447798109333
Step: 22 Error: 7.136881588836961
Step: 23 Error: 7.134442029108469
Step: 24 Error: 7.134427873784763
Step: 25 Error: 7.135905428948533
Step: 26 Error: 7.137991953967624
Step: 27 Error: 7.1400187551048

Step: 70 Error: 7.4120630612377365
Step: 71 Error: 7.414224551006278
Step: 72 Error: 7.416116267097011
Step: 73 Error: 7.417632381713014
Step: 74 Error: 7.4187114449781735
Step: 75 Error: 7.419334623000562
Step: 76 Error: 7.419520898403608
Step: 77 Error: 7.4193200468390375
Step: 78 Error: 7.418804269995268
Step: 79 Error: 7.418059340618942
Step: 80 Error: 7.417176008238249
Step: 81 Error: 7.416242303453012
Step: 82 Error: 7.4153371817875025
Step: 83 Error: 7.414525815103176
Step: 84 Error: 7.413856629766459
Step: 85 Error: 7.413360055798081
Step: 86 Error: 7.413048841915741
Step: 87 Error: 7.412919669944057
Step: 88 Error: 7.412955762123484
Step: 89 Error: 7.413130135229661
Step: 90 Error: 7.413409158507756
Step: 91 Error: 7.413756103982029
Step: 92 Error: 7.4141344127741835
Step: 93 Error: 7.414510474518151
Step: 94 Error: 7.4148557736815555
Step: 95 Error: 7.415148331960181
Step: 96 Error: 7.415373437459575
Step: 97 Error: 7.415523709461655
Step: 98 Error: 7.415598590352602
Step: 99

Step: 43 Error: 7.443660254053616
Step: 44 Error: 7.446112259124708
Step: 45 Error: 7.452251684654845
Step: 46 Error: 7.461083772883114
Step: 47 Error: 7.4715342028628715
Step: 48 Error: 7.482547361491137
Step: 49 Error: 7.493171943918102
Step: 50 Error: 7.502626218088505
Step: 51 Error: 7.510338724083404
Step: 52 Error: 7.5159638432104865
Step: 53 Error: 7.519374816733502
Step: 54 Error: 7.5206389457799805
Step: 55 Error: 7.519980711101123
Step: 56 Error: 7.517738538065564
Step: 57 Error: 7.514320194471255
Step: 58 Error: 7.510160694754297
Step: 59 Error: 7.505685398350979
Step: 60 Error: 7.501279935588031
Step: 61 Error: 7.497267761725401
Step: 62 Error: 7.493895527186809
Step: 63 Error: 7.491326005924255
Step: 64 Error: 7.489637979047343
Step: 65 Error: 7.488832179444667
Step: 66 Error: 7.488842148970208
Step: 67 Error: 7.489548654551118
Step: 68 Error: 7.490796181999159
Step: 69 Error: 7.492410004684159
Step: 70 Error: 7.494212424064593
Step: 71 Error: 7.496036995246822
Step: 72 Er

Step: 20 Error: 7.48060993993384
Step: 21 Error: 7.577682654900879
Step: 22 Error: 7.690694220824984
Step: 23 Error: 7.8099573900685595
Step: 24 Error: 7.926218847116505
Step: 25 Error: 8.031125032058789
Step: 26 Error: 8.117801173703665
Step: 27 Error: 8.181400489782083
Step: 28 Error: 8.21947097248075
Step: 29 Error: 8.232038460370754
Step: 30 Error: 8.221383091355015
Step: 31 Error: 8.191559703608455
Step: 32 Error: 8.147760851035425
Step: 33 Error: 8.095637021962206
Step: 34 Error: 8.040676363955754
Step: 35 Error: 7.987715727131477
Step: 36 Error: 7.940617672422301
Step: 37 Error: 7.9021140400115675
Step: 38 Error: 7.873792017451088
Step: 39 Error: 7.856185734810539
Step: 40 Error: 7.8489342483542215
Step: 41 Error: 7.850972204323821
Step: 42 Error: 7.860728573899956
Step: 43 Error: 7.876318114058145
Step: 44 Error: 7.895717322348614
Step: 45 Error: 7.916920742893306
Step: 46 Error: 7.938074906122111
Step: 47 Error: 7.957587046550596
Step: 48 Error: 7.974205313977002
Step: 49 Erro

Step: 104 Error: 9.6023062005979
Step: 105 Error: 9.602132267790697
Step: 106 Error: 9.601999700984738
Step: 107 Error: 9.601912598069866
Step: 108 Error: 9.601870971285031
Converged after 109 steps. Final Error: 9.601871345512418
6
Step: 0 Error: 28.102784809982268
Step: 1 Error: 24.49635726219894
Step: 2 Error: 23.167702184966423
Step: 3 Error: 22.484455563546938
Step: 4 Error: 21.935070909199773
Step: 5 Error: 21.397069124406528
Step: 6 Error: 20.884890987851264
Step: 7 Error: 20.426119802726845
Step: 8 Error: 20.022774525992148
Step: 9 Error: 19.657314954393833
Step: 10 Error: 19.31042161611723
Step: 11 Error: 18.97343982752841
Step: 12 Error: 18.65118400502478
Step: 13 Error: 18.357542326293746
Step: 14 Error: 18.10836652004302
Step: 15 Error: 17.915477285227
Step: 16 Error: 17.783764705424144
Step: 17 Error: 17.711413705665876
Step: 18 Error: 17.692006185757624
Step: 19 Error: 17.71694935679263
Step: 20 Error: 17.77715994213112
Step: 21 Error: 17.86371164991148
Step: 22 Error: 17