# Simple ARE Demo

This simple demo shows how to estimate the ARE in the case where reafference is the same across observations. The example is equivalent to observing something like angular velocity in the artificial ape environment (without the moving cube exafference). Unlike in the original experiment, here the platform only rotates in one direction. Exafference is therefore positive making the example slightly more interesting.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb

def gaussian(mu, sigma):
    return lambda n=1: np.random.normal(mu, sigma, size=n)

def uniform(x1, x2):
    return lambda n=1: x2 - (np.random.rand(n) * (x2 - x1))

def constant(x):
    return lambda n=1: np.full((n,), x)

def categorical(x, p=None):
    x = np.array(x)
    if p is None:
        p = np.ones(x.shape[0]) / x.shape[0]
    p = np.array(p)
    assert x.shape[0] == p.shape[0]
    def _dist(n=1):
        return np.random.choice(x, size=n, p=p)
    return _dist

n = 10000
A = categorical([-1,0,1])(n)        # action turn head left/noop/right
X0 = uniform(-1,1)(n)               # current rotation (observation)
P = categorical([0,1])(n)           # noise term is the platform rotation
# SCM
# experiment here to see how changing the SCM gives different results. A more complex SCM will need a more flexible model (see next cell)
X1 = X0 + A + P 

In [6]:

from tqdm.auto import tqdm
import torch
import torch.nn as nn

class Module(nn.Module):
    
    def __init__(self):
        super().__init__()
        # if the SCM is made more complex, use a more complex model with more layers.
        self.layers = nn.Sequential(
            nn.Linear(2, 1)
        )
        self.optim = torch.optim.Adam(self.parameters(), lr=0.0005)
        self.criterion = nn.MSELoss()
        
    def forward(self, x, a):
        z = torch.cat([x, a], dim=-1)
        return self.layers(z)
    
    def predict(self, x, a):
        pred_total_effect = self.forward(x, a)
        noop = torch.zeros_like(a)
        pred_exafferent_effect = model(x, noop)
        pred_reafferent_effect = pred_total_effect - pred_exafferent_effect.detach()
        return pred_total_effect, pred_reafferent_effect, pred_exafferent_effect
    
    def step(self, x1, x2, a):
        self.optim.zero_grad()
        pred_total_effect, pred_reafferent_effect, pred_exafferent_effect = self.predict(x1, a)
        pred_effect = pred_exafferent_effect + pred_reafferent_effect
        total_effect = x2 - x1
        loss = self.criterion(pred_effect, total_effect)
        loss.backward()
        self.optim.step()
        return loss.detach()
    
    def train(self, x1, x2, a, epochs=1000):
        pbar = tqdm(range(epochs))
        for e in pbar:
            x1, x2, a = self.shuffle(x1, x2, a)
            loss = self.step(x1, x2, a)
            pbar.set_description(f"loss: {loss.item():.5f}")
    
    def shuffle(self, *x):
        indx = torch.randperm(x[0].shape[0])
        return [z[indx] for z in x]
            
x0, x1, a = [torch.from_numpy(z).unsqueeze(1).float() for z in (X0, X1, A)]   
model = Module()
model.train(x0, x1, a, epochs=5000)

  0%|          | 0/5000 [00:00<?, ?it/s]

In [8]:
a1 = torch.ones_like(a)
a2 = torch.zeros_like(a)
a3 = -torch.ones_like(a)

with torch.no_grad():
    t, re, ex = model.predict(x0, a1)
    print(f"ACTION:  1 | GROUND TRUTH | total effect: { 1.5     : .3f} | reafference: { 1.        : .3f} | exafference : { 0.5       : .3f}")
    print(f"ACTION:  1 | ESTIMATE     | total effect: {t.mean() : .3f} | reafference: { re.mean() : .3f} | exafference : { ex.mean() : .3f}\n")
    
    t, re, ex = model.predict(x0, a2)
    print(f"ACTION:  0 | GROUND TRUTH | total effect: { 0.5     : .3f} | reafference: { 0.        : .3f} | exafference : { 0.5       : .3f}")
    print(f"ACTION:  0 | ESTIMATE     | total effect: {t.mean() : .3f} | reafference: { re.mean() : .3f} | exafference : { ex.mean() : .3f}\n")

    t, re, ex = model.predict(x0, a3)
    print(f"ACTION: -1 | GROUND TRUTH | total effect: { -0.5    : .3f} | reafference: { -1.       : .3f} | exafference : { 0.5       : .3f}")
    print(f"ACTION: -1 | ESTIMATE     | total effect: {t.mean() : .3f} | reafference: { re.mean() : .3f} | exafference : { ex.mean() : .3f}\n")
    
    
    

ACTION:  1 | GROUND TRUTH | total effect:  1.500 | reafference:  1.000 | exafference :  0.500
ACTION:  1 | ESTIMATE     | total effect:  1.511 | reafference:  1.010 | exafference :  0.501

ACTION:  0 | GROUND TRUTH | total effect:  0.500 | reafference:  0.000 | exafference :  0.500
ACTION:  0 | ESTIMATE     | total effect:  0.501 | reafference:  0.000 | exafference :  0.501

ACTION: -1 | GROUND TRUTH | total effect: -0.500 | reafference: -1.000 | exafference :  0.500
ACTION: -1 | ESTIMATE     | total effect: -0.509 | reafference: -1.010 | exafference :  0.501

