In [1]:
% matplotlib inline

import torch
import torch.nn as nn
import numpy as np

# Load Data
data_path = '../data/processed/'
X = np.load(data_path + 'X.npy')
y = np.load(data_path + 'y.npy')
print("X has shape: {}\ny has shape: {}".format(X.shape, y.shape))
X_torch = torch.from_numpy(X).float()
y_torch = torch.from_numpy(y).float()

# Calculate global optimum a.k.a c from appendix A in the paper
# P.S. @ in python 3.5+ means matrix multiplication
m = len(y_torch)
cross_cov = (1 / m) * y_torch.transpose(0,1) @ X_torch # (1, 128)
y_cov = (1 / m) * y_torch.transpose(0,1) @ y_torch
global_opt = -0.5 * (cross_cov @ cross_cov.transpose(0,1)) + 0.5 * y_cov 
global_opt 

X has shape: (2565, 128)
y has shape: (2565, 1)


tensor([[1.0345]])

In [6]:
import sys
sys.path.append('../src/models/')
from linear_nn import three_layer_nn, fro_loss, train, eight_layer_nn

std = 1
learning_rate = 0.5 * 1e-2
eps = 1e-5
model = eight_layer_nn('balanced', std)
loss_fn = fro_loss()
train_iter, loss = train(model, loss_fn, X_torch, y_torch, learning_rate, eps, verbose=True)

0 61.109352111816406
1 12.920022964477539
2 0.4881380498409271
3 0.4879516661167145
4 0.4877651631832123
5 0.48757848143577576
6 0.48739156126976013
7 0.4872044622898102
8 0.4870170056819916
9 0.4868292510509491
10 0.48664113879203796
11 0.48645269870758057
12 0.4862637519836426
13 0.48607438802719116
14 0.48588451743125916
15 0.48569416999816895
16 0.4855031371116638
17 0.4853115677833557
18 0.48511937260627747
19 0.4849265515804291
20 0.4847329556941986
21 0.4845386743545532
22 0.48434358835220337
23 0.4841477870941162
24 0.48395103216171265
25 0.483753502368927
26 0.48355504870414734
27 0.48335567116737366
28 0.4831553101539612
29 0.4829539656639099
30 0.4827515780925751
31 0.4825480878353119
32 0.48234352469444275
33 0.4821378290653229
34 0.4819309413433075
35 0.4817228615283966
36 0.4815135598182678
37 0.481302946805954
38 0.48109111189842224
39 0.48087790608406067
40 0.48066332936286926
41 0.48044732213020325
42 0.48022985458374023
43 0.4800109565258026
44 0.47979050874710083
45 

430 0.05212731659412384
431 0.051425751298666
432 0.0507325679063797
433 0.05004766955971718
434 0.04937100410461426
435 0.04870247840881348
436 0.04804203659296036
437 0.04738961160182953
438 0.04674512520432472
439 0.046108491718769073
440 0.0454796738922596
441 0.044858574867248535
442 0.044245123863220215
443 0.04363925755023956
444 0.04304090142250061
445 0.04244997724890709
446 0.04186641052365303
447 0.04129015654325485
448 0.04072112962603569
449 0.04015924036502838
450 0.039604444056749344
451 0.03905664384365082
452 0.03851579874753952
453 0.03798180818557739
454 0.037454620003700256
455 0.03693416342139244
456 0.036420371383428574
457 0.035913169384002686
458 0.035412486642599106
459 0.03491826355457306
460 0.034430406987667084
461 0.0339488722383976
462 0.03347358852624893
463 0.03300448879599571
464 0.03254150599241257
465 0.03208456188440323
466 0.03163360059261322
467 0.03118855506181717
468 0.03074936382472515
469 0.03031594678759575
470 0.029888255521655083
471 0.02946

939 2.6446732590557076e-05
940 2.6049863663502038e-05
941 2.5658982849563472e-05
942 2.5273939172620885e-05
943 2.4894663511076942e-05
944 2.452102671668399e-05
945 2.415304879832547e-05
946 2.37906351685524e-05
947 2.343361484236084e-05
948 2.308197144884616e-05
949 2.2735577658750117e-05
950 2.2394422558136284e-05
951 2.2058327886043116e-05
952 2.172730819438584e-05
953 2.1401257981779054e-05
954 2.108011904056184e-05
955 2.076377313642297e-05
956 2.0452176613616757e-05
957 2.0145247617620043e-05
958 1.9842922483803704e-05
959 1.9545179384294897e-05
960 1.9251876437920146e-05
961 1.8962975445901975e-05
962 1.8678403648664244e-05
963 1.8398110114503652e-05
964 1.8122009350918233e-05
965 1.7850079530035146e-05
966 1.7582215150468983e-05
967 1.7318376194452867e-05
968 1.7058493540389463e-05
969 1.6802492609713227e-05
970 1.6550317013752647e-05
971 1.630194492463488e-05
972 1.6057310858741403e-05
973 1.5816356608411297e-05
974 1.5579022146994248e-05
975 1.534521834400948e-05
976 1.511497

In [4]:
import sys
sys.path.append('../src/models/')
from linear_nn import three_layer_nn, fro_loss, train

# Set seed
seed = 521
torch.manual_seed(seed)

# Initalize Constants
learning_rates = [1e-4, 1e-3, 1e-2, 1e-1, 1e0]
eps = 1e-5 # Even though they say 1e-5 in the paper, this one works 

# std = 1e-4, no convergence of any lr
# std = 1e-3, need to debug wierd initialization being so close to global_opt
# std = 1e-2, total_iters = 129,005, lr = 1.0
# std = 1e-1, total_iters = 84, lr =  1.0


for learning_rate in learning_rates: # Find optimal learning_rate
    print("Test lr: {}".format(learning_rate))
    print("_________________")
    # Select Model
    std = 1e-3
    model = three_layer_nn('normal', std, False)
    loss_fn = fro_loss()
    
    # Train using vanilla gradient descent outputs the number of iterations 
    # to reach eps of the global opt with value of loss
    num_iter = np.inf
    train_iter, loss = train(model, loss_fn, X_torch, y_torch, learning_rate, global_opt, eps)
    if num_iter > train_iter: # Find min train_iter
        num_iter = train_iter
    print("(Learning Rate, Total Iterations, Loss) = ({}, {}, {}).".format(learning_rate, train_iter, loss))

Test lr: 0.0001
_________________
(Learning Rate, Total Iterations, Loss) = (0.0001, 1000001, 0.49999886751174927).
Test lr: 0.001
_________________
(Learning Rate, Total Iterations, Loss) = (0.001, 1000001, 3.8594913376321927e-10).
Test lr: 0.01
_________________
(Learning Rate, Total Iterations, Loss) = (0.01, 1000001, 2.883695669406161e-12).
Test lr: 0.1
_________________
(Learning Rate, Total Iterations, Loss) = (0.1, 1000001, 1.6422330685540894e-14).
Test lr: 1.0
_________________
(Learning Rate, Total Iterations, Loss) = (1.0, 1000001, 0.319337397813797).
