In [1]:
% matplotlib inline

import torch
import torch.nn as nn
import numpy as np

# Load Data
data_path = '../data/processed/'
X = np.load(data_path + 'X.npy')
y = np.load(data_path + 'y.npy')
print("X has shape: {}\ny has shape: {}".format(X.shape, y.shape))
X_torch = torch.from_numpy(X).float()
y_torch = torch.from_numpy(y).float()

X has shape: (2565, 128)
y has shape: (2565, 1)


In [2]:
import sys
sys.path.append('../src/models/')
from linear_nn import three_layer_nn, fro_loss

s = 1e-1
model = three_layer_nn('normal', s, False, p=0.1)
loss_fn = fro_loss()

cross_cov = (1 / len(y_torch)) * y_torch.transpose(0,1) @ X_torch # (1, 128)
y_cov = (1 / len(y_torch)) * y_torch.transpose(0,1) @ y_torch
global_opt = -0.5 * (cross_cov @ cross_cov.transpose(0,1)) + 0.5 * y_cov # Taken from appendix A in paper 
global_opt

tensor([[1.9375e-11]])

In [3]:
learning_rate = 1e-1
eps = 1e-5
loss = np.inf
t = 0
while torch.abs(global_opt - loss) > eps:
    W = model() # W_N * W_{N - 1} * ... * W_1

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(W, X_torch, y_torch)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    # In pytorch, gradients are accumulated with .backward(), hence,
    # we need to zero them out each round
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()
    

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            if param.grad is None:
                continue
            param.data -= learning_rate * param.grad
    t+=1

0 0.12397629022598267
1 0.08094190061092377
2 0.0576319694519043
3 0.043488673865795135
4 0.03422371670603752
5 0.02779902145266533
6 0.02314072474837303
7 0.019639352336525917
8 0.01692858338356018
9 0.014777785167098045
10 0.013035918585956097
11 0.011600688099861145
12 0.010400698520243168
13 0.009384789504110813
14 0.008515425026416779
15 0.0077644879929721355
16 0.007110513746738434
17 0.006536843255162239
18 0.006030356977134943
19 0.005580582655966282
20 0.0051790690049529076
21 0.004818921443074942
22 0.004494460299611092
23 0.004200970288366079
24 0.0039345077238976955
25 0.003691749181598425
26 0.0034698783420026302
27 0.003266492858529091
28 0.0030795345082879066
29 0.0029072295874357224
30 0.0027480425778776407
31 0.002600638894364238
32 0.0024638555478304625
33 0.0023366697132587433
34 0.002218186156824231
35 0.0021076109260320663
36 0.002004244364798069
37 0.0019074634183198214
38 0.0018167129019275308
39 0.001731496537104249
40 0.001651371014304459
41 0.00157593691255897