In [None]:
from os import getcwd
from os.path import abspath, join

import numpy as np
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

from src.data.load_dataset import load_mnist
from src.models.networks import V1_mnist_RFNet, classical_RFNet
from src.models.utils import train, test

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

In [None]:
# load data
train_batch_size, train_percentage = 1024, 0.99
train_loader, val_loader, test_loader = load_mnist(train_batch_size, train_percentage)

### V1 RFNet with $s=5.0$, $f=2.0$

In [None]:
num_neurons = sorted(set(np.logspace(0, 3.4, 50).astype('int')))
s, f, c = 5.0, 2.0, None
lr = 0.01
num_epochs = 10
log_interval = 100

test_v1 = {'hidden_size': [], 'mean': [], 'std': []}
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(1):
        model = V1_mnist_RFNet(hidden_size, s, f, c).to(device)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
        accuracy.append(test(model, device, test_loader, verbose=False))
        
    test_v1['hidden_size'].append(hidden_size)
    test_v1['mean'].append(np.mean(accuracy))
    test_v1['std'].append(np.std(accuracy))

### Classical RFNet

In [None]:
num_neurons = sorted(set(np.logspace(0, 3.4, 50).astype('int')))
inp_size = (1, 28, 28)
lr = 0.01
num_epochs = 10
log_interval = 100

test_classical = {'hidden_size': [], 'mean': [], 'std': []} 
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(1):
        model = classical_RFNet(inp_size, hidden_size).to(device)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
        accuracy.append(test(model, device, test_loader, verbose=False))

    test_classical['hidden_size'].append(hidden_size)
    test_classical['mean'].append(np.mean(accuracy))
    test_classical['std'].append(np.std(accuracy))

### V1 RFNet with incompatible parameters $s=0.5$, $f=0.5$

In [None]:
num_neurons = sorted(set(np.logspace(0, 3.4, 50).astype('int')))
s, f, c = 0.5, 0.5, None
lr = 0.01
num_epochs = 10
log_interval = 100

test_incompatible = {'hidden_size': [], 'mean': [], 'std': []}
for hidden_size in tqdm(num_neurons):
    accuracy = []
    for trial in range(1):
        model = V1_mnist_RFNet(hidden_size, s, f, c).to(device)
        optimizer = optim.SGD(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            _ = train(log_interval, device, model, train_loader, optimizer, epoch, verbose=False)
        accuracy.append(test(model, device, test_loader, verbose=False))
        
    test_incompatible['hidden_size'].append(hidden_size)
    test_incompatible['mean'].append(np.mean(accuracy))
    test_incompatible['std'].append(np.std(accuracy))

In [None]:
# save
test = {'v1': test_v1, 'classical': test_classical, 'incompatible': test_incompatible}
data_dir = abspath(join(getcwd(), '../../'))
with open(data_dir + '/models/results/mnist_clf/mnist_clf_s=%0.2f_f=%0.2f.pickle' % (s, f), 'wb') as handle:
    pickle.dump(test, handle, protocol=pickle.HIGHEST_PROTOCOL) 

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(test_v1['hidden_size'], test_v1['mean'])
plt.plot(test_classical['hidden_size'], test_classical['mean'])
plt.plot(test_incompatible['hidden_size'], test_incompatible['mean'])
plt.xlim([0, 1000])

In [None]:
list(zip(test_v1['hidden_size'], test_v1['mean']))