# Study-01-AugmentedDataset
Details: [Notion Link](https://www.notion.so/AntWorld-Dataset-ce7a7c36f5154ed4a15e192bda2a06af)

In [1]:
import itertools
import copy
import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as P
import gym
from gym import spaces
from tqdm import tqdm
import seaborn as sns

import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer

import pyRC.datasets.nordland as Nordland
import pyRC.analyse.utils as utA
import pyRC.datasets.utils as utD
import pyRC.learn.utils as utL
from pyRC.network import ESN_NA
from pyRC.environments import AntWorld

# Plotting settings
sns.set_context("notebook", font_scale = 1.5)
sns.set_style("dark")
sns.set_palette("deep", 12)
from ipywidgets import interact, interactive, fixed, interact_manual
cmap='turbo'

## Load Environment

In [2]:
# Ant environment
env = AntWorld('AntWorld-Gym-02')



### Prepare environment for training

In [3]:
# Img = np.load('AntWorld-Gym-01.npz')['Img'] # Generic Environment
Img = env.Img
print(Img.shape)
height, width = Img.shape[1:]

(8000, 10, 36)


In [4]:
Imgs, Lbls = [], []
for x in np.arange(0, 2, 0.05): # Label: 1
    Imgs.append(env.getImage([x,0],0))
    Lbls.append(1) 
for x in np.arange(2, 4, 0.05): # Label: 0
    Imgs.append(env.getImage([x,0],0))
    Lbls.append(0)

nImg  = len(Imgs)
Image = torch.Tensor(Imgs).unsqueeze(-1) # should be of size [nImg, h, w, 1]
Label = [1] * int(nImg/2) + [0] * int(nImg/2) # half 1, half 0 

dataSet    = utD.ImageDataset(Image, Label)
dataLoader = DataLoader(dataSet,batch_size=1, shuffle=False)

## PyTorch Lignthning Module for Experiment

In [13]:
class Experiment(LightningModule):
    def __init__(self, trainData):
        super().__init__()
        self.hparams   = {**config['hyperparameters'], **config['experimentParameters'], **config['modelParameters'], **config['experimentDetails']} # unzip nested dict
        self.trainData = trainData
        self.RC        = ESN_NA(self.hparams)
        self.GroundTruth = torch.Tensor(dataLoader.dataset.labels)
        
    def forward(self, x):
        return self.RC(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.RC.params)
    
    def train_dataloader(self):
        return self.trainData
    
    def test_dataloader(self):
        return self.trainData
    
    def training_step(self, batch, batch_idx):
        if batch_idx == 0: self.RC.reset() # At the beginning of each epoch, reset the model! 
        
        x, y  = batch
        y_hat = self(x)
        loss  = F.cross_entropy(y_hat, y)
        logs_loss = {'train_loss': loss}
        return {'loss': loss, 'log': logs_loss}
    
    def test_step(self, batch, batch_idx): 
        x, y    = batch
        y_hat   = self(x)
        tol_acc = (torch.abs(self.GroundTruth[batch_idx] - torch.argmax(y_hat))<self.hparams['tolerance']) # Accuracy
        return {'tol_acc': tol_acc, 'y_pred': y_hat}

    def test_epoch_end(self, outputs):
        acc  = torch.stack([x['tol_acc'] for x in outputs]).float().mean()*100 # percentage
        logs = {'test_acc': acc}
        return logs

## Create Model from config

In [6]:
config = utL.getConfig('config.yaml')
config['experimentDetails']['labelFreq'] = nImg//2

In [21]:
expDict = {
            'gpus'                     : 1,
            'profiler'                 : False,
            'log_save_interval'        : 10000,
            'progress_bar_refresh_rate': 1000,
            'row_log_interval'         : 10000,
            'max_epochs'               : 10,
            }

In [22]:
exp     = Experiment(dataLoader)
trainer = Trainer(**expDict)
trainer.fit(exp); 
trainer.test(ckpt_path=None);

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | RC   | ESN_NA | 2 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': tensor(50., device='cuda:0')}
--------------------------------------------------------------------------------



## Inspect States