# Adaptivity Example

This experiment is presented in the appendix of the paper.

As biological agents experience changes in their motor system over their lifetime, for example, through growth or disease, the same action will consistently give rise to different effects at different times. For a model of reafference to be viable it must be able to adapt to such changes. The aim of the experiment is to show that Alg. 1 can indeed adapt to changes. The experiment is performed with the Cartpole environmnet, the pole and cart is treated as the body of the agent, and the length of the pole is changed half way through training.

In [None]:
import gym
import numpy as np
import torchinfo
import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# local installation
from reafference.environment.cartpole import make, ground_truth, get_images, make_episode, make_dataset, plot
from reafference.model import CartPoleNet

DEVICE = "cuda:0"

env = make(render=False, angle_threshold=float("inf"), euler=True)

state_shape = env.observation_space.shape
action_shape = (env.action_space.n,)
latent_shape = (512,)
epochs = 50

model = CartPoleNet(state_shape, action_shape, latent_shape).to(DEVICE)
optim = torch.optim.Adam(model.parameters(), lr=0.0005)
criterion = torch.nn.MSELoss()
torchinfo.summary(model, input_data=(torch.zeros(2, *state_shape), torch.zeros(2, *action_shape)))


In [None]:
def train(env, model, validators=[]):
    val_losses = []
    dataset = make_dataset(env, num_episodes=250, device=DEVICE)
    loader = DataLoader(dataset, batch_size=1028, shuffle=True, drop_last=False)
    epoch_iter = tqdm(range(epochs))
    for e in epoch_iter:
        avg_loss = []
        for x, a in loader:
            x1, x2, a = x[:,0], x[:,1], a[:,0]
            optim.zero_grad()
            pred_total_effect = model(x1, a)        
            noop = torch.zeros_like(a)
            noop[:,0] = 1. 
            pred_exafferent_effect = model(x1, noop)
            pred_reafferent_effect = pred_total_effect - pred_exafferent_effect.detach()
            pred_effect = pred_exafferent_effect + pred_reafferent_effect
            total_effect = x2 - x1
            loss = criterion(pred_effect, total_effect)
            loss.backward()
            avg_loss.append(loss.detach())
            optim.step()
        avg_loss = torch.stack(avg_loss).cpu().numpy().mean()
        epoch_iter.set_description(f"Training Loss: {avg_loss : .5f}") 
        val_losses.append([v(model) for v in validators])
    return np.array(val_losses)

def validater(env, criterion=torch.nn.MSELoss()):
    dataset = make_dataset(env, num_episodes=1, max_episode_length=250, device=DEVICE)
    state, action = dataset.tensors
    x1, x2, a = state[:,0], state[:,1], action[:,0]
    gt_total, gt_re, gt_ex = [torch.from_numpy(y) for y in ground_truth(env, torch.cat([x1, x2[-1:]], dim=0).cpu().numpy())]
    
    def validate(model):
        with torch.no_grad():
            pred_total = model(x1, a)
            noop = torch.zeros_like(a)
            noop[:,0] = 1. 
            pred_ex = model(x1, noop)
            pred_re = pred_total - pred_ex
            #print(gt_total.shape, pred_total.shape)
        return criterion(pred_total.cpu(), gt_total), criterion(pred_re.cpu(), gt_re), criterion(pred_ex.cpu(), gt_ex)
    return validate

def tester(env):
    dataset = make_dataset(env, num_episodes=1, max_episode_length=100, device=DEVICE)
    state, action = dataset.tensors
    x1, x2, a = state[:,0], state[:,1], action[:,0]
    gt_total, gt_re, gt_ex = ground_truth(env, torch.cat([x1, x2[-1:]], dim=0).cpu().numpy())
    def test(model):
        with torch.no_grad():
            pred_total = model(x1, a).cpu().numpy()
            noop = torch.zeros_like(a)
            noop[:,0] = 1. 
            pred_ex = model(x1, noop).cpu().numpy()
            pred_re = pred_total - pred_ex

        fig = plt.figure(figsize=(12,4))
        plot(gt_total, style="-", fig=fig)
        plt.gca().set_prop_cycle(None)
        plot(pred_total, style="--", fig=fig, label_prefix="Predicted ")

        fig = plt.figure(figsize=(12,4))
        plot(gt_re, style="-", fig=fig)
        plt.gca().set_prop_cycle(None)
        plot(pred_re, style="--", fig=fig, label_prefix="Predicted ")

        fig = plt.figure(figsize=(12,4))
        plot(gt_ex, style="-", fig=fig)
        plt.gca().set_prop_cycle(None)
        plot(pred_ex, style="--", fig=fig, label_prefix="Predicted ")
    return test

In [None]:
# Show difference after pole length change
env1 = make(render=False, angle_threshold=float("inf"), euler=True) 
env1.unwrapped.length = 0.6
env1.unwrapped.polemass_length = env1.unwrapped.masspole * env1.unwrapped.length

env2 = make(render=False, angle_threshold=float("inf"), euler=True) 
env2.unwrapped.length = 0.4
env2.unwrapped.polemass_length = env2.unwrapped.masspole * env2.unwrapped.length

validate1 = validater(env1)
validate2 = validater(env2)

epochs = 100
loss1 = train(env1, model, validators=[validate1, validate2, validate3])
epochs = 100
loss2 = train(env2, model, validators=[validate1, validate2, validate3])

loss = np.concatenate([loss1, loss2])#, loss3])



In [None]:
_loss = np.log(loss)
v1, v2 = _loss[:,0], _loss[:,1]

fig, axes = plt.subplots(nrows=1, ncols=2, sharey=False, sharex=True, figsize=(12,3))
axes[0].set_ylabel("log MSE")

axes[0].plot(np.arange(v1.shape[0]), v1, label=["total", "reafferent", "exafferent"])
axes[1].plot(np.arange(v2.shape[0]), v2)

line, label = axes[0].get_legend_handles_labels()
fig.legend(line, label, loc='lower center', ncol=len(line), bbox_to_anchor=(0.5, -0.1))

lengths = [env1.unwrapped.length,env2.unwrapped.length,env3.unwrapped.length]
for v, l, ax in zip([v1,v2,v3], lengths, axes):
    ax.vlines([100], v.min(), v.max(), linestyles="--", color="black", alpha=0.5)
    ax.set_title(f"Test Episode with pole length={l}")
    ax.set_yticks([])
    
fig.tight_layout()
#plt.savefig("./media/cartpole-adaptive", dpi=100, bbox_inches='tight')

In [None]:
tester(env1)(model)
tester(env2)(model)
