# Visual State Classification (Python 3)
## <div id="content">Contents</div>
1. <a href="#arch">Model Specification</a> 
2. <a href="#train">Training</a> 
3. <a href="#test">Testing</a> 
4. <a href="#visual">Visualisations</a>

In [None]:
%%bash
if [ $(id -u) = 0 ];
then
    echo "You are root" # Should be root!
else
    echo "You do not have the right priviliges, you must be root"  
fi
python3 -V
nvcc --version # Should be 10.1

In [None]:
import sys
if sys.version_info[0] < 3:
    raise Exception("Must be using Python 3")

from IPython.display import clear_output # Clear Outpuopencv-pythont in cell programmatically
from ipywidgets import IntProgress # Progress Bar
import os, sys, shutil #Advanced File Manipulation
from os import path
import glob
import matplotlib as mlt # Data Visualisation
import matplotlib.pyplot as plt
#import opencv_python as cv2 # Interpret camera images
import pandas, numpy # Data Manipulation
from termcolor import colored # Colored Text

import subprocess # Used to run scripts in the background
import signal # Used to signal OS (e.g. to kill a process)
import csv # easily read and write to CSV files (used for ground truths)
import math # For trigonometry
import json # Effeciently store data structures

# Hyper-parameter Optimization Modules
import types as types
from types import * # Special Types (e.g. classes)
import inspect # Inspect method/class signatures
import pickle # Effeciently store classes to file


import getpass # For sudo-based tasks (hides input)
import re # Regular expressions
import random # For variation
import copy
import time # Allow python to wait
from enum import Enum

# Machine Learning
import numpy as np
import torch # Deep Learning research library
import torchsnooper
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision
from torchvision import datasets, transforms, models
from PIL import Image

# >>> Utilities <<<

class color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'
    
class label(Enum):
    GRASP = 0
    BAD = 1
    NOT_READY = 2
    EMPTY = 3
    
    

Setup CUDA path on Linux

In [None]:
output = subprocess.check_output("source ./switch_cuda.sh 10.1; env -0", shell=True,
                      executable="/bin/bash")

os.environ.update({"CUDA_HOME":"/usr/local/cuda-10.1"})
print(dict(os.environ)["CUDA_HOME"])

## <span id="arch">Model Specification</span>
<a href="#content">Return to Contents</a>

As we are working with spatio-temporal data where the time series component is potentially very informative, a Residual Neural Network (ResNet) with additional Long-Short Term Memory (LSTM) layers is appropriate. The residual components allow for feedback between layers while the LSTM enables the history of states to be remembered and exploited for better performance.

In this tiered development approach, the complexity of models is increased progressively to accomodate new features and determine more useful outputs. For appropriate fallback and baseline comparisons, earlier - functional - architectures for each model are retained here. Models are grouped by type and internally ordered by complexity, simplest first.

### <span id="arch-content">Section Contents</span>

1. <a href="#grasp">Grasp Point Extraction</a>
    1. <a href="#grasp">Detection</a>
    2. <a href="#grasp-centroid">Centroid Extraction</a>
    3. <a href="#grasp-orientation">Orientation Extraction</a>
2. <a href="#pick">Pick-And-Place</a>
    1. <a href="#pick-semi">Semi-supervised</a>
    2. <a href="#pick-simul">Unsupervised in simulation</a>
    3. <a href="#pick-real">Unsupervised in reality</a>
    4. <a href="#pick-placement">Placement</a>
    5. <a href="#pick-haptic">Haptic</a>
3. <a href="#waste">Waste Category Classifier</a>
    1. <a href="#waste-material">By Material</a>
    2. <a href="#waste-type">By Type</a>

## <span id="train">Training</span>
<a href="#content">Return to Contents</a>

#### Options
Our intention is to abstract out hyperparameters in the neural network (and create a "master" class for "instantiate" training sessions). To do so all relevant hyperpameters must be identified and validated. Hence, in this section we define the infrastructure needed to define which hyperparameter options exist. This includes their exact_name (as a string) and their possible values (type or list).

Identified Hyperparmaters:
1.   Loss Function
2.   Activation Function
3.   Optimizer
4.   Number of epochs
5.   Batch Size
6.   Shuffler
7.   Transformer
  *   Resizing
  *   Cropping
  *   Normalization
  *   ColourSpace 
  *   Contrast
8. Dropout probability
9. Architecture
10. Dataset


In [None]:
class Options(object):
    def __init__(self, id, filePath=None, **kwargs):
        self.id = id
        if filePath is not None:
            with open(filePath, 'rb') as handle:
                d = pickle.load(handle)
                for key, value in d.items():
                    if key == "id":
                        continue
                    exec("self.{}={}".format(key,value if type(value) is list else value.__name__))
        print(type(kwargs), len(kwargs))
        self.add_options(kwargs)

    def get(self, name):
        if hasattr(self, name):
            return eval("self.{}".format(name))
        else:
            return None

    def get_module_classes(self, moduleName):

        try:
            exec("import {}".format(moduleName))
        except:
            raise ImportError("Module {}".format(moduleName))

        black_list = ["List", ]
        func_list = []
        # ToDo: Select functions to REMOVE not to INCLUDE
        count = 0
        func_dict = {}
        for func in dir(eval(moduleName)):
            if re.search("__", func) or re.search("^_|_$", func) or func in black_list:
                continue

            desc = eval("{}.{}.__doc__".format(moduleName,func))
            if desc is None:
                desc = "No Description Found"
            else:
                desc = desc.split("\n")[0]
            print("{}: {}{}{} ({})".format(count,color.BOLD,func,color.END,desc))

            func_dict[count] = func
            count += 1

        resp = raw_input("Provide digit (e.g. 3), or list of digits (e.g. 2,4,6 ), of elements to IGNORE")

        while resp != "x":
            print("'{}'".format(resp))
            numbers = resp.split(",")
            for number in numbers:
                if number.isdigit():
                    try:
                        del func_dict[int(number)]
                    except:
                        print("Could not remove '{}' from elements")
                else:
                    print("{} is not a digit".format(number))

            print("--------------")
            # clear_output(wait=True)

            for count, func in func_dict.items():
                desc = eval("{}.{}.__doc__".format(moduleName,func))
                if desc is None:
                    desc = "No Description Found"
                else:
                    desc = desc.split("\n")[0]
                print("{}: {}{}{} ({})".format(count,color.BOLD,func,color.END,desc))
            resp = raw_input("Provide digit or list of digits to ignore, 'x' to exit")

        return ["{}.{}".format(moduleName,value) for value in func_dict.values()]

    def add_options(self, options):
        if type(options) is not dict:
            raise TypeError(
                "Option must be a dictionary \{name:type OR list\}")
        for key, value in options.items():
            print("{}:{}".format(key,value))
            if type(key) is str and type(value) is type:
                exec("self.{}={}".format(key,value.__name__))  # int
            elif type(key) is str and type(value) is str:
                exec("self.{}={}".format(key,self.get_module_classes(value)))
            elif type(key) is str and type(value) is list:
                exec("self.{}={}".format(key,value))
            else:
                raise TypeError("Key must be str, value must be type or list")

    def remove_option(self, name):
        if hasasttr(self, name):
            delattr(self, name)
            print("Removed attribute {} from options".format(name))

    def save_options(self, dirPath):
        d = self.__dict__
        file_path = os.path.join(dirPath, 'options_{}.pickle'.format(self.id))
        if os.path.exists(file_path):
            raise ValueError("{} already exists".format(file_path))
        with open(file_path, 'wb') as handle:
            pickle.dump(d, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print("Saved options to file:\n{}".format(file_path))

    def __str__(self):
        info_str = ">>> {0}Options {2}{1} <<<\n".format(color.BOLD,color.END,self.id)
        for key,value in vars(self).items():
            info_str += "\n{0}{2}{1}={3}\n".format(color.BOLD,color.END,key,value)
        return info_str
            

**Load Options from File**

In [None]:
options = Options("basic",filePath=path.join(os.getcwd(),"options","options_basic.pickle"))
print(options)

**Generate a new Options instance**

Acts as a blueprint for hyperparameters. That is it identifies which hyperparameters exist and how they can be changed.

In [None]:
# Note: For convenience we've limited the models to Classification, feel free to remove this
classification_architectures = ["torchvision.models.alexnet","torchvision.models.vgg11","torchvision.models.vgg11_bn","torchvision.models.vgg13",
                "torchvision.models.vgg13_bn","torchvision.models.vgg16","torchvision.models.vgg16_bn","torchvision.models.vgg19",
                "torchvision.models.vgg19_bn","torchvision.models.resnet18","torchvision.models.resnet34","torchvision.models.resnet50",
                "torchvision.models.resnet101","torchvision.models.resnet152","torchvision.models.squeezenet1_0","torchvision.models.squeezenet1_1",
                "torchvision.models.densenet121","torchvision.models.densenet169","torchvision.models.densenet161","torchvision.models.densenet201",
                "torchvision.models.googlenet","torchvision.models.shufflenet_v2_x0_5","torchvision.models.shufflenet_v2_x1_0","torchvision.models.shufflenet_v2_x1_5",
                "torchvision.models.shufflenet_v2_x2_0","torchvision.models.mobilenet_v2","torchvision.models.resnext50_32x4d","torchvision.models.resnext101_32x8d",
                "torchvision.models.wide_resnet50_2","torchvision.models.wide_resnet101_2","torchvision.models.mnasnet0_5","torchvision.models.mnasnet0_75",
                "torchvision.models.mnasnet1_0","torchvision.models.mnasnet1_3","torchvision.models.inception_v3"]

name = input("Name for new options")
options_new = Options(name,filePath=None,**{"loss":"torch.nn.modules.loss",
                                   "activation":"torch.nn.modules.activation",
                                   "optimizer":"torch.optim",
                                   "no_epochs":int,
                                   "batch_size":int,
                                   "shuffle":bool,
                                   "architecture":classification_architectures,
                                   "dropout_probability":float,
                                   "learning_rate":float})

options_new.save_options(path.join(os.getcwd(),"options"))

#### Configuration
Provides a unique configuration for the hyperparameters to be passed to the Training Pipeline. It uses the Options Class to define which hyper-parameters may exist. The config maps hyper-parameters names to one specific value. This way we can easily add new configurations (and put those Config objects in a list) to queue many different hyperparameter combinations for training.

In [None]:
class Config(object):
    options = None

    def __init__(self, file_path=None, options=None, verbose = False, **kwargs):
        if file_path is not None:
            if verbose: print("Loading configuration from {}".format(file_path.split(os.path.sep)[-1]))
            with open(file_path, 'rb') as handle:
                d = pickle.load(handle)
                for key, value in d.items():
                    value, kwargs = value
                    if verbose: print(key, str(value), type(value))
                    if type(value) is str:
                        exec("self.{} = ('{}',{})".format(key,value,kwargs))
                    else:
                        exec("self.{}=({},{})".format(key,value,kwargs))
            self.configured = (True, None)
            return

        if Config.options is None:
            if isinstance(options, Options):
                print("Setting Options for all Config Instances")
                Config.options = options
            else:
                raise TypeError(
                    "Options must be provided on 1st instance of Config Class")

        if len(kwargs) > 1:
            for key, value in kwargs.items():
                if Config.options.get(key) is not None:
                    try:
                        value, kwargs = value
                        if type(value) is str:
                            exec("self.{}=('{}',{})".format(key,value,kwargs))
                        else:
                            exec("self.{}=({},{})".format(key,value,kwargs))
                    except:
                        exec("self.{}=({},None)".format(key,value))
                else:
                    raise NameError("{} not found in available hyperparameters.".format(key))
            self.configured = (True, None)
        else:
            self.configured = (False, None)

    def config_wizard(self):
        if self.configured[0]:
            print("Already Configured {}".format(id(self)))
            return False  # Unsuccessful

        for option_name, abstract_val in Config.options.__dict__.items():
            # abstract_val = eval(f"options.{option_name}")

            is_typed = False
            if option_name == "id":
                continue
            print(">>> Choose {} <<<".format(option_name))
            if type(abstract_val) is list:
                print('\n'.join("{}: {}".format(i,abstract_val[i]) for i in range(len(abstract_val))))
            else:
                is_typed = True
                print(abstract_val)
                print("Enter value of Type {}".format(abstract_val.__name__))
            resp = input()
            while True:
                if is_typed == True:
                    try:
                        specific_val = abstract_val(resp)
                    except:
                        resp = input("Please enter value of type {}".format(abstract_val.__name__))
                        continue
                    exec("self.{} = ({},None)".format(option_name,specific_val)) in locals()
                    break
                else:  # Not Typed
                    if resp.isdigit():
                        try:
                            specific_val = abstract_val[int(resp)]
                        except:
                            resp = input("Please pick an integer index in range {}".format(len(abstract_val)))
                            continue

                        #*package, method = specific_val.split('.')
                        exec("import {}".format('.'.join(specific_val.split('.')[:-1]))) in locals()
                        print(specific_val)
                        
                        method = eval(specific_val)
                        # >>> Display Optional Parameters <<<
                        counter = 0
                        valid_counts = []; kwargs = {}
                        args,varargs,keywords,defaults = inspect.getargspec(method.__init__ if inspect.isclass(method) else method)
                        if defaults is not None:
                            offset = len(args)-len(defaults)
                            for i in range(offset,len(args)):
                                key, default_value = args[i], defaults[i-offset]

                                if default_value not in [None]:
                                    print("{}: {} [Default:{}]".format(counter,key,default_value))
                                    valid_counts.append(counter)
                                counter += 1

                            if len(valid_counts) > 0:
                                resp = input("Select parameter by number (e.g. 0), or ENTER to continue")
                                while resp:
                                    if resp.isdigit():
                                        if int(resp) in valid_counts:
                                            key, default_value = args[int(resp)+offset], defaults[int(resp)]
                                            new_value = input("Enter new value of type {} [Default: {}]".format(type(default_value).__name__,
                                                                                                                default_value))
                                            while new_value:
                                                try:
                                                    new_default_val = type(default_value)(new_value)
                                                    break
                                                except:
                                                    new_value = input("Enter new value of type {} [Default: {}]".format(type(default_value).__name__,
                                                                                                        default_value))
                                            kwargs[key] = new_default_val
                                    resp = input("Select parameter by number (e.g. 0), or ENTER to continue")
                        exec("self.{} = ('{}',{})".format(option_name,specific_val,kwargs)) in locals()
                        break
                    resp = input("Please pick an integer index in range {}".format(len(abstract_val)))
        self.configured = (True, None)
        return True

    def save_configuration(self, dirPath):
        name = input("Please select a Configuration name >>>")
        file_path = os.path.join(dirPath, 'configuration{}.pickle'.format(name))
        if os.path.exists(file_path):
            raise ValueError("{} already exists".format(file_path))
        with open(file_path, 'wb') as handle:
            pickle.dump(self.__dict__, handle,
                        protocol=pickle.HIGHEST_PROTOCOL)
            #pickle.dump(self, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print("Saved configuration to:\n{}".format(file_path))
        return file_path

    def get(self, attribute, default=None, kwargs=False):
        try:
            specific_val, kwargs = eval("self.{}".format(attribute))
        except:
            return default

        try:
            constraint = eval("Config.options.{}".format(attribute))
        except:
            constraint = None
        if type(constraint) is type:
            return constraint(specific_val)
        else:
            exec("import {}".format('.'.join(specific_val.split('.')[:-1])))
            return eval(specific_val), kwargs

    def as_dict(self):
        return vars(self)
    
    def __str__(self):
        info_str = ">>> {} Configuration {} <<<".format(color.BOLD,color.END)
        for key,value in vars(self).items():
            if type(value) == tuple:
                info_str += "\n{0}{2}{1} = {3}".format(color.BOLD,color.END,key,value[0])
                if len(value) == 2:
                    if value[1] is not None: info_str += "\n\t**kwargs = {}".format(value[1])
            else:
                info_str += "\n{} = {}".format(key,value)
        return info_str

config_base = Config(options=options)


**Generate a new hyperparameter Configuration**

Species a set of hyper-parameters to be passed into the training pipeline. Where a hyper-parameter is not set, a default value will be used.

In [None]:
config_new = Config()
config_new.config_wizard()

Save the new confuration to file

In [None]:
config_new.save_configuration(path.join(os.getcwd(),"configurations"))

#### <a name="assembly">Load Existing Configuration</a>
[Return to Section Content](#arch-content)<br/>
**Options**: Specifies which hyper-parameters may be changed. Maps hyper-parameter name to option (type or list)<br />
**Configuration**: Specific set of hyper-parameter values. Maps hyper-parameter name to value (function or primitive)

In [None]:
options = Options("basic",filePath=path.join(os.getcwd(),"options","options_basic.pickle"))
Config(options=options)
config_lookup = {}
for i,config_name in enumerate(os.listdir(path.join(os.getcwd(),"configurations"))):
    print("{} : {}".format(i,config_name))
    config_lookup[i] = config_name
    
resp = input("Choose a config by number: ")
while True:
    if resp.isdigit():
        if 0 <= int(resp) < len(config_lookup):
            config_name = config_lookup[int(resp)]
            break
    resp = input("Choose a config by number: ")
config = Config(file_path=path.join(os.getcwd(),"configurations",config_name))

print(config)

#### <a name="Torch_DataSet">PyTorch Dataset Interface</a>
[Return to Section Content](#arch-content)<br/>

Defines a custom PyTorch dataset for accessing generated images using their unique ID. Images are uniquely defined by:
* Their sequence number (ordered by time during pick-and-place task, later stages will have higher numbers)
* Their iteration (a single attempt at pick-and-place)
* Their run (a collection of many iterations, named during dataset generation)


In [None]:
class VisualDataset(Dataset):

    def __init__(self, dataset_dir, transforms=transforms.Compose([transforms.ToTensor()])):
        
        if os.path.exists(dataset_dir):
            self.dataset_dir = dataset_dir
        else:
            raise NameError("{} path DNE".format(dataset_dir))
        self.name = os.path.basename(self.dataset_dir)
        self.transforms = transforms
        
        counter = 0
        with open(path.join(self.dataset_dir,"metadata.csv"),"r") as f:
            self.meta_data = {}
            
            reader = csv.DictReader(f)
            next(reader, None) # Skip Header
            for row in reader:
                self.meta_data[counter] = row
                counter += 1

    def __getitem__(self, ID):
        
        # Given Image ID fetch image
        if ID not in self.meta_data: raise NameError("{} DNE".format(ID))
        img_name = self.meta_data[ID]['img_name']
        img_path = path.join(self.dataset_dir,img_name)
        #img = cv2.imread(img_path,cv2.IMREAD_COLOR)
        img = Image.open(img_path)
        img = self.transforms(img)
            
        label = int(self.meta_data[ID]["label"])

        return (img,label)

    def __len__(self):
        return len(self.meta_data)

**Define Training and Testing Sets**

Load a training set containing 80\% of generated data and a testing set containing 20\% of generated data. Split is random, hence reseeding may be worthwhile.

In [None]:
normalize = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                                            std = [0.229, 0.224, 0.225])                                           
])

dataset = VisualDataset(dataset_dir=path.join(os.getcwd(),"workspace","data","visualStateV2"))
print("{0}{2} dataset size{1}:{3}".format(color.BOLD,color.END,dataset.name.capitalize(),len(dataset)))
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
print("{0}Training Set Size{1}: {2}, {0}Testing Set Size{1}: {3}".format(color.BOLD,color.END,train_size,test_size))
trainset,testset = torch.utils.data.random_split(dataset, [train_size,test_size])

train_loader = torch.utils.data.DataLoader(trainset,
                                           batch_size=config.get("batch_size",5),
                                           shuffle=config.get("shuffle",True))

test_loader = torch.utils.data.DataLoader(testset,
                                          batch_size=config.get("batch_size",5),
                                          shuffle=config.get("shuffle",True))


**Inspect Dataset**

Check relative frequencies of each class in the dataset

In [None]:
meta_data = dataset.meta_data
counters = {label.GRASP:0,label.BAD:0,label.NOT_READY:0,label.EMPTY:0}
alt_counters={"failed":0,"succeeded":0}
for row in meta_data.values():
    counters[label(int(row['label']))] += 1
print(counters)

Check what percentage of the total data has a specific label

In [None]:
print("{}:{:.2f}%".format(label.BAD.name,100*counters[label.BAD]/len(dataset)))

**Load in Classification Architecture**

Define neural network architecture, with optional kwargs

In [None]:
architecture,kwargs = config.get("architecture",(torchvision.models.resnet18,{}))
#inspect.getargspec(architecture)
print("Using {} architecture".format(architecture.__name__))
try:
    net = architecture(num_classes=len(label),**kwargs)#".format(architecture,len(label),kwargs))
    print("Loaded {} with {} classes and kwargs: {}".format(architecture.__name__,len(label),kwargs))
except:
    net = architecture(pretrained=False)
    print("Loaded architecture does not support num_classes or kwargs parameter, override the input & output layers")

**Visualise functionality of Dataloaders**

Show subset of a patch, use to review validity of labels.

In [None]:
def visualise_batch(train_loader,frac=0.5,figwidth=10):
    (data,target) = next(iter(train_loader))
    
    nrows=math.ceil(frac*train_loader.batch_size/3)
    aspect_ratio = (nrows*246) / (3*468)
    
    fig, axes = plt.subplots(nrows=nrows,ncols=3,figsize=(figwidth,aspect_ratio*figwidth))
    
    for row in range(0,nrows):
        for col in range(0,3):
            img_data = data[row+col,:,:,:].numpy().swapaxes(0,2).swapaxes(0,1)
            axes[row,col].get_xaxis().set_ticks([])
            axes[row,col].get_yaxis().set_ticks([])
            axes[row,col].set_title(label(target[row+col].item()).name)
            [edge.set_color("white") for edge in axes[row,col].spines.values()]
            axes[row,col].imshow(img_data)
    #fig.suptitle("Batch",fontsize=24,weight='bold')
    plt.suptitle("Sample Batch | {}/{} images shown".format(math.ceil(frac*train_loader.batch_size), 
                train_loader.batch_size), fontsize=22, weight='bold')
    plt.subplots_adjust(wspace=0, hspace=0.2)
    plt.show()

visualise_batch(train_loader,frac=0.2)

**Train the Neural Network**

Run the defined (default:20) number of training epochs, evaluating the performance (and storing the result) at the end of each epoch. Hyper-parameters such as loss and activation are also swapped in here.

In [None]:
import torchsnooper

#@torchsnooper.snoop()
def train(model, device, train_loader, epoch, criterion, optimizer,verbose=False):
    model.train()
    losses = np.zeros(len(train_loader)) # Remember Losses
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        
        loss = criterion(output,target)

        loss.backward()
        
        losses[batch_idx] = loss.item() 
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if verbose:
                print("Outputs:{}\n{}".format(output.size(),output))
                print("Labels:{}\n{}".format(target.size(),target))
    return losses
        
def evaluate(model, device, test_loader, criterion, verbose=False):
    model.eval()
    running_loss = 0
    running_correctness = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            loss = criterion(output, target).item()
            #loss = criterion(output,target).item()
            running_loss += loss

            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            
            correct = pred.eq(target.view_as(pred)).sum().item()
            running_correctness += correct
            
            if verbose: print("Number Correct: {} out of {}".format(correct,len(data)))
            #print(pred.view_as(target),target)
    # test_loss /= len(test_loader.dataset)
    print(""">>> Evaluation <<<
Average Loss: {:.4f}
Average Correctness: {:.0f}%""".format(running_loss/len(test_loader),
                                      100*running_correctness/len(test_loader.dataset)))
    return 100*running_correctness/len(test_loader.dataset)

device = "cuda"    
print("{}Cuda{} is {}".format(color.BOLD,color.END,"available" if torch.cuda.is_available() else "not available"))
model = net.to(device)
optimizer, kwargs = config.get('optimizer',(torch.optim.Adam,{}))        
optimizer = optimizer(model.parameters(),**kwargs)
loss_func, kwargs = config.get('loss',(torch.nn.MSELoss,{}))   
criterion = loss_func(**kwargs)
print("{0}Loss Function{1}: {2}\n{0}Optimizer{1}:{3}".format(color.BOLD,color.END,loss_func.__name__,optimizer))

avg_losses = np.zeros(config.get('no_epochs',20))
avg_accuracies = np.zeros(config.get('no_epochs',20))
for epoch in range(0,config.get('no_epochs',20)):
    losses = train(model, device, train_loader, epoch, criterion, optimizer=optimizer,verbose=False)
    accuracy = evaluate(model, device, test_loader, criterion, verbose=False)
    avg_losses[epoch] = np.mean(losses)
    avg_accuracies[epoch] = np.mean(accuracy)

### Visualise Training Loss

Plot a Loss versus Epoch / Accuracy versus Epoch graph, useful to visualise the training process.

In [None]:
import matplotlib as mlt
font = {'size'   : 24}
mlt.rc('font', **font)
fig, ax1 = plt.subplots(figsize=(20,14))
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy (%)')
plt.xticks(range(1,config.get("num_epochs",20)+1))
ax1.scatter(range(1,config.get("num_epochs",20)+1),avg_accuracies, color="red", label="Accuracy")

ax2 = ax1.twinx()
ax2.set_ylabel("Loss")
ax2.scatter(range(1,config.get("num_epochs",20)+1),avg_losses, color="blue", label="Loss")

fig.subplots_adjust(right=0.78)  
fig.legend(loc="center right")
plt.show()

### Final Evaluation
Evaluate your model on the unseen testing set, use this to see if your model was overfitting.

In [None]:
evaluate(model, device, test_loader, criterion, verbose=True)

### Save Trained Model
Like the result? Save the model to file for future use. (remember to keep your random split seed handy)

In [None]:
model_name = input("Choose a name for your model")
torch.save(model,path.join(os.getcwd(),"models","model_{}".format(model_name)))

### Load Trained Model

In [None]:
model_lookup = {}
for i,model in enumerate(os.listdir(path.join(os.getcwd(),"models"))):
    print("{}: {}".format(i,model))
    model_lookup[i] = model
    
resp = input("Pick a Model by number:")
while True:
    if resp.isdigit():
        if (0 <= int(resp) < len(model_lookup)):
            name = model_lookup[int(resp)]
            break
    resp = input("Pick a Model by its number:")
    
model = torch.load(path.join(os.getcwd(),"models",name))
print("Loaded model {}".format(name))

**Show Model**

Print out PyTorch description of loaded model

In [None]:
model