## Experiment

This code corresponds to experiment (iii.1) and (iii.2) in the paper. To change the experiment, modify the dataset EXP variable (see cell below)

## Environment

In this environment the agent is placed atop a platform, it can look left or right, or remain facing in the current direction. There are some cubes that move up/down (perpendicular to the view changes). In this example the goal is to disentangle these moving cubes (exafference) from the changes in view caused by the agents action (reafference).

Instead of using the environment directly, we use a dataset that has been generated using the world-of-bugs platform. World of Bugs is a platform that is otherwise used for automated bug detection, but has a nice collection of environments built in the Unity3D engine. We have adapted one of these environments to produce the artificial ape environment.




### Action Space
The action is a `ndarray` with shape `(1,)` which can take values `{0, 1, 2}` indicating the direction of the view rotation. The view is rotated by a small amount (2 degrees).

| Num | Action                 |
|-----|------------------------|
| 0   | Look Reft              | 
| 1   | Noop                   |
| 2   | Look Right             |


### Observation Space

The full observation has been cropped to an 64 x 64 image. The observation are in a pytorch compatible format (float, CHW, [0-1]).

The observation is a `ndarray` with shape `(1,64,64)` an image.

### More info
Install this repo as a dependency:

```
import sys
!{sys.executable} -m pip install ./reafference
```


In [15]:
# imports
import gym
import math
import time
import copy
import torch
import torch.nn as nn
import torchvision
import random
import glob
import numpy as np
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torchvision.transforms.functional import resize

from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from IPython.display import clear_output

import reafference.jnu as J
from reafference.environment.artificialape import make_dataset, load
from reafference.model import UNet, DiagLinear

DEVICE = "cuda:0"
# CHANGE THIS PATH TO THE DATASET LOCATION IF NEEDED!
# experiment (iii.1), experiment (iii.2), other 
EXP = 1 # experiment 1 or 2
DATASET = ['no-platform', 'platformi', 'platform']
PATH = f"./data/artificial-ape/{DATASET[EXP-1]}/train/"


In [16]:
# Data Exploration
dataset = load(PATH + "episode-0.tar")
state, action, rotation = [np.stack(z) for z in zip(*dataset)]
gt_effect = ((state[1:] - state[:-1]) + 1) / 2
imgs = np.concatenate([state[:-1],gt_effect], axis=3)
aimg = resize(torch.from_numpy(np.eye(3)[action[:-1]]).unsqueeze(1).unsqueeze(1), size=(3, imgs.shape[-1]), interpolation=0).numpy()
imgs = np.concatenate([aimg, imgs], axis=-2)
# unfortunately we cannot visualise the ground-truth reafferent/exafferent effect as its not avaliable in the dataset.
J.images(imgs, scale=3)


interactive(children=(IntSlider(value=0, description='x', layout=Layout(width='99%'), max=1022), Output()), _d…

HBox(children=(Canvas(height=201, width=384),), layout=Layout(align_items='center', display='flex', flex_flow=…

<reafference.jnu.image._image.Image at 0x7f0cfe5c0b80>

In [17]:
# make dataset for training
dataset = make_dataset(*glob.glob(PATH + "*"))

  0%|          | 0/10 [00:00<?, ?it/s]

In [19]:
# UNet to estimate effects
state_shape = [1,64,64]
action_shape = [3]
latent_shape = [512]
epochs = 50

model = UNet(state_shape[0], state_shape[0], exp=4, output_activation=torch.nn.Tanh(), batch_normalize=False)
conditional_shape = model.conditional_shape(state_shape)
model.condition = DiagLinear(conditional_shape, action_shape=action_shape[0])

model = model.to(DEVICE)

optim = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = torch.nn.MSELoss()

import torchinfo
torchinfo.summary(model, input_data=(torch.zeros(2, *state_shape), torch.zeros(2, *action_shape)), device=DEVICE)

# pretrained models
# experiment (iii.1) Artificial-Ape-1.model.pt
# experiment (iii.2) Artificial-Ape-2.model.pt
# experiment (iii.2 no indicator) Artificial-Ape-3.model.pt
model.load_state_dict(torch.load(f"./models/Artificial-Ape-{EXP}.model.pt"))


<All keys matched successfully>

In [7]:
loader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=False)
    
# Train
epoch_iter = tqdm(range(epochs))

for e in epoch_iter:
    avg_loss = []
    for x1, x2, a in loader:
        #print(x[...,2].min(), x[...,2].max())
        x1, x2, a = x1.to(DEVICE), x2.to(DEVICE), a.to(DEVICE)
        optim.zero_grad()
        
        # prediction of the total effect of each action
        pred_total_effect = model(x1, a)        
        # prediction of the exafferent effect - when all actions are noop  
        noop = torch.zeros_like(a)
        noop[:,1] = 1. # noop is at index 1!
        pred_exafferent_effect = model(x1, noop)
        
        # prediction of the reafferent effect (total - exafferent)
        # this will be 0 for any a == 0 in the batch
        pred_reafferent_effect = pred_total_effect - pred_exafferent_effect.detach()
        #pred_reafferent_effect[a[:,0] == 1] = 0. # detach gradients where reafferent effect should be 0 (?)
    
        pred_effect = pred_exafferent_effect + pred_reafferent_effect # combined effect
        total_effect = x2 - x1 # ground truth total effect
        
        loss = criterion(pred_effect, total_effect)
        loss.backward()
        avg_loss.append(loss.detach())
        optim.step()
    
    avg_loss = torch.stack(avg_loss).cpu().numpy().mean()
    epoch_iter.set_description(f"Loss: {avg_loss : .5f}")
        

  0%|          | 0/50 [00:00<?, ?it/s]

In [22]:
import torchvision.transforms.functional as fn

def get_images(state, action, total_effect, t_effect, re_effect, ex_effect):
    b = -torch.ones_like(total_effect[...,:1]).cpu()
    imgs = (torch.cat([b, total_effect.cpu(), b, t_effect.cpu(), b, re_effect.cpu(), b, ex_effect.cpu(), b], dim=-1) + 1) / 2
    imgs = torch.cat([state[:-1].cpu(), imgs], dim=-1)
    ai = fn.resize(torch.eye(3)[action[:-1]].unsqueeze(1).unsqueeze(1), size=(3,imgs.shape[-1]), interpolation=0)
    imgs = torch.cat([imgs, ai], dim=-2)
    return imgs

state, action, rotation = [np.stack(z) for z in zip(*load(PATH.replace("train", "test") + "episode-0.tar"))]

with torch.no_grad():
    x1, a = torch.from_numpy(state[:-1]).to(DEVICE), torch.eye(3)[torch.from_numpy(action[:-1])].to(DEVICE)
    pred_total = model(x1, a)
    noop = torch.zeros_like(a)
    noop[:,1] = 1. 
    pred_ex = model(x1, noop)
    pred_re = pred_total - pred_ex

gt_effect = torch.from_numpy((state[1:] - state[:-1]))
imgs = get_images(torch.from_numpy(state), torch.from_numpy(action), gt_effect, pred_total, pred_re, pred_ex)
    
J.images(imgs, scale=3, on_interact=action)

imgs = fn.resize(imgs, size=[2*imgs.shape[2], 2*imgs.shape[3]], interpolation=torchvision.transforms.functional.InterpolationMode.NEAREST)

video = (imgs.permute(0,2,3,1) * 255).int().repeat(1,1,1,3)[:200]
torchvision.io.write_video(f"./media/ArtificialApe-{EXP}-Predictions.mp4", video, fps=3)


interactive(children=(IntSlider(value=0, description='x', layout=Layout(width='99%'), max=1022), Output()), _d…

HBox(children=(Canvas(height=201, width=975),), layout=Layout(align_items='center', display='flex', flex_flow=…

In [11]:
torch.save(model.state_dict(), "./models/Artificial-Ape-3.model.pt")