In [1]:
import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from estimator import classical_weights, V1_inspired_weights
import pickle

In [2]:
from mnist import MNIST
mndata = MNIST('./data/mnist/')
train, train_labels = map(torch.FloatTensor, mndata.load_training())
test, test_labels = map(torch.FloatTensor, mndata.load_testing())
X_train = train/255.0
X_test = test/255.0
                    
n, d = X_train.shape

y_train = train_labels.to(dtype=torch.long)
y_test = test_labels.to(dtype=torch.long)

In [3]:
class V1_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        self.fc1.weight.data = torch.FloatTensor(V1_inspired_weights(hidden_size, d, t=5, l=2, scale=scale))
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)
    
class He_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        torch.nn.init.kaiming_normal_(self.fc1.weight)
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)
    
class RF_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        self.fc1.weight.data = torch.FloatTensor(classical_weights(hidden_size, d, scale=scale))
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)

In [4]:
def predict(model, X):
    return model(X).data.max(1)[1]

def error(model, X, y):
    y_pred = predict(model, X)
    accuracy = 1.0 * torch.sum(y_pred == y) / len(y)
    return 1 - accuracy

In [5]:
V1_model = V1_net(32, 1/32)
RF_model = RF_net(32, 1/32)
He_model= He_net(32, 1/32)
print('V1', torch.mean(torch.norm(V1_model.fc1.weight.data, dim=1)))
print('He', torch.mean(torch.norm(He_model.fc1.weight.data, dim=1)))
print('RF', torch.mean(torch.norm(RF_model.fc1.weight.data, dim=1)))
w = torch.mean(torch.norm(RF_model.fc1.weight.data, dim=1))

V1 tensor(4.6528)
He tensor(1.4180)
RF tensor(4.9700)


In [None]:
scale = 1/32
models = {'V1': V1_net, 'RF': RF_net, 'He': He_net}
n_epochs = 1001
n_trials = 3
t, l = 5, 3
loss_func = nn.CrossEntropyLoss()

for h in [50, 100, 400, 1000]:
    for lr in [1e-3, 1e-2, 1e-1, 1e0]:
        train_err = {m: np.zeros((n_trials, n_epochs)) for m in models.keys()}
        test_err = {m: np.zeros((n_trials, n_epochs)) for m in models.keys()}
        loss_list = {m: np.zeros((n_trials, n_epochs)) for m in models.keys()}
        for m, network in models.items():
            for i in range(n_trials):
                model = network(h, scale)
                optim = torch.optim.SGD(model.parameters(), lr=lr)
                for j in range(n_epochs):
                    optim.zero_grad()
                    loss = loss_func(model(X_train), y_train)
                    loss.backward()
                    optim.step()

                    train_err[m][i, j] = error(model, X_train, y_train)
                    test_err[m][i, j] = error(model, X_test, y_test)
                    loss_list[m][i, j] = loss.data

                    if (j % 100 == 0):
                        print('Trial %d, Epoch: %d, %s model, h=%d, lr=%0.5f, Loss=%0.5f, test err=%0.3f' % (i,j, m, 
                                                                                                             h, lr, 
                                                                                                             loss, 
                                                                                                  test_err[m][i, j]))

        results = {'test_err': test_err, 'train_err': train_err, 'loss': loss_list}
        with open('results/initialize_mnist/clf_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f_scale=%0.6f.pickle' % (t, l, 
                                                                                                     h, lr, 
                                                                                                     scale), 
                  'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

Trial 0, Epoch: 0, V1 model, h=50, lr=0.00100, Loss=4.32108, test err=0.952
Trial 0, Epoch: 100, V1 model, h=50, lr=0.00100, Loss=2.31125, test err=0.790
Trial 0, Epoch: 200, V1 model, h=50, lr=0.00100, Loss=1.96239, test err=0.676
Trial 0, Epoch: 300, V1 model, h=50, lr=0.00100, Loss=1.74155, test err=0.601
Trial 0, Epoch: 400, V1 model, h=50, lr=0.00100, Loss=1.58059, test err=0.518
Trial 0, Epoch: 500, V1 model, h=50, lr=0.00100, Loss=1.45714, test err=0.440
Trial 0, Epoch: 600, V1 model, h=50, lr=0.00100, Loss=1.35917, test err=0.390
Trial 0, Epoch: 700, V1 model, h=50, lr=0.00100, Loss=1.27945, test err=0.352
Trial 0, Epoch: 800, V1 model, h=50, lr=0.00100, Loss=1.21329, test err=0.326
Trial 0, Epoch: 900, V1 model, h=50, lr=0.00100, Loss=1.15750, test err=0.306
Trial 0, Epoch: 1000, V1 model, h=50, lr=0.00100, Loss=1.10983, test err=0.291
Trial 1, Epoch: 0, V1 model, h=50, lr=0.00100, Loss=4.20815, test err=0.915
Trial 1, Epoch: 100, V1 model, h=50, lr=0.00100, Loss=2.18770, test

Trial 0, Epoch: 700, V1 model, h=50, lr=0.01000, Loss=0.54484, test err=0.151
Trial 0, Epoch: 800, V1 model, h=50, lr=0.01000, Loss=0.52497, test err=0.146
Trial 0, Epoch: 900, V1 model, h=50, lr=0.01000, Loss=0.50824, test err=0.142
Trial 0, Epoch: 1000, V1 model, h=50, lr=0.01000, Loss=0.49386, test err=0.139
Trial 1, Epoch: 0, V1 model, h=50, lr=0.01000, Loss=5.11700, test err=0.852
Trial 1, Epoch: 100, V1 model, h=50, lr=0.01000, Loss=0.86228, test err=0.244
Trial 1, Epoch: 200, V1 model, h=50, lr=0.01000, Loss=0.70122, test err=0.195
Trial 1, Epoch: 300, V1 model, h=50, lr=0.01000, Loss=0.62731, test err=0.173
Trial 1, Epoch: 400, V1 model, h=50, lr=0.01000, Loss=0.58172, test err=0.161
Trial 1, Epoch: 500, V1 model, h=50, lr=0.01000, Loss=0.54949, test err=0.152
Trial 1, Epoch: 600, V1 model, h=50, lr=0.01000, Loss=0.52483, test err=0.146
Trial 1, Epoch: 700, V1 model, h=50, lr=0.01000, Loss=0.50499, test err=0.140
Trial 1, Epoch: 800, V1 model, h=50, lr=0.01000, Loss=0.48845, te

Trial 1, Epoch: 300, V1 model, h=50, lr=0.10000, Loss=0.34118, test err=0.100
Trial 1, Epoch: 400, V1 model, h=50, lr=0.10000, Loss=0.31273, test err=0.091
Trial 1, Epoch: 500, V1 model, h=50, lr=0.10000, Loss=0.29207, test err=0.084
Trial 1, Epoch: 600, V1 model, h=50, lr=0.10000, Loss=0.27585, test err=0.080
Trial 1, Epoch: 700, V1 model, h=50, lr=0.10000, Loss=0.26247, test err=0.078
Trial 1, Epoch: 800, V1 model, h=50, lr=0.10000, Loss=0.25106, test err=0.075
Trial 1, Epoch: 900, V1 model, h=50, lr=0.10000, Loss=0.24114, test err=0.072
Trial 1, Epoch: 1000, V1 model, h=50, lr=0.10000, Loss=0.23232, test err=0.070
Trial 2, Epoch: 0, V1 model, h=50, lr=0.10000, Loss=4.14497, test err=0.880
Trial 2, Epoch: 100, V1 model, h=50, lr=0.10000, Loss=0.47007, test err=0.127
Trial 2, Epoch: 200, V1 model, h=50, lr=0.10000, Loss=0.38949, test err=0.105
Trial 2, Epoch: 300, V1 model, h=50, lr=0.10000, Loss=0.34573, test err=0.092
Trial 2, Epoch: 400, V1 model, h=50, lr=0.10000, Loss=0.31644, te

Trial 1, Epoch: 1000, V1 model, h=50, lr=1.00000, Loss=0.07424, test err=0.030
Trial 2, Epoch: 0, V1 model, h=50, lr=1.00000, Loss=4.56824, test err=0.711
Trial 2, Epoch: 100, V1 model, h=50, lr=1.00000, Loss=0.28874, test err=0.081
Trial 2, Epoch: 200, V1 model, h=50, lr=1.00000, Loss=0.20541, test err=0.059
Trial 2, Epoch: 300, V1 model, h=50, lr=1.00000, Loss=0.22519, test err=0.065
Trial 2, Epoch: 400, V1 model, h=50, lr=1.00000, Loss=0.18721, test err=0.055
Trial 2, Epoch: 500, V1 model, h=50, lr=1.00000, Loss=0.14949, test err=0.047
Trial 2, Epoch: 600, V1 model, h=50, lr=1.00000, Loss=0.13420, test err=0.045
Trial 2, Epoch: 700, V1 model, h=50, lr=1.00000, Loss=0.12321, test err=0.042
Trial 2, Epoch: 800, V1 model, h=50, lr=1.00000, Loss=0.12185, test err=0.041
Trial 2, Epoch: 900, V1 model, h=50, lr=1.00000, Loss=0.10873, test err=0.039
Trial 2, Epoch: 1000, V1 model, h=50, lr=1.00000, Loss=0.10184, test err=0.037
Trial 0, Epoch: 0, RF model, h=50, lr=1.00000, Loss=2.50317, tes

Trial 2, Epoch: 500, V1 model, h=100, lr=0.00100, Loss=0.91342, test err=0.251
Trial 2, Epoch: 600, V1 model, h=100, lr=0.00100, Loss=0.85906, test err=0.238
Trial 2, Epoch: 700, V1 model, h=100, lr=0.00100, Loss=0.81698, test err=0.226
Trial 2, Epoch: 800, V1 model, h=100, lr=0.00100, Loss=0.78302, test err=0.216
Trial 2, Epoch: 900, V1 model, h=100, lr=0.00100, Loss=0.75479, test err=0.207
Trial 2, Epoch: 1000, V1 model, h=100, lr=0.00100, Loss=0.73077, test err=0.202
Trial 0, Epoch: 0, RF model, h=100, lr=0.00100, Loss=2.49620, test err=0.883
Trial 0, Epoch: 100, RF model, h=100, lr=0.00100, Loss=2.26239, test err=0.846
Trial 0, Epoch: 200, RF model, h=100, lr=0.00100, Loss=2.13853, test err=0.779
Trial 0, Epoch: 300, RF model, h=100, lr=0.00100, Loss=2.03722, test err=0.683
Trial 0, Epoch: 400, RF model, h=100, lr=0.00100, Loss=1.94563, test err=0.595
Trial 0, Epoch: 500, RF model, h=100, lr=0.00100, Loss=1.86101, test err=0.531
Trial 0, Epoch: 600, RF model, h=100, lr=0.00100, Los

Trial 2, Epoch: 1000, V1 model, h=100, lr=0.01000, Loss=0.38064, test err=0.101
Trial 0, Epoch: 0, RF model, h=100, lr=0.01000, Loss=2.57041, test err=0.902
Trial 0, Epoch: 100, RF model, h=100, lr=0.01000, Loss=1.61136, test err=0.420
Trial 0, Epoch: 200, RF model, h=100, lr=0.01000, Loss=1.17900, test err=0.287
Trial 0, Epoch: 300, RF model, h=100, lr=0.01000, Loss=0.94648, test err=0.234
Trial 0, Epoch: 400, RF model, h=100, lr=0.01000, Loss=0.80916, test err=0.204
Trial 0, Epoch: 500, RF model, h=100, lr=0.01000, Loss=0.72004, test err=0.185
Trial 0, Epoch: 600, RF model, h=100, lr=0.01000, Loss=0.65774, test err=0.173
Trial 0, Epoch: 700, RF model, h=100, lr=0.01000, Loss=0.61172, test err=0.160
Trial 0, Epoch: 800, RF model, h=100, lr=0.01000, Loss=0.57623, test err=0.151
Trial 0, Epoch: 900, RF model, h=100, lr=0.01000, Loss=0.54796, test err=0.145
Trial 0, Epoch: 1000, RF model, h=100, lr=0.01000, Loss=0.52485, test err=0.138
Trial 1, Epoch: 0, RF model, h=100, lr=0.01000, Loss

Trial 0, Epoch: 400, RF model, h=100, lr=0.10000, Loss=0.32873, test err=0.094
Trial 0, Epoch: 500, RF model, h=100, lr=0.10000, Loss=0.30792, test err=0.088
Trial 0, Epoch: 600, RF model, h=100, lr=0.10000, Loss=0.29142, test err=0.083
Trial 0, Epoch: 700, RF model, h=100, lr=0.10000, Loss=0.27763, test err=0.078
Trial 0, Epoch: 800, RF model, h=100, lr=0.10000, Loss=0.26581, test err=0.075
Trial 0, Epoch: 900, RF model, h=100, lr=0.10000, Loss=0.25547, test err=0.072
Trial 0, Epoch: 1000, RF model, h=100, lr=0.10000, Loss=0.24624, test err=0.070
Trial 1, Epoch: 0, RF model, h=100, lr=0.10000, Loss=2.61774, test err=0.889
Trial 1, Epoch: 100, RF model, h=100, lr=0.10000, Loss=0.54631, test err=0.145
Trial 1, Epoch: 200, RF model, h=100, lr=0.10000, Loss=0.42652, test err=0.116
Trial 1, Epoch: 300, RF model, h=100, lr=0.10000, Loss=0.37651, test err=0.103
Trial 1, Epoch: 400, RF model, h=100, lr=0.10000, Loss=0.34595, test err=0.095
Trial 1, Epoch: 500, RF model, h=100, lr=0.10000, Los

Trial 0, Epoch: 900, RF model, h=100, lr=1.00000, Loss=0.06742, test err=0.031
Trial 0, Epoch: 1000, RF model, h=100, lr=1.00000, Loss=0.06150, test err=0.030
Trial 1, Epoch: 0, RF model, h=100, lr=1.00000, Loss=2.43638, test err=0.863
Trial 1, Epoch: 100, RF model, h=100, lr=1.00000, Loss=0.24226, test err=0.071
Trial 1, Epoch: 200, RF model, h=100, lr=1.00000, Loss=0.17463, test err=0.054
Trial 1, Epoch: 300, RF model, h=100, lr=1.00000, Loss=0.14117, test err=0.047
Trial 1, Epoch: 400, RF model, h=100, lr=1.00000, Loss=0.11965, test err=0.041
Trial 1, Epoch: 500, RF model, h=100, lr=1.00000, Loss=0.10404, test err=0.039
Trial 1, Epoch: 600, RF model, h=100, lr=1.00000, Loss=0.09207, test err=0.036
Trial 1, Epoch: 700, RF model, h=100, lr=1.00000, Loss=0.08245, test err=0.034
Trial 1, Epoch: 800, RF model, h=100, lr=1.00000, Loss=0.07442, test err=0.032
Trial 1, Epoch: 900, RF model, h=100, lr=1.00000, Loss=0.06765, test err=0.031
Trial 1, Epoch: 1000, RF model, h=100, lr=1.00000, Lo

Trial 1, Epoch: 300, RF model, h=400, lr=0.00100, Loss=1.46580, test err=0.333
Trial 1, Epoch: 400, RF model, h=400, lr=0.00100, Loss=1.30912, test err=0.286
Trial 1, Epoch: 500, RF model, h=400, lr=0.00100, Loss=1.18994, test err=0.254
Trial 1, Epoch: 600, RF model, h=400, lr=0.00100, Loss=1.09683, test err=0.232
Trial 1, Epoch: 700, RF model, h=400, lr=0.00100, Loss=1.02228, test err=0.218
Trial 1, Epoch: 800, RF model, h=400, lr=0.00100, Loss=0.96130, test err=0.205
Trial 1, Epoch: 900, RF model, h=400, lr=0.00100, Loss=0.91050, test err=0.195
Trial 1, Epoch: 1000, RF model, h=400, lr=0.00100, Loss=0.86747, test err=0.187
Trial 2, Epoch: 0, RF model, h=400, lr=0.00100, Loss=2.45025, test err=0.866
Trial 2, Epoch: 100, RF model, h=400, lr=0.00100, Loss=1.98480, test err=0.631
Trial 2, Epoch: 200, RF model, h=400, lr=0.00100, Loss=1.69470, test err=0.440
Trial 2, Epoch: 300, RF model, h=400, lr=0.00100, Loss=1.48174, test err=0.337
Trial 2, Epoch: 400, RF model, h=400, lr=0.00100, Los

Trial 1, Epoch: 800, RF model, h=400, lr=0.01000, Loss=0.41227, test err=0.108
Trial 1, Epoch: 900, RF model, h=400, lr=0.01000, Loss=0.39848, test err=0.104
Trial 1, Epoch: 1000, RF model, h=400, lr=0.01000, Loss=0.38675, test err=0.102
Trial 2, Epoch: 0, RF model, h=400, lr=0.01000, Loss=2.48020, test err=0.888
Trial 2, Epoch: 100, RF model, h=400, lr=0.01000, Loss=0.90408, test err=0.193
Trial 2, Epoch: 200, RF model, h=400, lr=0.01000, Loss=0.66904, test err=0.152
Trial 2, Epoch: 300, RF model, h=400, lr=0.01000, Loss=0.57042, test err=0.134
Trial 2, Epoch: 400, RF model, h=400, lr=0.01000, Loss=0.51371, test err=0.123
Trial 2, Epoch: 500, RF model, h=400, lr=0.01000, Loss=0.47579, test err=0.116
Trial 2, Epoch: 600, RF model, h=400, lr=0.01000, Loss=0.44809, test err=0.111
Trial 2, Epoch: 700, RF model, h=400, lr=0.01000, Loss=0.42666, test err=0.106
Trial 2, Epoch: 800, RF model, h=400, lr=0.01000, Loss=0.40940, test err=0.103
Trial 2, Epoch: 900, RF model, h=400, lr=0.01000, Los

Trial 2, Epoch: 200, RF model, h=400, lr=0.10000, Loss=0.32281, test err=0.086
Trial 2, Epoch: 300, RF model, h=400, lr=0.10000, Loss=0.28943, test err=0.079
Trial 2, Epoch: 400, RF model, h=400, lr=0.10000, Loss=0.26673, test err=0.074
Trial 2, Epoch: 500, RF model, h=400, lr=0.10000, Loss=0.24934, test err=0.069
Trial 2, Epoch: 600, RF model, h=400, lr=0.10000, Loss=0.23511, test err=0.067
Trial 2, Epoch: 700, RF model, h=400, lr=0.10000, Loss=0.22300, test err=0.064
Trial 2, Epoch: 800, RF model, h=400, lr=0.10000, Loss=0.21244, test err=0.061
Trial 2, Epoch: 900, RF model, h=400, lr=0.10000, Loss=0.20307, test err=0.058
Trial 2, Epoch: 1000, RF model, h=400, lr=0.10000, Loss=0.19467, test err=0.057
Trial 0, Epoch: 0, He model, h=400, lr=0.10000, Loss=2.34754, test err=0.859
Trial 0, Epoch: 100, He model, h=400, lr=0.10000, Loss=0.51765, test err=0.120
Trial 0, Epoch: 200, He model, h=400, lr=0.10000, Loss=0.38905, test err=0.100
Trial 0, Epoch: 300, He model, h=400, lr=0.10000, Los

Trial 2, Epoch: 700, RF model, h=400, lr=1.00000, Loss=0.06010, test err=0.032
Trial 2, Epoch: 800, RF model, h=400, lr=1.00000, Loss=0.05276, test err=0.031
Trial 2, Epoch: 900, RF model, h=400, lr=1.00000, Loss=0.04665, test err=0.029
Trial 2, Epoch: 1000, RF model, h=400, lr=1.00000, Loss=0.04149, test err=0.028
Trial 0, Epoch: 0, He model, h=400, lr=1.00000, Loss=2.32012, test err=0.574
Trial 0, Epoch: 100, He model, h=400, lr=1.00000, Loss=0.19844, test err=0.056
Trial 0, Epoch: 200, He model, h=400, lr=1.00000, Loss=0.13706, test err=0.041
Trial 0, Epoch: 300, He model, h=400, lr=1.00000, Loss=0.10592, test err=0.034
Trial 0, Epoch: 400, He model, h=400, lr=1.00000, Loss=0.08608, test err=0.030
Trial 0, Epoch: 500, He model, h=400, lr=1.00000, Loss=0.07204, test err=0.028
Trial 0, Epoch: 600, He model, h=400, lr=1.00000, Loss=0.06146, test err=0.026
Trial 0, Epoch: 700, He model, h=400, lr=1.00000, Loss=0.05319, test err=0.024
Trial 0, Epoch: 800, He model, h=400, lr=1.00000, Los

## plot results

In [21]:
import pickle
import numpy as np
import matplotlib.pyplot as plt

t, l, n_epochs= 5, 3, 1001
models = ['V1', 'RF', 'He']

for scale in [1/8, 1/16, 1/32]:
    for h in [50, 100, 400, 1000]:
        for lr in [1e-3, 1e-2, 1e-1, 1e0]:

            with open('results/initialize_mnist/clf_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f_scale=%0.6f.pickle' % (t, l, 
                                                                                                         h, lr, 
                                                                                                         scale), 
                      'rb') as handle:
                sims = pickle.load(handle)

            avg_train_err = {m: np.mean(sims['train_err'][m], axis=0) for m in models}
            avg_test_err = {m: np.mean(sims['test_err'][m], axis=0) for m in models}
            avg_loss_list = {m: np.mean(sims['loss'][m], axis=0) for m in models}

            std_train_err = {m: np.std(sims['train_err'][m], axis=0) for m in models}
            std_test_err = {m: np.std(sims['train_err'][m], axis=0) for m in models}
            std_loss_list = {m: np.std(sims['train_err'][m], axis=0) for m in models}

            fig = plt.figure(figsize=(12, 5))
            plt.suptitle(r'Shallow FFW FC net w/ GD. h=%d, scale=%0.5f, lr=%0.4f, '%(h, scale, lr), fontsize=16)
            ax = fig.add_subplot(131)
            plt.title('Network loss', fontsize=16)
            plt.plot(np.arange(n_epochs), avg_loss_list['V1'], label='V1-inspired', lw=3)
            plt.plot(np.arange(n_epochs), avg_loss_list['RF'], label='classical', lw=3)
            plt.plot(np.arange(n_epochs), avg_loss_list['He'], label='He', lw=3)

            plt.fill_between(np.arange(n_epochs), avg_loss_list['V1'] - std_loss_list['V1'], 
                             avg_loss_list['V1'] + std_loss_list['V1'], alpha=0.2)
            plt.fill_between(np.arange(n_epochs), avg_loss_list['RF'] - std_loss_list['RF'], 
                             avg_loss_list['RF'] + std_loss_list['RF'], alpha=0.2)
            plt.fill_between(np.arange(n_epochs), avg_loss_list['He'] - std_loss_list['He'], 
                             avg_loss_list['He'] + std_loss_list['He'], alpha=0.2)

            plt.xlabel('Epoch', fontsize=20)
            plt.ylabel('Training loss', fontsize=20)
            ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
            plt.yscale('log')
            plt.legend(fontsize=18)

            ax = fig.add_subplot(132)
            plt.title('Train error', fontsize=16)
            plt.plot(np.arange(n_epochs), avg_train_err['V1'], label='V1-inspired', lw=3)
            plt.plot(np.arange(n_epochs), avg_train_err['RF'],  label='classical', lw=3)
            plt.plot(np.arange(n_epochs), avg_train_err['He'],  label='He', lw=3)
            plt.fill_between(np.arange(n_epochs), avg_train_err['V1'] - std_train_err['V1'], 
                             avg_train_err['V1'] + std_train_err['V1'],  alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_train_err['RF'] - std_train_err['RF'], 
                             avg_train_err['RF'] + std_train_err['RF'],  alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_train_err['He'] - std_train_err['He'], 
                             avg_train_err['He'] + std_train_err['He'],  alpha=0.2 )
            plt.xlabel('Epoch', fontsize=20)
            plt.ylabel('Training error', fontsize=20)
            ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
            plt.yticks(np.arange(0, 1, 0.2))
            plt.yscale('log')
            plt.legend(fontsize=18)

            ax = fig.add_subplot(133)
            plt.title('Test error', fontsize=16)
            plt.plot(np.arange(n_epochs), avg_test_err['V1'], label='V1-inspired', lw=3)
            plt.plot(np.arange(n_epochs), avg_test_err['RF'], label='classical', lw=3)
            plt.plot(np.arange(n_epochs), avg_test_err['He'], label='He', lw=3)
            plt.fill_between(np.arange(n_epochs), avg_test_err['V1'] - std_test_err['V1'], 
                             avg_test_err['V1'] + std_test_err['V1'], alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_test_err['RF'] - std_test_err['RF'], 
                             avg_test_err['RF'] + std_test_err['RF'], alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_test_err['He'] - std_test_err['He'], 
                             avg_test_err['He'] + std_test_err['He'], alpha=0.2 )
            plt.xlabel('Epoch', fontsize=20)
            plt.ylabel('Test error', fontsize=20)
            ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
            plt.yticks(np.arange(0, 1, 0.2))
            plt.yscale('log')
            plt.legend(fontsize=18)

            plt.tight_layout()
            plt.subplots_adjust(top=0.8)    

            print(scale, h, lr, 'Test err, V1: %0.4f, RF: %0.4f, He: %0.4f' % (avg_test_err['V1'][-1], 
                                                                 avg_test_err['RF'][-1],
                                                                avg_test_err['He'][-1]))
            
            plt.savefig('results/initialize_mnist/figures_full_training/init_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f_scale=%0.6f.png' % (t, 
                                                                                                                 l,h, 
                                                                                                                 lr, 
                                                                                                             scale))
            plt.close()

            


0.125 50 0.001 Test err, V1: 0.2058, RF: 0.4096, He: 0.5403
0.125 50 0.01 Test err, V1: 0.1514, RF: 0.1791, He: 0.1404
0.125 50 0.1 Test err, V1: 0.0779, RF: 0.0930, He: 0.0717
0.125 50 1.0 Test err, V1: 0.0484, RF: 0.0441, He: 0.0320
0.125 100 0.001 Test err, V1: 0.1542, RF: 0.2885, He: 0.4677
0.125 100 0.01 Test err, V1: 0.0881, RF: 0.1394, He: 0.1355
0.125 100 0.1 Test err, V1: 0.0559, RF: 0.0822, He: 0.0676
0.125 100 1.0 Test err, V1: 0.0387, RF: 0.0422, He: 0.0265
0.125 400 0.001 Test err, V1: 0.0815, RF: 0.1451, He: 0.3039
0.125 400 0.01 Test err, V1: 0.0858, RF: 0.0815, He: 0.1181
0.125 400 0.1 Test err, V1: 0.0323, RF: 0.0544, He: 0.0632
0.125 400 1.0 Test err, V1: 0.0295, RF: 0.0374, He: 0.0214
0.125 1000 0.001 Test err, V1: 0.0620, RF: 0.1015, He: 0.2261
0.125 1000 0.01 Test err, V1: 0.0498, RF: 0.0619, He: 0.1092
0.125 1000 0.1 Test err, V1: 0.0276, RF: 0.0379, He: 0.0590
0.125 1000 1.0 Test err, V1: 0.0291, RF: 0.0370, He: 0.0199
0.0625 50 0.001 Test err, V1: 0.2481, RF: 0.

In [7]:
h, lr

(1000, 1.0)