## Environment

Your objective is to guide your chicken across lane after lane of busy rush hour traffic. You receive a point for every chicken that makes it to the top of the screen after crossing all the lanes of traffic. We use the `FreewayDeterminstic-v4` as it has a restricted actions space (that contains the required noop action), and is determinstic and so it is easier to visualise/interpret performance.


### Modifications to FreewayDeterministic-v4

The full observation has been cropped to an 84 x 84 image. This reduces the computational requirments but still illustrates the point. The observations also been converted to a pytorch compatible format (float, CHW, [0-1]).


### Action Space
The action is a `ndarray` with shape `(1,)` which can take values `{0, 1, 2}` indicating the direction of movement.

| Num | Action                 |
|-----|------------------------|
| 0   | Noop                   | 
| 1   | Move chicken up        |
| 2   | Move chicken down      |


### Observation Space

The observation is a `ndarray` with shape `(3,84,84)` an image of the state as a human player would see it.

Install this repo as a dependency:

```
import sys
!{sys.executable} -m pip install -e ./reafference
```


In [1]:
# imports
import gym
import math
import time
import copy
import torch
import torch.nn as nn
import torchvision
import random
import numpy as np
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt
import torchvision.transforms as T
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from IPython.display import clear_output

import reafference.jnu as J

from reafference import environment
from reafference.environment.freeway import ground_truth, make_dataset, make_episode, make
from reafference.model import UNet, DiagLinear


DEVICE = "cuda:0"


  logger.warn(
  if LooseVersion(mpl.__version__) >= "3.0":
  other = LooseVersion(other)


In [2]:
# Data Exploration
env = make()
state, action, info = make_episode(env, max_length=200)
ram = info['ram']
gt_effect, re_effect, ex_effect = ground_truth(env, state, ram)

b = -np.ones_like(gt_effect[...,:1]) # padding
imgs = (np.concatenate([gt_effect, b, re_effect, b, ex_effect], axis=3) + 1) / 2
imgs = np.concatenate([state[:-1], b, imgs], axis=3)

# action to image
aindx = np.eye(env.action_space.n).repeat(imgs.shape[-1]/3, 1)#.repeat(6, 0)[np.newaxis,...].repeat(3,0)
aindx = aindx[action[:-1]]
aindx = aindx[:,np.newaxis, np.newaxis,:].repeat(3,1).repeat(4,2)

imgs = np.concatenate([aindx, imgs], axis=-2)
J.images(imgs, on_interact=action, scale=3)
imgs = np.concatenate([aindx, imgs], axis=-2)
#torchvision.io.write_video("./images/freeway_ground_truth.mp4", torch.from_numpy(imgs[...,:-1].transpose(0,2,3,1) * 255).int(), fps=3)


A.L.E: Arcade Learning Environment (version 0.7.5+db37282)
[Powered by Stella]


(200, 3, 84, 84) (199, 3, 84, 84)


object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super(Widget, self).__init__(**kwargs)


interactive(children=(IntSlider(value=0, description='x', layout=Layout(width='99%'), max=198), Output()), _do…

HBox(children=(Canvas(height=264, width=1017),), layout=Layout(align_items='center', display='flex', flex_flow…

In [3]:
# make a dataset with a random policy
env = make()
dataset = make_dataset(env, num_episodes=5, max_episode_length=1000)
print(f"DATASET SIZE: {len(dataset)}")


DATASET SIZE: 4995


In [13]:
# UNet to estimate effects
state_shape = env.observation_space.shape
action_shape = (env.action_space.n,)
latent_shape = (512,)
epochs = 50

model = UNet(state_shape[0], state_shape[0], exp=5, output_activation=torch.nn.Tanh(), batch_normalize=False)
conditional_shape = model.conditional_shape(state_shape)
model.condition = DiagLinear(conditional_shape, action_shape=action_shape[0])

model = model.to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = torch.nn.MSELoss()

import torchinfo
torchinfo.summary(model, input_data=(torch.zeros(2, *state_shape), torch.zeros(2, *action_shape)))

model.load_state_dict(torch.load("./models/Freeway-v0.model.pt"))

<All keys matched successfully>

In [5]:
loader = DataLoader(dataset, batch_size=128, shuffle=True, drop_last=False)
    
# Train
epoch_iter = tqdm(range(epochs))

for e in epoch_iter:
    
    # make a dataset with a random policy

    #dataset = make_dataset(env, num_episodes=100, device=DEVICE)
    #
    avg_loss = []
    for x1, x2, a in loader:
        x1, x2, a = x1.to(DEVICE), x2.to(DEVICE), a.to(DEVICE)
        optim.zero_grad()
        
        # prediction of the total effect of each action
        pred_total_effect = model(x1, a)        
        # prediction of the exafferent effect - when all actions are noop  
        noop = torch.zeros_like(a)
        noop[:,0] = 1. 
        pred_exafferent_effect = model(x1, noop)
        
        # prediction of the reafferent effect (total - exafferent)
        # this will be 0 for any a == 0 in the batch
        pred_reafferent_effect = pred_total_effect - pred_exafferent_effect.detach()
        #pred_reafferent_effect[a[:,0] == 1] = 0. # detach gradients where reafferent effect should be 0 (?)
    
        pred_effect = pred_exafferent_effect + pred_reafferent_effect # combined effect
        total_effect = x2 - x1 # ground truth total effect
        
        loss = criterion(pred_effect, total_effect)
        loss.backward()
        avg_loss.append(loss.detach())
        optim.step()
    
    avg_loss = torch.stack(avg_loss).cpu().numpy().mean()
    epoch_iter.set_description(f"Loss: {avg_loss : .5f}")
        

  0%|          | 0/50 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [22]:
import torchvision.transforms.functional as fn

def get_images(state, action, total_effect, t_effect, re_effect, ex_effect):
    b = -torch.ones_like(total_effect[...,:1]).cpu()
    imgs = (torch.cat([b, total_effect.cpu(), b, t_effect.cpu(), b, re_effect.cpu(), b, ex_effect.cpu(), b], dim=-1) + 1) / 2
    imgs = torch.cat([state[:-1].cpu(), imgs], dim=-1)
    ai = fn.resize(torch.eye(3)[action[:-1]].unsqueeze(1).unsqueeze(1).repeat(1,3,1,1), size=(3,imgs.shape[-1]), interpolation=0)
    
    print(ai.shape, imgs.shape)
    imgs = torch.cat([imgs, ai], dim=-2)
    return imgs

env = make()
state, action, info = make_episode(env, max_length=200)
ram = info['ram']
gt_effect, re_effect, ex_effect = ground_truth(env, state, ram)

with torch.no_grad():
    x1, a = torch.from_numpy(state[:-1]).to(DEVICE), torch.eye(3)[torch.from_numpy(action[:-1])].to(DEVICE)
    x1, a = x1.contiguous(), a.contiguous()
    pred_total = model(x1, a)
    noop = torch.zeros_like(a)
    noop[:,0] = 1. 
    pred_ex = model(x1, noop)
    pred_re = pred_total - pred_ex

gt_effect = torch.from_numpy((state[1:] - state[:-1]))
imgs = get_images(torch.from_numpy(state), torch.from_numpy(action), gt_effect, pred_total, pred_re, pred_ex)
    
J.images(imgs, scale=3, on_interact=action)

imgs = fn.resize(imgs, size=[2*imgs.shape[2], 2*imgs.shape[3]], interpolation=torchvision.transforms.functional.InterpolationMode.NEAREST)

video = (imgs.permute(0,2,3,1) * 255).int()[:200]
torchvision.io.write_video(f"./media/Freeway-1-Predictions.mp4", video, fps=3)


(200, 3, 84, 84) (199, 3, 84, 84)
torch.Size([199, 3, 3, 425]) torch.Size([199, 3, 84, 425])


interactive(children=(IntSlider(value=0, description='x', layout=Layout(width='99%'), max=198), Output()), _do…

HBox(children=(Canvas(height=261, width=1275),), layout=Layout(align_items='center', display='flex', flex_flow…

In [None]:
#torch.save(model.state_dict(), "./models/Freeway-v0.model.pt")

In [24]:
# figure for paper...
x = env.unwrapped.reset()
for i in range(40):
    x, *_ = env.unwrapped.step(0)

n = 2
z = x[111:195,15:99,:].copy()
x[111-n:195+n,15-n:99+n,:] = np.array([255,0,0])

x[111:195,15:99,:] = z
J.image(x, scale=3)


HBox(children=(Canvas(height=630, width=480),), layout=Layout(align_items='center', display='flex', flex_flow=…

<reafference.jnu.image._image.Image at 0x7ff468937190>