In [1]:
import torchvision
from torch.utils.data import Subset, DataLoader
import torch
import torchvision
import torch.nn.functional as F

import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import math
import tqdm
from sklearn.metrics import confusion_matrix
from tqdm import tqdm as tqdm
from models import *
%matplotlib inline
#from torch.utils.tensorboard import SummaryWriter
import random

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 128
#batch_size = 32


trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

indices = np.arange(len(trainset))

#train_indices, test_indices = train_test_split(indices, train_size=100*10, stratify=dataset.targets)

train_indices = indices[0:int(np.floor(50000/batch_size))*batch_size]
test_indices = indices[0:int(np.floor(10000/batch_size))*batch_size]


train_dataset = Subset(trainset, train_indices)

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

test_dataset = Subset(testset, test_indices)

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train_step(hidden_size, input_tensor, target):
    hidden_bottom_0 = torch.zeros(1,hidden_size).to(device)
    hidden_top_0 = torch.zeros(1,hidden_size).to(device)
    frnn.zero_grad()

    for i in range(input_tensor.shape[-1]):
    #for i in range(300):
        input = input_tensor[:,:,i:i+1].reshape(input_tensor.shape[0],-1)
        out, hidden_bottom_0, hidden_top_0 = frnn(input, hidden_bottom_0.to(device), hidden_top_0.to(device))
        #out, hidden_bottom_0 = frnn(input_tensor[:,i:i+1], hidden_bottom_0)
        if i==input_tensor.shape[-1]-1:
        #if i == 1:
            loss = criterion(out, target)
            loss.backward()
            optimizer.step()
            return loss.item(), out
        
        
def val_step(hidden_size, input_tensor, target):
    hidden_bottom_0 = torch.zeros(1,hidden_size).to(device)
    hidden_top_0 = torch.zeros(1,hidden_size).to(device)

    for i in range(input_tensor.shape[-1]):
        input = input_tensor[:,:,i:i+1].reshape(input_tensor.shape[0],-1)
        out, hidden_bottom_0, hidden_top_0 = frnn(input, hidden_bottom_0.to(device), hidden_top_0.to(device))
        if i==input_tensor.shape[-1]-1:
            loss = criterion(out, target)
            return loss.item(), out

        
def get_accuracy(logit, target, batch_size):
    ''' Obtain accuracy for training round '''
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
    accuracy = 100.0 * corrects/batch_size
    return accuracy.item()

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
epochs = 150
hidden_size = 128
gamma = 0.001
epsilon = 0.01

In [None]:
## MAIN TRAINING LOOP ## 



frnn = FRNN_AS_SC(3,hidden_size,10, device, epsilon=epsilon, gamma=gamma)
#frnn = TLRNN_AS_SC(3,hidden_size,10, device, epsilon=epsilon, gamma=gamma)
#frnn = FRNN_SC(3,hidden_size,10, device, epsilon=epsilon, gamma=gamma)



frnn.to(device)
criterion = nn.CrossEntropyLoss()
optimizer= torch.optim.Adagrad(frnn.parameters(), lr=1e-2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, patience=10, verbose=True, factor=0.1)


#Lists to gather data
losses = []
accuracys = []
val_losses = []
val_accuracys = []


for epoch in range(epochs):
    epoch_accuracys = []
    epoch_losses = []
    gts = []
    predictions = []
    for batch_idx, (data, target) in enumerate(trainloader):
        data = data.reshape(batch_size,3, -1)
        data = data.to(device)
        target = target.to(device)
    
        loss, pred = train_step(hidden_size, data, target)
        #loss, pred = train_step_arnn(hidden_size, data, target)

        epoch_losses.append(loss)
        acc = get_accuracy(pred, target, batch_size)
        gts += list(target)
        predictions += list(np.argmax(pred.detach().cpu().numpy(), axis=-1))
        
        epoch_accuracys.append(acc)
    
    

    losses.append(np.mean(epoch_losses))
    accuracys.append(np.mean(epoch_accuracys))
    print('Epoch:  %d | Loss: %.4f | Train Accuracy: %.2f' 
      %(epoch, np.mean(epoch_losses) , np.mean(epoch_accuracys)))
    

    
    epoch_accuracys = []
    epoch_losses = []
    


    #Performing Evaluation
    gts = []
    predictions = []
    for batch_idx, (data, target) in enumerate(testloader):
        data = data.reshape(batch_size,3, -1)
        data = data.to(device)
        target = target.to(device)

        loss, pred = val_step(hidden_size, data, target)
        #loss, pred = val_step_arnn(hidden_size, data, target)
        
        gts += list(target)
        predictions += list(np.argmax(pred.detach().cpu().numpy(), axis=-1))

        epoch_losses.append(loss)
        acc = get_accuracy(pred, target, batch_size)
        epoch_accuracys.append(acc)

        
    #Save best model only
    if(epoch > 1):
        if(np.mean(epoch_losses) < np.min(val_losses)):
            path = "model.pt"
            torch.save(frnn.state_dict(), path)
        
    val_losses.append(np.mean(epoch_losses))
    val_accuracys.append(np.mean(epoch_accuracys))

    
    #optimizer.step(np.mean(epoch_losses))

        


    print('Epoch:  %d | Val-Loss: %.4f | Val Accuracy: %.2f' 
      %(epoch, np.mean(epoch_losses) , np.mean(epoch_accuracys)))
    

Epoch:  0 | Loss: 1.8694 | Train Accuracy: 32.56
Epoch:  0 | Val-Loss: 1.7602 | Val Accuracy: 37.26
