In [17]:
import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from estimator import classical_weights, V1_inspired_weights
from data_fns import load_mnist
from sklearn.model_selection import train_test_split
import pickle

In [9]:
from mnist import MNIST
train, train_labels, test, test_labels = load_mnist()
n, d = train.shape

num_train = 50
X_train, _, y_train, _ = train_test_split(train, train_labels, train_size=num_train, stratify=train_labels, 
                                          random_state=42)

X_train, y_train = torch.from_numpy(X_train).float(), torch.from_numpy(y_train).long()
X_test, y_test = torch.from_numpy(test).float(), torch.from_numpy(test_labels).long()

In [10]:
class V1_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        self.fc1.weight.data = torch.FloatTensor(V1_inspired_weights(hidden_size, d, t=5, l=2, scale=scale))
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)
    
class He_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        torch.nn.init.kaiming_normal_(self.fc1.weight)
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)
    
class RF_net(nn.Module):
    def __init__(self, hidden_size, scale):
        super().__init__()
        self.fc1 = nn.Linear(d, hidden_size)
        self.fc1.weight.data = torch.FloatTensor(classical_weights(hidden_size, d, scale=scale))
        self.output = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = torch.relu(self.fc1(inputs))
        return self.output(x)

In [11]:
def predict(model, X):
    return model(X).data.max(1)[1]

def error(model, X, y):
    y_pred = predict(model, X)
    accuracy = 1.0 * torch.sum(y_pred == y) / len(y)
    return 1 - accuracy

In [None]:
scale = 1/32
models = {'V1': V1_net, 'RF': RF_net, 'He': He_net}
n_epochs = 7000
n_trials = 5
t, l = 5, 3
loss_func = nn.CrossEntropyLoss()

for h in [50, 100, 400, 1000]:
    for lr in [1e-3, 1e-2, 1e-1, 1e0]:
        train_err = {m: np.zeros((n_trials, n_epochs)) for m in models.keys()}
        test_err = {m: np.zeros((n_trials, n_epochs)) for m in models.keys()}
        loss_list = {m: np.zeros((n_trials, n_epochs)) for m in models.keys()}
        for m, network in models.items():
            for i in range(n_trials):
                model = network(h, scale)
                optim = torch.optim.SGD(model.parameters(), lr=lr)
                for j in range(n_epochs):
                    optim.zero_grad()
                    loss = loss_func(model(X_train), y_train)
                    loss.backward()
                    optim.step()

                    train_err[m][i, j] = error(model, X_train, y_train)
                    test_err[m][i, j] = error(model, X_test, y_test)
                    loss_list[m][i, j] = loss.data

                    if (j % 500 == 0):
                        print('Trial %d, Epoch: %d, %s model, h=%d, lr=%0.5f, Loss=%0.5f, test err=%0.3f' % (i,j, m, 
                                                                                                             h, lr, 
                                                                                                             loss, 
                                                                                                  test_err[m][i, j]))

        results = {'test_err': test_err, 'train_err': train_err, 'loss': loss_list}
        with open('results/initialize_mnist/clf_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f_scale=%0.6f_fewshot.pickle' % (t, l, 
                                                                                                     h, lr, 
                                                                                                     scale), 
                  'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

Trial 0, Epoch: 0, V1 model, h=50, lr=0.00100, Loss=4.93333, test err=0.884
Trial 0, Epoch: 400, V1 model, h=50, lr=0.00100, Loss=0.96775, test err=0.485
Trial 0, Epoch: 800, V1 model, h=50, lr=0.00100, Loss=0.58889, test err=0.430
Trial 0, Epoch: 1200, V1 model, h=50, lr=0.00100, Loss=0.41699, test err=0.411
Trial 0, Epoch: 1600, V1 model, h=50, lr=0.00100, Loss=0.31728, test err=0.400
Trial 0, Epoch: 2000, V1 model, h=50, lr=0.00100, Loss=0.25222, test err=0.394
Trial 0, Epoch: 2400, V1 model, h=50, lr=0.00100, Loss=0.20669, test err=0.390
Trial 0, Epoch: 2800, V1 model, h=50, lr=0.00100, Loss=0.17333, test err=0.388
Trial 0, Epoch: 3200, V1 model, h=50, lr=0.00100, Loss=0.14810, test err=0.387
Trial 0, Epoch: 3600, V1 model, h=50, lr=0.00100, Loss=0.12845, test err=0.386
Trial 0, Epoch: 4000, V1 model, h=50, lr=0.00100, Loss=0.11284, test err=0.385
Trial 0, Epoch: 4400, V1 model, h=50, lr=0.00100, Loss=0.10020, test err=0.385
Trial 0, Epoch: 4800, V1 model, h=50, lr=0.00100, Loss=0.

Trial 0, Epoch: 6000, RF model, h=50, lr=0.00100, Loss=0.13190, test err=0.376
Trial 0, Epoch: 6400, RF model, h=50, lr=0.00100, Loss=0.11759, test err=0.374
Trial 0, Epoch: 6800, RF model, h=50, lr=0.00100, Loss=0.10565, test err=0.372
Trial 1, Epoch: 0, RF model, h=50, lr=0.00100, Loss=2.47185, test err=0.887
Trial 1, Epoch: 400, RF model, h=50, lr=0.00100, Loss=1.85637, test err=0.704
Trial 1, Epoch: 800, RF model, h=50, lr=0.00100, Loss=1.51861, test err=0.570
Trial 1, Epoch: 1200, RF model, h=50, lr=0.00100, Loss=1.24800, test err=0.513
Trial 1, Epoch: 1600, RF model, h=50, lr=0.00100, Loss=1.02602, test err=0.481
Trial 1, Epoch: 2000, RF model, h=50, lr=0.00100, Loss=0.84233, test err=0.460
Trial 1, Epoch: 2400, RF model, h=50, lr=0.00100, Loss=0.69166, test err=0.444
Trial 1, Epoch: 2800, RF model, h=50, lr=0.00100, Loss=0.56833, test err=0.429
Trial 1, Epoch: 3200, RF model, h=50, lr=0.00100, Loss=0.46826, test err=0.421
Trial 1, Epoch: 3600, RF model, h=50, lr=0.00100, Loss=0.

Trial 1, Epoch: 4800, He model, h=50, lr=0.00100, Loss=0.39113, test err=0.333
Trial 1, Epoch: 5200, He model, h=50, lr=0.00100, Loss=0.33179, test err=0.330
Trial 1, Epoch: 5600, He model, h=50, lr=0.00100, Loss=0.28399, test err=0.327
Trial 1, Epoch: 6000, He model, h=50, lr=0.00100, Loss=0.24529, test err=0.326
Trial 1, Epoch: 6400, He model, h=50, lr=0.00100, Loss=0.21373, test err=0.325
Trial 1, Epoch: 6800, He model, h=50, lr=0.00100, Loss=0.18779, test err=0.324
Trial 2, Epoch: 0, He model, h=50, lr=0.00100, Loss=2.37081, test err=0.925
Trial 2, Epoch: 400, He model, h=50, lr=0.00100, Loss=2.20757, test err=0.829
Trial 2, Epoch: 800, He model, h=50, lr=0.00100, Loss=2.07456, test err=0.724
Trial 2, Epoch: 1200, He model, h=50, lr=0.00100, Loss=1.92562, test err=0.600
Trial 2, Epoch: 1600, He model, h=50, lr=0.00100, Loss=1.75352, test err=0.508
Trial 2, Epoch: 2000, He model, h=50, lr=0.00100, Loss=1.56100, test err=0.454
Trial 2, Epoch: 2400, He model, h=50, lr=0.00100, Loss=1.

Trial 2, Epoch: 3600, V1 model, h=50, lr=0.01000, Loss=0.00790, test err=0.399
Trial 2, Epoch: 4000, V1 model, h=50, lr=0.01000, Loss=0.00689, test err=0.398
Trial 2, Epoch: 4400, V1 model, h=50, lr=0.01000, Loss=0.00610, test err=0.398
Trial 2, Epoch: 4800, V1 model, h=50, lr=0.01000, Loss=0.00547, test err=0.398
Trial 2, Epoch: 5200, V1 model, h=50, lr=0.01000, Loss=0.00494, test err=0.398
Trial 2, Epoch: 5600, V1 model, h=50, lr=0.01000, Loss=0.00451, test err=0.397
Trial 2, Epoch: 6000, V1 model, h=50, lr=0.01000, Loss=0.00414, test err=0.396
Trial 2, Epoch: 6400, V1 model, h=50, lr=0.01000, Loss=0.00382, test err=0.396
Trial 2, Epoch: 6800, V1 model, h=50, lr=0.01000, Loss=0.00355, test err=0.396
Trial 3, Epoch: 0, V1 model, h=50, lr=0.01000, Loss=8.40417, test err=0.887
Trial 3, Epoch: 400, V1 model, h=50, lr=0.01000, Loss=0.12084, test err=0.383
Trial 3, Epoch: 800, V1 model, h=50, lr=0.01000, Loss=0.04737, test err=0.374
Trial 3, Epoch: 1200, V1 model, h=50, lr=0.01000, Loss=0.

Trial 3, Epoch: 2800, RF model, h=50, lr=0.01000, Loss=0.01237, test err=0.373
Trial 3, Epoch: 3200, RF model, h=50, lr=0.01000, Loss=0.01033, test err=0.373
Trial 3, Epoch: 3600, RF model, h=50, lr=0.01000, Loss=0.00883, test err=0.372
Trial 3, Epoch: 4000, RF model, h=50, lr=0.01000, Loss=0.00769, test err=0.371
Trial 3, Epoch: 4400, RF model, h=50, lr=0.01000, Loss=0.00679, test err=0.370
Trial 3, Epoch: 4800, RF model, h=50, lr=0.01000, Loss=0.00607, test err=0.370
Trial 3, Epoch: 5200, RF model, h=50, lr=0.01000, Loss=0.00548, test err=0.370
Trial 3, Epoch: 5600, RF model, h=50, lr=0.01000, Loss=0.00499, test err=0.369
Trial 3, Epoch: 6000, RF model, h=50, lr=0.01000, Loss=0.00458, test err=0.369
Trial 3, Epoch: 6400, RF model, h=50, lr=0.01000, Loss=0.00422, test err=0.368
Trial 3, Epoch: 6800, RF model, h=50, lr=0.01000, Loss=0.00392, test err=0.367
Trial 4, Epoch: 0, RF model, h=50, lr=0.01000, Loss=2.67528, test err=0.899
Trial 4, Epoch: 400, RF model, h=50, lr=0.01000, Loss=0

Trial 4, Epoch: 1600, He model, h=50, lr=0.01000, Loss=0.03679, test err=0.300
Trial 4, Epoch: 2000, He model, h=50, lr=0.01000, Loss=0.02511, test err=0.300
Trial 4, Epoch: 2400, He model, h=50, lr=0.01000, Loss=0.01870, test err=0.301
Trial 4, Epoch: 2800, He model, h=50, lr=0.01000, Loss=0.01473, test err=0.301
Trial 4, Epoch: 3200, He model, h=50, lr=0.01000, Loss=0.01206, test err=0.302
Trial 4, Epoch: 3600, He model, h=50, lr=0.01000, Loss=0.01016, test err=0.302
Trial 4, Epoch: 4000, He model, h=50, lr=0.01000, Loss=0.00874, test err=0.302
Trial 4, Epoch: 4400, He model, h=50, lr=0.01000, Loss=0.00765, test err=0.303
Trial 4, Epoch: 4800, He model, h=50, lr=0.01000, Loss=0.00678, test err=0.303
Trial 4, Epoch: 5200, He model, h=50, lr=0.01000, Loss=0.00608, test err=0.304
Trial 4, Epoch: 5600, He model, h=50, lr=0.01000, Loss=0.00551, test err=0.304
Trial 4, Epoch: 6000, He model, h=50, lr=0.01000, Loss=0.00502, test err=0.304
Trial 4, Epoch: 6400, He model, h=50, lr=0.01000, Lo

Trial 0, Epoch: 400, RF model, h=50, lr=0.10000, Loss=0.00796, test err=0.357
Trial 0, Epoch: 800, RF model, h=50, lr=0.10000, Loss=0.00329, test err=0.354
Trial 0, Epoch: 1200, RF model, h=50, lr=0.10000, Loss=0.00201, test err=0.352
Trial 0, Epoch: 1600, RF model, h=50, lr=0.10000, Loss=0.00142, test err=0.350
Trial 0, Epoch: 2000, RF model, h=50, lr=0.10000, Loss=0.00109, test err=0.350
Trial 0, Epoch: 2400, RF model, h=50, lr=0.10000, Loss=0.00088, test err=0.349
Trial 0, Epoch: 2800, RF model, h=50, lr=0.10000, Loss=0.00073, test err=0.349
Trial 0, Epoch: 3200, RF model, h=50, lr=0.10000, Loss=0.00063, test err=0.348
Trial 0, Epoch: 3600, RF model, h=50, lr=0.10000, Loss=0.00055, test err=0.347
Trial 0, Epoch: 4000, RF model, h=50, lr=0.10000, Loss=0.00049, test err=0.348
Trial 0, Epoch: 4400, RF model, h=50, lr=0.10000, Loss=0.00044, test err=0.347
Trial 0, Epoch: 4800, RF model, h=50, lr=0.10000, Loss=0.00039, test err=0.346
Trial 0, Epoch: 5200, RF model, h=50, lr=0.10000, Loss

Trial 0, Epoch: 6400, He model, h=50, lr=0.10000, Loss=0.00029, test err=0.312
Trial 0, Epoch: 6800, He model, h=50, lr=0.10000, Loss=0.00027, test err=0.312
Trial 1, Epoch: 0, He model, h=50, lr=0.10000, Loss=2.29066, test err=0.902
Trial 1, Epoch: 400, He model, h=50, lr=0.10000, Loss=0.00858, test err=0.309
Trial 1, Epoch: 800, He model, h=50, lr=0.10000, Loss=0.00342, test err=0.309
Trial 1, Epoch: 1200, He model, h=50, lr=0.10000, Loss=0.00206, test err=0.309
Trial 1, Epoch: 1600, He model, h=50, lr=0.10000, Loss=0.00145, test err=0.309
Trial 1, Epoch: 2000, He model, h=50, lr=0.10000, Loss=0.00111, test err=0.310
Trial 1, Epoch: 2400, He model, h=50, lr=0.10000, Loss=0.00090, test err=0.309
Trial 1, Epoch: 2800, He model, h=50, lr=0.10000, Loss=0.00075, test err=0.309
Trial 1, Epoch: 3200, He model, h=50, lr=0.10000, Loss=0.00064, test err=0.309
Trial 1, Epoch: 3600, He model, h=50, lr=0.10000, Loss=0.00056, test err=0.309
Trial 1, Epoch: 4000, He model, h=50, lr=0.10000, Loss=0.

Trial 1, Epoch: 5200, V1 model, h=50, lr=1.00000, Loss=0.00003, test err=0.359
Trial 1, Epoch: 5600, V1 model, h=50, lr=1.00000, Loss=0.00003, test err=0.359
Trial 1, Epoch: 6000, V1 model, h=50, lr=1.00000, Loss=0.00003, test err=0.358
Trial 1, Epoch: 6400, V1 model, h=50, lr=1.00000, Loss=0.00002, test err=0.358
Trial 1, Epoch: 6800, V1 model, h=50, lr=1.00000, Loss=0.00002, test err=0.357
Trial 2, Epoch: 0, V1 model, h=50, lr=1.00000, Loss=3.13943, test err=0.826
Trial 2, Epoch: 400, V1 model, h=50, lr=1.00000, Loss=0.00055, test err=0.385
Trial 2, Epoch: 800, V1 model, h=50, lr=1.00000, Loss=0.00025, test err=0.382
Trial 2, Epoch: 1200, V1 model, h=50, lr=1.00000, Loss=0.00016, test err=0.379
Trial 2, Epoch: 1600, V1 model, h=50, lr=1.00000, Loss=0.00011, test err=0.378
Trial 2, Epoch: 2000, V1 model, h=50, lr=1.00000, Loss=0.00009, test err=0.376
Trial 2, Epoch: 2400, V1 model, h=50, lr=1.00000, Loss=0.00007, test err=0.375
Trial 2, Epoch: 2800, V1 model, h=50, lr=1.00000, Loss=0.

Trial 2, Epoch: 4000, RF model, h=50, lr=1.00000, Loss=0.00004, test err=0.353
Trial 2, Epoch: 4400, RF model, h=50, lr=1.00000, Loss=0.00003, test err=0.352
Trial 2, Epoch: 4800, RF model, h=50, lr=1.00000, Loss=0.00003, test err=0.352
Trial 2, Epoch: 5200, RF model, h=50, lr=1.00000, Loss=0.00003, test err=0.352
Trial 2, Epoch: 5600, RF model, h=50, lr=1.00000, Loss=0.00003, test err=0.352
Trial 2, Epoch: 6000, RF model, h=50, lr=1.00000, Loss=0.00002, test err=0.352
Trial 2, Epoch: 6400, RF model, h=50, lr=1.00000, Loss=0.00002, test err=0.351
Trial 2, Epoch: 6800, RF model, h=50, lr=1.00000, Loss=0.00002, test err=0.351
Trial 3, Epoch: 0, RF model, h=50, lr=1.00000, Loss=2.55112, test err=0.717
Trial 3, Epoch: 400, RF model, h=50, lr=1.00000, Loss=0.00043, test err=0.398
Trial 3, Epoch: 800, RF model, h=50, lr=1.00000, Loss=0.00020, test err=0.393
Trial 3, Epoch: 1200, RF model, h=50, lr=1.00000, Loss=0.00013, test err=0.392
Trial 3, Epoch: 1600, RF model, h=50, lr=1.00000, Loss=0.

Trial 3, Epoch: 2800, He model, h=50, lr=1.00000, Loss=0.00005, test err=0.325
Trial 3, Epoch: 3200, He model, h=50, lr=1.00000, Loss=0.00004, test err=0.324
Trial 3, Epoch: 3600, He model, h=50, lr=1.00000, Loss=0.00004, test err=0.325
Trial 3, Epoch: 4000, He model, h=50, lr=1.00000, Loss=0.00003, test err=0.325
Trial 3, Epoch: 4400, He model, h=50, lr=1.00000, Loss=0.00003, test err=0.325
Trial 3, Epoch: 4800, He model, h=50, lr=1.00000, Loss=0.00003, test err=0.324
Trial 3, Epoch: 5200, He model, h=50, lr=1.00000, Loss=0.00002, test err=0.325
Trial 3, Epoch: 5600, He model, h=50, lr=1.00000, Loss=0.00002, test err=0.325
Trial 3, Epoch: 6000, He model, h=50, lr=1.00000, Loss=0.00002, test err=0.325
Trial 3, Epoch: 6400, He model, h=50, lr=1.00000, Loss=0.00002, test err=0.325
Trial 3, Epoch: 6800, He model, h=50, lr=1.00000, Loss=0.00002, test err=0.325
Trial 4, Epoch: 0, He model, h=50, lr=1.00000, Loss=2.31845, test err=0.672
Trial 4, Epoch: 400, He model, h=50, lr=1.00000, Loss=0

Trial 4, Epoch: 1200, V1 model, h=100, lr=0.00100, Loss=0.18839, test err=0.343
Trial 4, Epoch: 1600, V1 model, h=100, lr=0.00100, Loss=0.14128, test err=0.340
Trial 4, Epoch: 2000, V1 model, h=100, lr=0.00100, Loss=0.11235, test err=0.336
Trial 4, Epoch: 2400, V1 model, h=100, lr=0.00100, Loss=0.09286, test err=0.335
Trial 4, Epoch: 2800, V1 model, h=100, lr=0.00100, Loss=0.07887, test err=0.334
Trial 4, Epoch: 3200, V1 model, h=100, lr=0.00100, Loss=0.06837, test err=0.333
Trial 4, Epoch: 3600, V1 model, h=100, lr=0.00100, Loss=0.06022, test err=0.331
Trial 4, Epoch: 4000, V1 model, h=100, lr=0.00100, Loss=0.05372, test err=0.330
Trial 4, Epoch: 4400, V1 model, h=100, lr=0.00100, Loss=0.04842, test err=0.330
Trial 4, Epoch: 4800, V1 model, h=100, lr=0.00100, Loss=0.04402, test err=0.330
Trial 4, Epoch: 5200, V1 model, h=100, lr=0.00100, Loss=0.04031, test err=0.329
Trial 4, Epoch: 5600, V1 model, h=100, lr=0.00100, Loss=0.03715, test err=0.329
Trial 4, Epoch: 6000, V1 model, h=100, l

Trial 4, Epoch: 6400, RF model, h=100, lr=0.00100, Loss=0.08301, test err=0.385
Trial 4, Epoch: 6800, RF model, h=100, lr=0.00100, Loss=0.07574, test err=0.383
Trial 0, Epoch: 0, He model, h=100, lr=0.00100, Loss=2.29144, test err=0.902
Trial 0, Epoch: 400, He model, h=100, lr=0.00100, Loss=2.09888, test err=0.763
Trial 0, Epoch: 800, He model, h=100, lr=0.00100, Loss=1.91540, test err=0.602
Trial 0, Epoch: 1200, He model, h=100, lr=0.00100, Loss=1.71473, test err=0.493
Trial 0, Epoch: 1600, He model, h=100, lr=0.00100, Loss=1.49344, test err=0.425
Trial 0, Epoch: 2000, He model, h=100, lr=0.00100, Loss=1.26591, test err=0.386
Trial 0, Epoch: 2400, He model, h=100, lr=0.00100, Loss=1.04882, test err=0.364
Trial 0, Epoch: 2800, He model, h=100, lr=0.00100, Loss=0.85596, test err=0.349
Trial 0, Epoch: 3200, He model, h=100, lr=0.00100, Loss=0.69656, test err=0.338
Trial 0, Epoch: 3600, He model, h=100, lr=0.00100, Loss=0.56962, test err=0.334
Trial 0, Epoch: 4000, He model, h=100, lr=0.0

Trial 0, Epoch: 4400, V1 model, h=100, lr=0.01000, Loss=0.00399, test err=0.344
Trial 0, Epoch: 4800, V1 model, h=100, lr=0.01000, Loss=0.00360, test err=0.343
Trial 0, Epoch: 5200, V1 model, h=100, lr=0.01000, Loss=0.00329, test err=0.343
Trial 0, Epoch: 5600, V1 model, h=100, lr=0.01000, Loss=0.00302, test err=0.343
Trial 0, Epoch: 6000, V1 model, h=100, lr=0.01000, Loss=0.00279, test err=0.343
Trial 0, Epoch: 6400, V1 model, h=100, lr=0.01000, Loss=0.00259, test err=0.343
Trial 0, Epoch: 6800, V1 model, h=100, lr=0.01000, Loss=0.00241, test err=0.343
Trial 1, Epoch: 0, V1 model, h=100, lr=0.01000, Loss=4.22939, test err=0.907
Trial 1, Epoch: 400, V1 model, h=100, lr=0.01000, Loss=0.05314, test err=0.339
Trial 1, Epoch: 800, V1 model, h=100, lr=0.01000, Loss=0.02431, test err=0.336
Trial 1, Epoch: 1200, V1 model, h=100, lr=0.01000, Loss=0.01535, test err=0.334
Trial 1, Epoch: 1600, V1 model, h=100, lr=0.01000, Loss=0.01109, test err=0.333
Trial 1, Epoch: 2000, V1 model, h=100, lr=0.0

Trial 1, Epoch: 2400, RF model, h=100, lr=0.01000, Loss=0.01243, test err=0.358
Trial 1, Epoch: 2800, RF model, h=100, lr=0.01000, Loss=0.01019, test err=0.357
Trial 1, Epoch: 3200, RF model, h=100, lr=0.01000, Loss=0.00860, test err=0.356
Trial 1, Epoch: 3600, RF model, h=100, lr=0.01000, Loss=0.00742, test err=0.355
Trial 1, Epoch: 4000, RF model, h=100, lr=0.01000, Loss=0.00651, test err=0.355
Trial 1, Epoch: 4400, RF model, h=100, lr=0.01000, Loss=0.00579, test err=0.354
Trial 1, Epoch: 4800, RF model, h=100, lr=0.01000, Loss=0.00520, test err=0.354
Trial 1, Epoch: 5200, RF model, h=100, lr=0.01000, Loss=0.00472, test err=0.354
Trial 1, Epoch: 5600, RF model, h=100, lr=0.01000, Loss=0.00431, test err=0.354
Trial 1, Epoch: 6000, RF model, h=100, lr=0.01000, Loss=0.00397, test err=0.353
Trial 1, Epoch: 6400, RF model, h=100, lr=0.01000, Loss=0.00367, test err=0.353
Trial 1, Epoch: 6800, RF model, h=100, lr=0.01000, Loss=0.00342, test err=0.353
Trial 2, Epoch: 0, RF model, h=100, lr=0

Trial 2, Epoch: 400, He model, h=100, lr=0.01000, Loss=0.51279, test err=0.321
Trial 2, Epoch: 800, He model, h=100, lr=0.01000, Loss=0.12299, test err=0.310
Trial 2, Epoch: 1200, He model, h=100, lr=0.01000, Loss=0.05574, test err=0.313
Trial 2, Epoch: 1600, He model, h=100, lr=0.01000, Loss=0.03360, test err=0.313
Trial 2, Epoch: 2000, He model, h=100, lr=0.01000, Loss=0.02335, test err=0.312
Trial 2, Epoch: 2400, He model, h=100, lr=0.01000, Loss=0.01760, test err=0.312
Trial 2, Epoch: 2800, He model, h=100, lr=0.01000, Loss=0.01399, test err=0.312
Trial 2, Epoch: 3200, He model, h=100, lr=0.01000, Loss=0.01153, test err=0.311
Trial 2, Epoch: 3600, He model, h=100, lr=0.01000, Loss=0.00976, test err=0.312
Trial 2, Epoch: 4000, He model, h=100, lr=0.01000, Loss=0.00843, test err=0.312
Trial 2, Epoch: 4400, He model, h=100, lr=0.01000, Loss=0.00740, test err=0.312
Trial 2, Epoch: 4800, He model, h=100, lr=0.01000, Loss=0.00658, test err=0.312
Trial 2, Epoch: 5200, He model, h=100, lr=

Trial 2, Epoch: 5600, V1 model, h=100, lr=0.10000, Loss=0.00021, test err=0.323
Trial 2, Epoch: 6000, V1 model, h=100, lr=0.10000, Loss=0.00019, test err=0.323
Trial 2, Epoch: 6400, V1 model, h=100, lr=0.10000, Loss=0.00018, test err=0.323
Trial 2, Epoch: 6800, V1 model, h=100, lr=0.10000, Loss=0.00017, test err=0.323
Trial 3, Epoch: 0, V1 model, h=100, lr=0.10000, Loss=4.49016, test err=0.803
Trial 3, Epoch: 400, V1 model, h=100, lr=0.10000, Loss=0.00291, test err=0.366
Trial 3, Epoch: 800, V1 model, h=100, lr=0.10000, Loss=0.00147, test err=0.367
Trial 3, Epoch: 1200, V1 model, h=100, lr=0.10000, Loss=0.00098, test err=0.366
Trial 3, Epoch: 1600, V1 model, h=100, lr=0.10000, Loss=0.00073, test err=0.366
Trial 3, Epoch: 2000, V1 model, h=100, lr=0.10000, Loss=0.00058, test err=0.366
Trial 3, Epoch: 2400, V1 model, h=100, lr=0.10000, Loss=0.00048, test err=0.365
Trial 3, Epoch: 2800, V1 model, h=100, lr=0.10000, Loss=0.00041, test err=0.365
Trial 3, Epoch: 3200, V1 model, h=100, lr=0.1

Trial 3, Epoch: 3600, RF model, h=100, lr=0.10000, Loss=0.00050, test err=0.351
Trial 3, Epoch: 4000, RF model, h=100, lr=0.10000, Loss=0.00045, test err=0.351
Trial 3, Epoch: 4400, RF model, h=100, lr=0.10000, Loss=0.00040, test err=0.351
Trial 3, Epoch: 4800, RF model, h=100, lr=0.10000, Loss=0.00036, test err=0.350
Trial 3, Epoch: 5200, RF model, h=100, lr=0.10000, Loss=0.00033, test err=0.350
Trial 3, Epoch: 5600, RF model, h=100, lr=0.10000, Loss=0.00031, test err=0.350
Trial 3, Epoch: 6000, RF model, h=100, lr=0.10000, Loss=0.00028, test err=0.350
Trial 3, Epoch: 6400, RF model, h=100, lr=0.10000, Loss=0.00026, test err=0.350
Trial 3, Epoch: 6800, RF model, h=100, lr=0.10000, Loss=0.00025, test err=0.350
Trial 4, Epoch: 0, RF model, h=100, lr=0.10000, Loss=2.41857, test err=0.827
Trial 4, Epoch: 400, RF model, h=100, lr=0.10000, Loss=0.00676, test err=0.338
Trial 4, Epoch: 800, RF model, h=100, lr=0.10000, Loss=0.00290, test err=0.335
Trial 4, Epoch: 1200, RF model, h=100, lr=0.1

Trial 4, Epoch: 1600, He model, h=100, lr=0.10000, Loss=0.00144, test err=0.308
Trial 4, Epoch: 2000, He model, h=100, lr=0.10000, Loss=0.00110, test err=0.308
Trial 4, Epoch: 2400, He model, h=100, lr=0.10000, Loss=0.00089, test err=0.308
Trial 4, Epoch: 2800, He model, h=100, lr=0.10000, Loss=0.00074, test err=0.309
Trial 4, Epoch: 3200, He model, h=100, lr=0.10000, Loss=0.00063, test err=0.309
Trial 4, Epoch: 3600, He model, h=100, lr=0.10000, Loss=0.00055, test err=0.309
Trial 4, Epoch: 4000, He model, h=100, lr=0.10000, Loss=0.00049, test err=0.309
Trial 4, Epoch: 4400, He model, h=100, lr=0.10000, Loss=0.00044, test err=0.309
Trial 4, Epoch: 4800, He model, h=100, lr=0.10000, Loss=0.00040, test err=0.310
Trial 4, Epoch: 5200, He model, h=100, lr=0.10000, Loss=0.00036, test err=0.310
Trial 4, Epoch: 5600, He model, h=100, lr=0.10000, Loss=0.00033, test err=0.310
Trial 4, Epoch: 6000, He model, h=100, lr=0.10000, Loss=0.00031, test err=0.310
Trial 4, Epoch: 6400, He model, h=100, l

Trial 4, Epoch: 6800, V1 model, h=100, lr=1.00000, Loss=0.00002, test err=0.416
Trial 0, Epoch: 0, RF model, h=100, lr=1.00000, Loss=2.55422, test err=0.734
Trial 0, Epoch: 400, RF model, h=100, lr=1.00000, Loss=0.00045, test err=0.364
Trial 0, Epoch: 800, RF model, h=100, lr=1.00000, Loss=0.00020, test err=0.361
Trial 0, Epoch: 1200, RF model, h=100, lr=1.00000, Loss=0.00013, test err=0.359
Trial 0, Epoch: 1600, RF model, h=100, lr=1.00000, Loss=0.00009, test err=0.358
Trial 0, Epoch: 2000, RF model, h=100, lr=1.00000, Loss=0.00007, test err=0.357
Trial 0, Epoch: 2400, RF model, h=100, lr=1.00000, Loss=0.00006, test err=0.357
Trial 0, Epoch: 2800, RF model, h=100, lr=1.00000, Loss=0.00005, test err=0.356
Trial 0, Epoch: 3200, RF model, h=100, lr=1.00000, Loss=0.00004, test err=0.356
Trial 0, Epoch: 3600, RF model, h=100, lr=1.00000, Loss=0.00004, test err=0.356
Trial 0, Epoch: 4000, RF model, h=100, lr=1.00000, Loss=0.00003, test err=0.356
Trial 0, Epoch: 4400, RF model, h=100, lr=1.0

Trial 0, Epoch: 4800, He model, h=100, lr=1.00000, Loss=0.00003, test err=0.322
Trial 0, Epoch: 5200, He model, h=100, lr=1.00000, Loss=0.00002, test err=0.322
Trial 0, Epoch: 5600, He model, h=100, lr=1.00000, Loss=0.00002, test err=0.322
Trial 0, Epoch: 6000, He model, h=100, lr=1.00000, Loss=0.00002, test err=0.321
Trial 0, Epoch: 6400, He model, h=100, lr=1.00000, Loss=0.00002, test err=0.321
Trial 0, Epoch: 6800, He model, h=100, lr=1.00000, Loss=0.00002, test err=0.321
Trial 1, Epoch: 0, He model, h=100, lr=1.00000, Loss=2.32377, test err=0.608
Trial 1, Epoch: 400, He model, h=100, lr=1.00000, Loss=0.00042, test err=0.316
Trial 1, Epoch: 800, He model, h=100, lr=1.00000, Loss=0.00019, test err=0.316
Trial 1, Epoch: 1200, He model, h=100, lr=1.00000, Loss=0.00012, test err=0.317
Trial 1, Epoch: 1600, He model, h=100, lr=1.00000, Loss=0.00009, test err=0.317
Trial 1, Epoch: 2000, He model, h=100, lr=1.00000, Loss=0.00007, test err=0.318
Trial 1, Epoch: 2400, He model, h=100, lr=1.0

Trial 1, Epoch: 2800, V1 model, h=400, lr=0.00100, Loss=0.01861, test err=0.302
Trial 1, Epoch: 3200, V1 model, h=400, lr=0.00100, Loss=0.01622, test err=0.302
Trial 1, Epoch: 3600, V1 model, h=400, lr=0.00100, Loss=0.01437, test err=0.302
Trial 1, Epoch: 4000, V1 model, h=400, lr=0.00100, Loss=0.01290, test err=0.301
Trial 1, Epoch: 4400, V1 model, h=400, lr=0.00100, Loss=0.01170, test err=0.301
Trial 1, Epoch: 4800, V1 model, h=400, lr=0.00100, Loss=0.01070, test err=0.301
Trial 1, Epoch: 5200, V1 model, h=400, lr=0.00100, Loss=0.00985, test err=0.300
Trial 1, Epoch: 5600, V1 model, h=400, lr=0.00100, Loss=0.00913, test err=0.301
Trial 1, Epoch: 6000, V1 model, h=400, lr=0.00100, Loss=0.00851, test err=0.301
Trial 1, Epoch: 6400, V1 model, h=400, lr=0.00100, Loss=0.00796, test err=0.301
Trial 1, Epoch: 6800, V1 model, h=400, lr=0.00100, Loss=0.00748, test err=0.301
Trial 2, Epoch: 0, V1 model, h=400, lr=0.00100, Loss=4.05967, test err=0.927
Trial 2, Epoch: 400, V1 model, h=400, lr=0.

Trial 2, Epoch: 800, RF model, h=400, lr=0.00100, Loss=0.38408, test err=0.374
Trial 2, Epoch: 1200, RF model, h=400, lr=0.00100, Loss=0.23529, test err=0.366
Trial 2, Epoch: 1600, RF model, h=400, lr=0.00100, Loss=0.16357, test err=0.360
Trial 2, Epoch: 2000, RF model, h=400, lr=0.00100, Loss=0.12315, test err=0.359
Trial 2, Epoch: 2400, RF model, h=400, lr=0.00100, Loss=0.09777, test err=0.358
Trial 2, Epoch: 2800, RF model, h=400, lr=0.00100, Loss=0.08056, test err=0.357
Trial 2, Epoch: 3200, RF model, h=400, lr=0.00100, Loss=0.06824, test err=0.355
Trial 2, Epoch: 3600, RF model, h=400, lr=0.00100, Loss=0.05901, test err=0.354
Trial 2, Epoch: 4000, RF model, h=400, lr=0.00100, Loss=0.05188, test err=0.354
Trial 2, Epoch: 4400, RF model, h=400, lr=0.00100, Loss=0.04622, test err=0.354
Trial 2, Epoch: 4800, RF model, h=400, lr=0.00100, Loss=0.04162, test err=0.353
Trial 2, Epoch: 5200, RF model, h=400, lr=0.00100, Loss=0.03781, test err=0.352
Trial 2, Epoch: 5600, RF model, h=400, lr

Trial 2, Epoch: 6000, He model, h=400, lr=0.00100, Loss=0.13588, test err=0.294
Trial 2, Epoch: 6400, He model, h=400, lr=0.00100, Loss=0.12109, test err=0.294
Trial 2, Epoch: 6800, He model, h=400, lr=0.00100, Loss=0.10874, test err=0.295
Trial 3, Epoch: 0, He model, h=400, lr=0.00100, Loss=2.27660, test err=0.878
Trial 3, Epoch: 400, He model, h=400, lr=0.00100, Loss=1.91469, test err=0.543
Trial 3, Epoch: 800, He model, h=400, lr=0.00100, Loss=1.58201, test err=0.402
Trial 3, Epoch: 1200, He model, h=400, lr=0.00100, Loss=1.27703, test err=0.370
Trial 3, Epoch: 1600, He model, h=400, lr=0.00100, Loss=1.01422, test err=0.353
Trial 3, Epoch: 2000, He model, h=400, lr=0.00100, Loss=0.80145, test err=0.345
Trial 3, Epoch: 2400, He model, h=400, lr=0.00100, Loss=0.63677, test err=0.337
Trial 3, Epoch: 2800, He model, h=400, lr=0.00100, Loss=0.51108, test err=0.332
Trial 3, Epoch: 3200, He model, h=400, lr=0.00100, Loss=0.41516, test err=0.326
Trial 3, Epoch: 3600, He model, h=400, lr=0.0

In [18]:
import pickle
import numpy as np
import matplotlib.pyplot as plt

t, l, n_epochs= 5, 3, 7000
models = ['V1', 'RF', 'He']

for scale in [1/8, 1/16, 1/32]:
    for h in [50, 100, 400, 1000]:
        for lr in [1e-3, 1e-2, 1e-1, 1e0]:

            with open('results/initialize_mnist/clf_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f_scale=%0.6f_fewshot.pickle' % (t, l, 
                                                                                                         h, lr, 
                                                                                                         scale), 
                      'rb') as handle:
                sims = pickle.load(handle)

            avg_train_err = {m: np.mean(sims['train_err'][m], axis=0) for m in models}
            avg_test_err = {m: np.mean(sims['test_err'][m], axis=0) for m in models}
            avg_loss_list = {m: np.mean(sims['loss'][m], axis=0) for m in models}

            std_train_err = {m: np.std(sims['train_err'][m], axis=0) for m in models}
            std_test_err = {m: np.std(sims['train_err'][m], axis=0) for m in models}
            std_loss_list = {m: np.std(sims['train_err'][m], axis=0) for m in models}

            fig = plt.figure(figsize=(12, 5))
            plt.suptitle(r'Shallow FFW FC net w/ GD. h=%d, scale=%0.5f, lr=%0.4f, '%(h, scale, lr), fontsize=16)
            ax = fig.add_subplot(131)
            plt.title('Network loss', fontsize=16)
            plt.plot(np.arange(n_epochs), avg_loss_list['V1'], label='V1-inspired', lw=3)
            plt.plot(np.arange(n_epochs), avg_loss_list['RF'], label='classical', lw=3)
            plt.plot(np.arange(n_epochs), avg_loss_list['He'], label='He', lw=3)

            plt.fill_between(np.arange(n_epochs), avg_loss_list['V1'] - std_loss_list['V1'], 
                             avg_loss_list['V1'] + std_loss_list['V1'], alpha=0.2)
            plt.fill_between(np.arange(n_epochs), avg_loss_list['RF'] - std_loss_list['RF'], 
                             avg_loss_list['RF'] + std_loss_list['RF'], alpha=0.2)
            plt.fill_between(np.arange(n_epochs), avg_loss_list['He'] - std_loss_list['He'], 
                             avg_loss_list['He'] + std_loss_list['He'], alpha=0.2)

            plt.xlabel('Epoch', fontsize=20)
            plt.ylabel('Training loss', fontsize=20)
            ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
            plt.yscale('log')
            plt.legend(fontsize=18)

            ax = fig.add_subplot(132)
            plt.title('Train error', fontsize=16)
            plt.plot(np.arange(n_epochs), avg_train_err['V1'], label='V1-inspired', lw=3)
            plt.plot(np.arange(n_epochs), avg_train_err['RF'],  label='classical', lw=3)
            plt.plot(np.arange(n_epochs), avg_train_err['He'],  label='He', lw=3)
            plt.fill_between(np.arange(n_epochs), avg_train_err['V1'] - std_train_err['V1'], 
                             avg_train_err['V1'] + std_train_err['V1'],  alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_train_err['RF'] - std_train_err['RF'], 
                             avg_train_err['RF'] + std_train_err['RF'],  alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_train_err['He'] - std_train_err['He'], 
                             avg_train_err['He'] + std_train_err['He'],  alpha=0.2 )
            plt.xlabel('Epoch', fontsize=20)
            plt.ylabel('Training error', fontsize=20)
            ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
            plt.yticks(np.arange(0, 1, 0.2))
            plt.yscale('log')
            plt.legend(fontsize=18)

            ax = fig.add_subplot(133)
            plt.title('Test error', fontsize=16)
            plt.plot(np.arange(n_epochs), avg_test_err['V1'], label='V1-inspired', lw=3)
            plt.plot(np.arange(n_epochs), avg_test_err['RF'], label='classical', lw=3)
            plt.plot(np.arange(n_epochs), avg_test_err['He'], label='He', lw=3)
            plt.fill_between(np.arange(n_epochs), avg_test_err['V1'] - std_test_err['V1'], 
                             avg_test_err['V1'] + std_test_err['V1'], alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_test_err['RF'] - std_test_err['RF'], 
                             avg_test_err['RF'] + std_test_err['RF'], alpha=0.2 )
            plt.fill_between(np.arange(n_epochs), avg_test_err['He'] - std_test_err['He'], 
                             avg_test_err['He'] + std_test_err['He'], alpha=0.2 )
            plt.xlabel('Epoch', fontsize=20)
            plt.ylabel('Test error', fontsize=20)
            ax.tick_params(axis = 'both', which = 'major', labelsize = 14, width=2, length=6)
            plt.yticks(np.arange(0, 1, 0.2))
            plt.yscale('log')
            plt.legend(fontsize=18)

            plt.tight_layout()
            plt.subplots_adjust(top=0.8)    

            print(scale, h, lr, 'Test err, V1: %0.4f, RF: %0.4f, He: %0.4f' % (avg_test_err['V1'][-1], 
                                                                 avg_test_err['RF'][-1],
                                                                avg_test_err['He'][-1]))
            
            plt.savefig('results/initialize_mnist/figures_fewshot/init_t=%0.2f_l=%0.2f_h=%d_lr=%0.4f_scale=%0.6f.png' % (t, 
                                                                                                                 l,h, 
                                                                                                                 lr, 
                                                                                                             scale))
            plt.close()

            


0.125 50 0.001 Test err, V1: 0.3976, RF: 0.4645, He: 0.3243
0.125 50 0.01 Test err, V1: 0.4074, RF: 0.4192, He: 0.3207
0.125 50 0.1 Test err, V1: 0.4009, RF: 0.4288, He: 0.3171
0.125 50 1.0 Test err, V1: 0.4090, RF: 0.4169, He: 0.3208
0.125 100 0.001 Test err, V1: 0.3541, RF: 0.4149, He: 0.3188
0.125 100 0.01 Test err, V1: 0.3470, RF: 0.4131, He: 0.3204
0.125 100 0.1 Test err, V1: 0.3553, RF: 0.3925, He: 0.3126
0.125 100 1.0 Test err, V1: 0.4366, RF: 0.4205, He: 0.3153
0.125 400 0.001 Test err, V1: 0.3004, RF: 0.3591, He: 0.3166
0.125 400 0.01 Test err, V1: 0.2928, RF: 0.3441, He: 0.3126
0.125 400 0.1 Test err, V1: 0.2975, RF: 0.3324, He: 0.3112
0.125 400 1.0 Test err, V1: 0.4303, RF: 0.4051, He: 0.3092
0.125 1000 0.001 Test err, V1: 0.2933, RF: 0.3388, He: 0.3080
0.125 1000 0.01 Test err, V1: 0.2722, RF: 0.3322, He: 0.3082
0.125 1000 0.1 Test err, V1: 0.3279, RF: 0.3685, He: 0.3124
0.125 1000 1.0 Test err, V1: 0.3901, RF: 0.4476, He: 0.3082
0.0625 50 0.001 Test err, V1: 0.4095, RF: 0.