In [1]:
from os import getcwd
from os.path import abspath, join

import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

from src.data.load_dataset import load_kmnist
from src.models.networks import V1_mnist_RFNet, classical_RFNet
from src.models.utils import train, test

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

In [3]:
# load data
train_batch_size, train_percentage = 64, 0.01
train_loader, val_loader, test_loader = load_kmnist(train_batch_size, train_percentage)

# training params
num_epochs = 10
step_size, gamma = 5, 0.1 # lr scheduler
num_trials = 50
log_interval = 100
num_neurons = sorted(set(np.logspace(0, 3.5, 50).astype('int')))
loss_fn = F.cross_entropy

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### V1 RFNet with optimized parameters

In [None]:
s, f, c = 5.34, 1.965, None
lr = 0.0031485838088746586

test_v1 = {'hidden_size': [], 'mean': [], 'std': []}
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(num_trials):
        model = V1_mnist_RFNet(hidden_size, s, f, c).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, loss_fn, verbose=False)
            scheduler.step()
        accuracy.append(test(model, device, test_loader, verbose=False))
        
    test_v1['hidden_size'].append(hidden_size)
    test_v1['mean'].append(np.mean(accuracy))
    test_v1['std'].append(np.std(accuracy))

  2%|▏         | 1/44 [02:55<2:05:40, 175.37s/it]

### Classical RFNet

In [None]:
inp_size = (1, 28, 28)
lr = 0.01922083004518646

test_classical = {'hidden_size': [], 'mean': [], 'std': []} 
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(num_trials):
        model = classical_RFNet(inp_size, hidden_size).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, loss_fn, verbose=False)
            scheduler.step()
        accuracy.append(test(model, device, test_loader, verbose=False))

    test_classical['hidden_size'].append(hidden_size)
    test_classical['mean'].append(np.mean(accuracy))
    test_classical['std'].append(np.std(accuracy))

### V1 RFNet with incompatible parameters $s=0.5$, $f=0.5$

In [None]:
s, f, c = 0.5, 0.5, None
lr = 0.0031485838088746586

test_incompatible = {'hidden_size': [], 'mean': [], 'std': []}
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(num_trials):
        model = V1_mnist_RFNet(hidden_size, s, f, c).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, loss_fn, verbose=False)
            scheduler.step()
        accuracy.append(test(model, device, test_loader, verbose=False))
        
    test_incompatible['hidden_size'].append(hidden_size)
    test_incompatible['mean'].append(np.mean(accuracy))
    test_incompatible['std'].append(np.std(accuracy))

In [6]:
# save
test_classical = {'hidden_size': [], 'mean': [], 'std': []} 
test_incompatible = {'hidden_size': [], 'mean': [], 'std': []}
test = {'v1': test_v1, 'classical': test_classical, 'incompatible': test_incompatible}
data_dir = abspath(join(getcwd(), '../../'))
with open(data_dir + '/models/results/kmnist_clf/kmnist_clf_s=%0.2f_f=%0.2f_fewshot_torch.pickle' % (s, f), 'wb') as handle:
    pickle.dump(test, handle, protocol=pickle.HIGHEST_PROTOCOL) 

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(test_v1['hidden_size'], test_v1['mean'])
plt.plot(test_classical['hidden_size'], test_classical['mean'])
plt.plot(test_incompatible['hidden_size'], test_incompatible['mean'])
plt.xlim([0, 1000])

In [5]:
list(zip(test_v1['hidden_size'], test_v1['mean']))

[(1, 10.733999999999998),
 (2, 10.3376),
 (3, 11.683000000000002),
 (4, 11.913400000000001),
 (5, 12.441200000000002),
 (6, 12.8752),
 (7, 13.2948),
 (8, 13.9174),
 (10, 15.0318),
 (11, 16.6122),
 (13, 17.998),
 (16, 19.477600000000002),
 (19, 21.809200000000004),
 (22, 23.952399999999997),
 (26, 26.952199999999998),
 (31, 29.332400000000003),
 (37, 32.281),
 (43, 34.942800000000005),
 (51, 37.3968),
 (61, 40.098800000000004),
 (71, 42.9318),
 (84, 44.24759999999999),
 (100, 47.1958),
 (117, 49.57019999999999),
 (138, 51.151399999999995),
 (163, 52.84960000000001),
 (193, 54.490199999999994),
 (227, 56.096599999999995),
 (268, 57.35560000000001),
 (316, 58.5168),
 (372, 59.386999999999986),
 (439, 60.1482),
 (517, 61.157399999999996),
 (610, 61.63539999999999),
 (719, 62.345600000000005),
 (848, 62.6916)]