# antBot Highway World Experiments: EMB
Run EMB on antBot Highway World Experiments

In [1]:
from os import listdir
import numpy as np
import pandas as pd
import matplotlib.pyplot as P
import seaborn as sns
from tqdm.notebook import trange, tqdm
from pyRC.analyse.perfectMemory import *

# Plotting settings
sns.set_context("notebook", font_scale = 1.5)
sns.set_style("dark")
sns.set_palette("deep", 12)
from ipywidgets import interact, interactive, fixed, interact_manual
cmap='Greys_r'

## Setup Datasets

In [2]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
from pyRC.datasets.utils import ImageDataset

class antBotDatasets():
    def __init__(self, expPath = '../data/2000-10/'):
        self.expPath     = expPath # experiment path
        self.HW          = [[75,360], [50,180], [25,90]] # (h,w) pairs FIXED for now
        self.strImages   = sorted([f.replace('.npy','') for f in listdir(self.expPath)])# get the name of all data in the given experiment path
        self.allImages   = np.stack([np.load(self.expPath+f+'.npy') for f in self.strImages]) # load all the data matched in strImages
        self.nDatasets   = self.allImages.shape[0]
        self.nImages     = self.allImages.shape[1]
        self.HW_Original = self.allImages.shape[-2:]
        self.nHW         = len(self.HW)
        print(f'>> Loaded {self.nDatasets} datasets of {self.nImages} images each!')
        
#         self.allImagesTorch = torch.Tensor(self.allImages)

    def setGT(self, strGT = 'mountains'):
        #TODO try except error throw if strGT is not in list
        self.strGT    = strGT # ground truth name
        self.gtImages = self.allImages[self.strImages.index(strGT)] # find the index of the gtName in the list of names first
        print(f'>> Set `{self.strGT}` as ground truth!')
    
    
    def get(self, strDataset = 'mountains', nImages=200, h=25, w=90):
        # dataSet in shape: torch.Size([num_images, height, width, batchSize=1]) # batchSize=1 reserved for batch, to be permuted to top
        dsID    = self.strImages.index(strDataset)
        Images0 = self.allImages[dsID]
        Images  = np.stack([IMAGEOP(img,h,w) for img in Images0]) # need to swap axis of height(2) and width(1)
        ImagesT = torch.Tensor(Images).permute(0,2,1).unsqueeze(-1) # add dimension for batch operations
        iLabels = range(self.nImages)
        dataSet                = ImageDataset(ImagesT, iLabels)
        dataLoader             = DataLoader(dataSet, batch_size = 1, shuffle = False)
        dataLoader.nInput      = h*w   # image height is the number of inputs
        dataLoader.nOutput     = nImages # nImages to choose from
        dataLoader.groundTruth = iLabels
        return dataLoader



In [3]:
ds = antBotDatasets()
DL = ds.get()

>> Loaded 7 datasets of 200 images each!


## Setup PyTorch-Lightning Experiment

In [9]:
import torch
from torch.nn import functional as F
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from pytorch_lightning.metrics.classification import ConfusionMatrix

import time
import numpy as np

import pyRC.datasets.nordland as Nordland
import pyRC.learn.utils       as utL
import pyRC.network as RC
import pyRC.analyse.utils as utA

import wandb
from pytorch_lightning.loggers import WandbLogger

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as P
P.style.use('seaborn-dark')

class antBotEMB(LightningModule):
    def __init__(self, config):
        super().__init__()
        self.hparams = {**config['hyperparameters'], **config['experimentParameters'], **config['modelParameters'], **config['experimentDetails']} # unzip nested dict
        self.dlargs  = {'nImages': self.hparams['nImages'], 'h': self.hparams['height'], 'w': self.hparams['width']}
        self.RC      = RC.ESN_NA(self.hparams, readoutType=self.hparams['readoutType'])
        print('>> Network is constructed!')
        
    def forward(self, x):
        return self.RC(x)

    def prepare_data(self):
        antBotData              = antBotDatasets()
        self.data_0             = antBotData.get('dawn_cloudy_empty'  , **self.dlargs)
        self.data_1             = antBotData.get('dusk'  , **self.dlargs)
        self.data_2             = antBotData.get('mountains'  , **self.dlargs)
        self.GroundTruth        = torch.Tensor(self.data_0.groundTruth)
        self.hparams['nParams'] = utA.modelParameters(self.RC, returnOnly=True).item()
        print('>> Datasets are loaded!')
        
    def train_dataloader(self):
        return self.data_0
      
    def test_dataloader(self):
        return [self.data_0, self.data_1, self.data_2]
    
    def configure_optimizers(self):
        params = [{'params': self.RC.Wout, 'lr': self.hparams['learningRate']}] 
        return torch.optim.Adam(params)

    def training_step(self, batch, batch_idx):

        if batch_idx == 0:  # At the beginning of each epoch, reset the model! 
            self.RC.reset()
            
        x, y  = batch
        y_hat = self(x)
        loss  = F.cross_entropy(y_hat, y)
        logs_loss = {'train_loss': loss}
        return {'loss': loss, 'log': logs_loss}

    def test_step(self, batch, batch_idx, dataloader_idx): 
        #TODO implement error instead of loss! torch.mean(abs(df['imageID']-df['predIDF']))
        x, y    = batch
        y_hat   = self(x)
        tol_acc = (torch.abs(self.GroundTruth[batch_idx] - torch.argmax(y_hat))<self.hparams['tolerance']) # Accuracy
        return {'test_loss': F.cross_entropy(y_hat, y), 'tol_acc': tol_acc, 'y_pred': y_hat}

    def test_epoch_end(self, outputs):
        loss_0 = torch.stack([x['test_loss'] for x in outputs[0]]).mean()
        loss_1 = torch.stack([x['test_loss'] for x in outputs[1]]).mean()
        loss_2 = torch.stack([x['test_loss'] for x in outputs[2]]).mean()
        
        acc_0  = torch.stack([x['tol_acc'] for x in outputs[0]]).float().mean()*100 # percentage
        acc_1  = torch.stack([x['tol_acc'] for x in outputs[1]]).float().mean()*100 # percentage
        acc_2  = torch.stack([x['tol_acc'] for x in outputs[2]]).float().mean()*100 # percentage
        
        logs_loss   = {'loss_0': loss_0, 'loss_1': loss_1, 'loss_2': loss_2}
        logs_acc    = {'acc_0' : acc_0,  'acc_1' : acc_1 , 'acc_2': acc_2}
        return {'log': {**logs_loss, **logs_acc}}

In [24]:
# ## Training
config = utL.getConfig('config.yaml')

config['hyperparameters']['nReservoir']  = 4000
config['experimentParameters']['nEpoch'] = 100
config['experimentDetails']['width']     = 360
config['experimentDetails']['height']    = 25
config['modelParameters']['nInput']      = config['experimentDetails']['width'] * config['experimentDetails']['height']

expDict = {
            'gpus'                     : 1,
            'profiler'                 : False,
            'log_save_interval'        : 10000,
            'progress_bar_refresh_rate': 100,
            'row_log_interval'         : 10000,
            'max_epochs'               : config['experimentParameters']['nEpoch'],
            }
            
exp     = antBotEMB(config)
trainer = Trainer(**expDict)
trainer.fit(exp); 
trainer.test(ckpt_path=None);

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | RC   | ESN_NA | 800 K 


>> Loaded 7 datasets of 200 images each!
>> Datasets are loaded!


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

Saving latest checkpoint..



>> Loaded 7 datasets of 200 images each!
>> Datasets are loaded!


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'acc_0': tensor(88., device='cuda:0'),
 'acc_1': tensor(30.0000, device='cuda:0'),
 'acc_2': tensor(20.0000, device='cuda:0'),
 'loss_0': tensor(1.9518, device='cuda:0'),
 'loss_1': tensor(4.9857, device='cuda:0'),
 'loss_2': tensor(6.5184, device='cuda:0')}
--------------------------------------------------------------------------------

