# Experiment

This code is for an the 2D grid world experiment shown in the paper appendix. It is a simple and intuative example that shows what Alg. 1 is learning.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
from tqdm.auto import tqdm
import gym

import reafference.jnu as J
from reafference.data.iterators import gym_iterator


# simple grid world environment, see appendix in paper.

class Grid2D(gym.Env):
    
    def __init__(self, n=9, m=2, stochastic_actions=False, noise=None):
        super().__init__()
        if noise is None:
            noise = lambda x: x 
        self.observation_space = gym.spaces.Box(0,1,shape=(1,n,n))
        if not stochastic_actions:
            self.action_space = gym.spaces.Discrete(5)
            self.actions = lambda a: np.array([[0,0], [0,1], [1,0], [0,-1], [-1,0]], dtype=np.int64)[a]
        else: # use stochastic actions
            self.action_space = gym.spaces.Discrete(3)
            def _actions(a):
                assert self.action_space.contains(a)
                if a == 0:
                    return np.array([0,0])
                r = np.random.choice([-1,1])
                if a == 1:
                    return np.array([0, r])
                else:
                    return np.array([r, 0])
            self.actions = _actions
            
        self.state = np.zeros((n,n))
        self.body = np.ones((m,m)) 
        self.noise = noise
        self.reset()
        
    @property
    def position(self):
        return self._position
    
    @position.setter
    def position(self, value):
        value = np.array(value).astype(np.int64)
        n = self.observation_space.shape[-1]
        m = self.body.shape[-1]
        self._position = value
        self._position[0] = max(0, min(value[0], n - m))
        self._position[1] = max(0, min(value[1], n - m))
        self.state = np.zeros((n,n))
        c1, c2 = self.position
        self.state[c1:c1+m,c2:c2+m] = self.body
    
    def step(self, action):
        self.position = self.position + self.actions(action)
        self.position = self.noise(self.position)
        return self.state[np.newaxis,...].astype(np.float32), 0., False, dict()
    
    def reset(self):
        n = self.observation_space.shape[-1]
        m = self.body.shape[-1]
        self.state = np.zeros((n,n))
        self.position = [(n - m) // 2, (n - m) // 2]
        return self.state[np.newaxis,...].astype(np.float32), dict()

def all_noise(p):
    r = np.random.choice([-1,1])
    if np.random.uniform() > 0.5:
        return p + np.array([r,0])
    else:
        return p + np.array([0,r])
    
def wind_noise(p):
    if np.random.uniform() > 0.5:
        return p + np.array([np.random.choice([-1,1]),0])
    else:
        return p
    

    
env = Grid2D(stochastic_actions=False, noise=wind_noise)
def episode(env, max_length=100):
    state, action, *_ = zip(*gym_iterator(env, max_length=max_length))
    state, action = np.stack(state), np.stack(action)
    return state[:-1], state[1:], action[:-1]
x1, x2, a = episode(env)
J.images(x1, on_interact=a, scale=10)

In [None]:
import torch
import torch.nn as nn

class Module(nn.Module):
    
    def __init__(self, state_shape, action_shape, epochs=10):
        super().__init__()
        self.state_shape = state_shape
        self.action_shape = action_shape
        s = np.prod(state_shape)
        a = np.prod(action_shape)
        self.layers = nn.Sequential(
            nn.Linear(s + a, 512), nn.LeakyReLU(),
            nn.Linear(512, 512), nn.LeakyReLU(),
            nn.Linear(512, 512), nn.LeakyReLU(),
            nn.Linear(512, s)
        )
        self.optim = torch.optim.Adam(self.parameters(), lr=0.0005)
        self.criterion = nn.MSELoss()
        
    def forward(self, x, a):
        x = x.view(x.shape[0], -1)
        a = torch.eye(self.action_shape[0], device=a.device)[a.long()]
        #print(x.shape, a.shape)
        z = torch.cat([x, a], dim=-1)
        return self.layers(z).reshape(x.shape[0], *self.state_shape)
    
    def predict(self, x, a):
        pred_total_effect = self.forward(x, a)
        noop = torch.zeros_like(a)
        pred_exafferent_effect = self.forward(x, noop)
        pred_reafferent_effect = pred_total_effect - pred_exafferent_effect.detach()
        return pred_total_effect, pred_reafferent_effect, pred_exafferent_effect
    
    def step(self, x1, x2, a):
        self.optim.zero_grad()
        pred_total_effect, pred_reafferent_effect, pred_exafferent_effect = self.predict(x1, a)
        pred_effect = pred_exafferent_effect + pred_reafferent_effect
        total_effect = x2 - x1
        loss = self.criterion(pred_effect, total_effect)
        loss.backward()
        self.optim.step()
        return loss.detach()
    
    def train(self, x1, x2, a, epochs=1000):
        pbar = tqdm(range(epochs))
        for e in pbar:
            x1, x2, a = self.shuffle(x1, x2, a)
            loss = self.step(x1, x2, a)
            pbar.set_description(f"loss: {loss.item() :.6f}")
    
    def shuffle(self, *x):
        indx = torch.randperm(x[0].shape[0])
        return [z[indx] for z in x]

x0, x1, a = zip(*[episode(env, max_length=100) for i in range(100)])
x0, x1, a = np.concatenate(x0), np.concatenate(x1), np.concatenate(a)
print(x0.shape, a.shape)
x0, x1, a = torch.from_numpy(x0).cuda(), torch.from_numpy(x1).cuda(), torch.from_numpy(a).cuda()

model = Module(state_shape=env.observation_space.shape, action_shape=(env.action_space.n,)).cuda()
model.train(x0, x1, a, epochs=1000)

In [None]:
import torchvision.io
def show(action, n = 100):
    _x, _a = x0[:n], a[:n]
    p = 2
    t, re, ex = model.predict(_x, torch.zeros_like(_a) + action) 
    t = torchvision.transforms.functional.resize(t, (90, 90), interpolation=0)
    re = torchvision.transforms.functional.resize(re, (90, 90), interpolation=0)
    ex = torchvision.transforms.functional.resize(ex, (90, 90), interpolation=0)
    _x = torchvision.transforms.functional.resize(_x, (90, 90), interpolation=0)
    b = torch.ones(_x.shape[0], 1, _x.shape[2], p).cuda()
    eff = (torch.cat([b,t,b,re,b,ex], dim=-1) + 1) / 2
    imgs = 1 - torch.clip(torch.cat([_x, eff], dim=-1), 0, 1)
    b = torch.zeros(imgs.shape[0], imgs.shape[1], imgs.shape[2] + p * 2, imgs.shape[3] + p * 2)
    b[:,:,p:-p,p:-p] = imgs
    imgs = b
    J.images(imgs)
    return imgs[18]
   
imgs = []
for _a in range(env.action_space.n):
    print("ACTION:", _a)
    img = show(_a)
    #torchvision.io.write_png((img.cpu() * 255).byte(), f"./media/World2D-Action-{_a}.png")
    

# Long term effects

Below shows how to compute the long term effects from single-step estimates. In this instance, there are only exafferent effects, unless one specifies more than one action in the forecasting.

In [None]:
x0 = env.reset()[0]
a = np.zeros(10)

def forecast(model, x0, an):
    assert x0.shape[0] == a.shape[1]
    x0, an = x0.cuda(), an.cuda()
    def _forecast(x0, an):
        for a in an:
            yield x0, a
            t, re, ex = model.predict(x0, a)
            x0 = x0 + t
        #yield x0, torch.zeros_like(a)
    X, A = zip(*_forecast(x0, an))
    return torch.cat(X).clip(0,1), torch.cat(A)
      
x0, a = torch.from_numpy(x0).unsqueeze(0), torch.from_numpy(a).unsqueeze(1)

X0, A = forecast(model, x0, a)
a[0] = 1
X, A = forecast(model, x0, a)

J.images(torch.cat([X0, X, (X - X0 + 1) / 2], dim=-1), on_interact=A,scale=10)


# 