In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

import networks
import train

import numpy as np
import matplotlib.pyplot as plt
import time

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))
                               ])

In [None]:
ROOT = './data'
BATCH = 64

train_data = datasets.MNIST(root=ROOT, download=True, train=True, transform=transform)
test_data = datasets.MNIST(root=ROOT, download=True, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH, shuffle=True)

loaders = {'train': train_loader,
           'test': test_loader}

In [None]:
data_iter = iter(train_loader)
images, labels = data_iter.next()

print(images.shape)
print(labels.shape)

plt.imshow(images[0].squeeze())

In [None]:
lenet = networks.LeNet()
print(lenet)

In [None]:
iternet = networks.IterNet(linear=False, n_iter=2)
print(iternet)

In [None]:
model_parameters = filter(lambda p: p.requires_grad, lenet.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

print(params)

In [None]:
model_parameters = filter(lambda p: p.requires_grad, iternet.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

print(params)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
TRIALS = 2
EPOCHS = 2

In [None]:
time_0 = time.time()
_, LN_Accuracies, _ = train.trial_evaluation(TRIALS, 
                                             EPOCHS, 
                                             loaders, 
                                             'LeNet', 
                                             device, 
                                             verbose=True)
time_elapsed = time.time()-time_0
print("Total: {:.2f}\tAverage: {:.2f}".format(time_elapsed, time_elapsed/TRIALS))

In [None]:
time_0 = time.time()
_, IN2_Accuracies, _ = train.trial_evaluation(TRIALS, 
                                              EPOCHS, 
                                              loaders, 
                                              'IterNet', 
                                              device, 
                                              linear=True,
                                              n_iter=2,
                                              verbose=True)
time_elapsed = time.time()-time_0
print("Total: {:.2f}\tAverage: {:.2f}".format(time_elapsed, time_elapsed/TRIALS))

In [None]:
time_0 = time.time()
_, IN3_Accuracies, _ = train.trial_evaluation(TRIALS, 
                                              EPOCHS, 
                                              loaders, 
                                              'IterNet', 
                                              device,
                                              linear=False,
                                              n_iter=3, 
                                              verbose=False)
time_elapsed = time.time()-time_0
print("Total: {:.2f}\tAverage: {:.2f}".format(time_elapsed, time_elapsed/TRIALS))

In [None]:
time_0 = time.time()
_, IN1_Accuracies, _ = train.trial_evaluation(TRIALS, 
                                              EPOCHS, 
                                              loaders, 
                                              'IterNet', 
                                              device,
                                              linear=False,
                                              n_iter=1, 
                                              verbose=False)
time_elapsed = time.time()-time_0
print("Total: {:.2f}\tAverage: {:.2f}".format(time_elapsed, time_elapsed/TRIALS))

In [None]:
train.summarize_trials(IN1_Accuracies)

In [None]:
train.summarize_trials(IN2_Accuracies)

In [None]:
IN3_Accuracies = np.array(IN3_Accuracies)

print("Max: ", np.max(IN3_Accuracies))
print("Min: ", np.min(IN3_Accuracies))
print("Mean: ", np.mean(IN3_Accuracies))
print("SDev: ", np.std(IN3_Accuracies))

In [None]:
LN_Accuracies = np.array(LN_Accuracies)

print("Max: ", np.max(LN_Accuracies))
print("Min: ", np.min(LN_Accuracies))
print("Mean: ", np.mean(LN_Accuracies))
print("SDev: ", np.std(LN_Accuracies))