In [11]:
import torchvision
from torch.utils.data import Subset, DataLoader
import torch
import torchvision
import torch.nn.functional as F

import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import math
import tqdm
from tqdm import tqdm as tqdm
from models import *
%matplotlib inline
import random

In [12]:
dataset = torchvision.datasets.MNIST('./tmp', train=True, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

dataset_test = torchvision.datasets.MNIST('./tmp', train=False, download=False,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

In [13]:
batch_size=128

# Split the indices in a stratified way
indices = np.arange(len(dataset))

train_indices = indices[0:int(np.floor(60000/batch_size))*batch_size]
test_indices = indices[0:int(np.floor(10000/batch_size))*batch_size]

train_indices = indices[0:int(np.floor(200/batch_size))*batch_size]
test_indices = indices[0:int(np.floor(200/batch_size))*batch_size]

# Warp into Subsets and DataLoaders
train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset_test, test_indices)

train_loader = DataLoader(train_dataset, shuffle=True, num_workers=2, batch_size=batch_size)
test_loader = DataLoader(test_dataset, shuffle=False, num_workers=2, batch_size=batch_size)


In [14]:
# Functions for one batch update of the network

def train_step(hidden_size, input_tensor, target):
    hidden_bottom_0 = torch.zeros(1,hidden_size).to(device)
    hidden_top_0 = torch.zeros(1,hidden_size).to(device)
    frnn.zero_grad()

    for i in range(input_tensor.shape[-1]):
    #for i in range(300):
        out, hidden_bottom_0, hidden_top_0 = frnn(input_tensor[:,i:i+1], hidden_bottom_0.to(device), hidden_top_0.to(device))
        if i==input_tensor.shape[-1]-1:
        #if i == 1:
            loss = criterion(out, target)
            loss.backward()
            optimizer.step()
            return loss.item(), out
        
        
def val_step(hidden_size, input_tensor, target):
    hidden_bottom_0 = torch.zeros(1,hidden_size).to(device)
    hidden_top_0 = torch.zeros(1,hidden_size).to(device)

    for i in range(input_tensor.shape[-1]):
        out, hidden_bottom_0, hidden_top_0 = frnn(input_tensor[:,i:i+1], hidden_bottom_0.to(device), hidden_top_0.to(device))
        if i==input_tensor.shape[-1]-1:
            loss = criterion(out, target)
            return loss.item(), out
        
        
def get_accuracy(logit, target, batch_size):
    ''' Obtain accuracy for training round '''
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
device

device(type='cuda', index=0)

In [16]:
# Hyperparameter Settings

epochs = 150
hidden_size = 128
gamma = 0.15
epsilon = 0.1

In [17]:
#device="cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
## MAIN TRAINING LOOP ## 


# Select Model
frnn = FRNN_AS_SC(1,hidden_size,10, device, epsilon=epsilon, gamma=gamma)
#frnn = TLRNN_AS_SC(1,hidden_size,10, device, epsilon=epsilon, gamma=gamma)
#frnn = FRNN_SC(1,hidden_size,10, device, epsilon=epsilon, gamma=gamma)



frnn.to(device)
criterion = nn.CrossEntropyLoss()
optimizer= torch.optim.Adagrad(frnn.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=10, verbose=True, factor=0.1)


#Lists to gather data
losses = []
accuracys = []
val_losses = []
val_accuracys = []

for epoch in range(epochs):
    epoch_accuracys = []
    epoch_losses = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.reshape(batch_size,-1)
        data = data.to(device)
        target = target.to(device)
    
        loss, pred = train_step(hidden_size, data, target)

        epoch_losses.append(loss)
        acc = get_accuracy(pred, target, batch_size)
        epoch_accuracys.append(acc)
    
    

    losses.append(np.mean(epoch_losses))
    accuracys.append(np.mean(epoch_accuracys))
    print('Epoch:  %d | Loss: %.4f | Train Accuracy: %.2f' 
      %(epoch, np.mean(epoch_losses) , np.mean(epoch_accuracys)))
    
    epoch_accuracys = []
    epoch_losses = []
    


    #Performing Evaluation
    for batch_idx, (data, target) in enumerate(test_loader):
        data = data.reshape(batch_size,-1)
        data = data.to(device)
        target = target.to(device)

        loss, pred = val_step(hidden_size, data, target)

        epoch_losses.append(loss)
        acc = get_accuracy(pred, target, batch_size)
        epoch_accuracys.append(acc)
        
    #Save best model only
    if(epoch > 1):
        if(np.mean(epoch_losses) < np.min(val_losses)):
            path = "model.pt"
            torch.save(frnn.state_dict(), path)
        
    val_losses.append(np.mean(epoch_losses))
    val_accuracys.append(np.mean(epoch_accuracys))

    

    print('Epoch:  %d | Val-Loss: %.4f | Val Accuracy: %.2f' 
      %(epoch, np.mean(epoch_losses) , np.mean(epoch_accuracys)))
    