In [67]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from src.data.load_dataset import load_mnist, load_kmnist
from src.models.networks import V1_mnist_RFNet, classical_RFNet
from src.models.utils import train, test

In [102]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### MNIST

In [95]:
train_loader, val_loader, test_loader = load_mnist(128, 0.1)

In [97]:
h, s, f, c = 100, 5, 2, None
model = V1_mnist_RFNet(h, s, f, c).to(device)

# hyperparams
lr = 1E-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# train
epochs = 5
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 72.421745. Accuracy: 6142/59940 (10.25%)


Test set: Average loss: 69.834061. Accuracy: 6213/59940 (10.37%)


Test set: Average loss: 67.277260. Accuracy: 6265/59940 (10.45%)


Test set: Average loss: 68.620743. Accuracy: 1014/10000 (10.14%)



In [101]:
## classical network
inp_size, hidden_size = (1, 28, 28), 100
model = classical_RFNet(inp_size, hidden_size, seed=10).to(device)

# optimizer
lr = 1E-1
optimizer = optim.Adam(model.parameters(), lr=lr)

# train
epochs = 5
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)



Test set: Average loss: 1.747105. Accuracy: 33799/59940 (56.39%)


Test set: Average loss: 2.020098. Accuracy: 33126/59940 (55.27%)


Test set: Average loss: 2.324098. Accuracy: 32765/59940 (54.66%)


Test set: Average loss: 2.192673. Accuracy: 33303/59940 (55.56%)


Test set: Average loss: 2.145921. Accuracy: 33166/59940 (55.33%)


Test set: Average loss: 2.434522. Accuracy: 31560/59940 (52.65%)


Test set: Average loss: 2.562590. Accuracy: 31382/59940 (52.36%)


Test set: Average loss: 2.410451. Accuracy: 32737/59940 (54.62%)


Test set: Average loss: 2.314633. Accuracy: 33617/59940 (56.08%)


Test set: Average loss: 2.271592. Accuracy: 33956/59940 (56.65%)


Test set: Average loss: 2.228732. Accuracy: 5772/10000 (57.72%)



### KMNIST

In [103]:
train_loader, val_loader, test_loader = load_kmnist(128, 0.9)

In [104]:
h, s, f, c = 100, 5, 2, None
model = V1_mnist_RFNet(h, s, f, c).to(device)

# hyperparams
lr = 1E-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# train
epochs = 5
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)


Test set: Average loss: 12.889578. Accuracy: 2444/6000 (40.73%)


Test set: Average loss: 7.022928. Accuracy: 3544/6000 (59.07%)


Test set: Average loss: 5.095811. Accuracy: 4004/6000 (66.73%)


Test set: Average loss: 4.010584. Accuracy: 4222/6000 (70.37%)


Test set: Average loss: 3.259418. Accuracy: 4355/6000 (72.58%)


Test set: Average loss: 5.553703. Accuracy: 5916/10000 (59.16%)



In [None]:
## classical network
inp_size, hidden_size = (1, 28, 28), 100
model = classical_RFNet(inp_size, hidden_size, seed=10).to(device)

# optimizer
lr = 1E-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# train
epochs = 5
log_interval = 5
for epoch in range(1, epochs + 1):
    train(log_interval, device, model, train_loader, optimizer, epoch, verbose=True)
    val_accuracy = test(model, device, val_loader)
# calculate and print test accuracy
test_accuracy = test(model, device, test_loader)

