# ECE 4194 Final Project. Spring 2019. The Ohio State University
## Authors: Michael Wharton.124, Alex Whitman.97, Benji Justice.251

This notebook is to train a resnet model to classify the actions specified in the dataset linked below.

Dataset: http://archive.ics.uci.edu/ml/datasets/Smartphone-Based+Recognition+of+Human+Activities+and+Postural+Transitions



### Define packages to autoreload

In [5]:
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
%aimport models.resnet

### Import necessary modules

In [15]:
# torch modules
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# import model
from models.resnet import resnet18



# classics
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import time

### Define model hyper parameters

In [8]:
# what percentage of data should be used for training 
split = 0.7

num_epoch = 1000
lr = 0.5e-4


### Create Model

In [9]:
model = resnet18()

### Handle GPU

In [11]:
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    device = torch.device('cuda:0')
    model = model.to(device)
else:
    dtype = torch.FloatTensor
    device = torch.device('cpu')

### Handle loss function and optimizer

In [None]:
crit = torch.nn.CrossEntropyLoss()
opt  = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0)

### Track stats

In [14]:
stats = np.zeros(shape=(num_epoch, 2, 2))

# indices
train = 0
test  = 1
acc   = 0
loss  = 1


### Load data

In [None]:
# ...

### Train !

In [None]:
tstart = time.time()
for epoch in range(num_epoch):
    
    # switch to train mode
    model.train()
    
    # track accuracy
    total = 0
    correct = 0
    
    # track loss
    tr_loss = []
    for batch, data in enumerate(train_dl):
        print('Training Iteration: {0:4d} or {1:4d}'.format(batch+1, train_batches), end='\r')
        
        # extract signal and labels
        signals, labels = data
        
        ## Normalize?
        
        # move data to device
        signal = signal.to(device)
        label = label.to(device)
        
        # process signals
        out = model(signals)
        
        
        # hard decision for classification
        _, pred = torch.max(out.data, 1)
        
        # compute accuracy
        total   += labels.size(0)
        correct += (pred == labels).sum().item()
        
        # compute loss
        b_loss = crit(out, labels)
        opt.zero_grad()
        b_loss.backward()
        opt.step()
        
        tr_loss.append(b_loss.item())
    
    # epoch training stats
    tr_accuracy = 100. * correct / total
    tr_loss = np.mean(tr_loss)
    
    # evaluate model
    model.eval()
    
    ts_loss = []
    total   = 0
    correct = 0
    with torch.no_grad():
        for batch, data in enumerate(test_dl):
            print('Testing  Iteration:: {0:4d} or {1:4d}'.format(batch+1, test_batches), end='\r')
            
            # extract signals and labels
            signals, labels = data
            
            # move data to device
            signals = signals.to(device)
            labels = labels.to(device)
            
            # process signals
            out = model(signals)
            
            # hard decision for classification
            _, pred = torch.max(out.data, 1)
            
            # compute accuracy
            total   += labels.size(0)
            correct += (pred == labels).sum().item()
            
            # compute loss
            b_loss = crit(out, labels)
            
            ts_loss.append(b_loss.item())
            
    # epoch testing stats
    ts_accuracy = 100. * correct / total
    ts_loss = ts_loss.mean(ts_loss)
    
    # save stats to plot later
    stats[epoch, train, loss     ] = tr_loss
    stats[epoch, train, accuracy ] = tr_accuracy
    stats[epoch, test,  loss     ] = ts_loss
    stats[epoch, test,  accuracy ] = ts_accuracy
    
    t1 = time.time() - tstart
    print('Epoch: {0:4d} Tr loss: {1:.3f} Ts loss: {2:.3f} Tr Accuracy: {3:3.2f}% Ts Accuracy: {4:3.2f}% Time: {5:4.2f}s'.format(
            epoch+1,      tr_loss,         ts_loss,         ts_accuracy,           ts_accuracy, t1))