# The Tolman-Eichenbaum Machine: Unifying Space and Relational Memory through Generalization in the Hippocampal Formation.
James C.R.Whittington, Timothy H.Muller, Shirley Mark, Guifen Chen, Caswell Barry, Neil Burgess, Timothy E.J.Behrens.
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-4-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->

https://www.sciencedirect.com/science/article/pii/S009286742031388X

* [1 Introduction](#1-Introduction)
* [2 Implementation](#2-Implementation)
* [3 Running the Model](#3-RunningtheModel)
* [4 References](#4-References)
* [5 Contributors](#5-Contributors)

For a more detailed explanation of TEM's theory and implementation in theis framework, see [Description, Implementation & Analysis of the Tolman-eichenbaum Machine](https://github.com/LukeHollingsworth/Tollman-Eichenbaum-Implementation/blob/main/Description%2C%20Implementation%20and%20Analysis%20of%20the%20Tollman-Eichenbaum%20Machine.pdf).

## 1. Introduction
The Tolman-Eichenbaum machine [1] is a model capable of generating representations that resemble some of the
neural activity seen in the hippocampus and entorhinal cortex. The model takes inspiration from Tolman’s theory of
an internal representation, or cognitive map [2], and combines these with the relational memory of Eichenbaum [3].
TEM uses these concepts to both sequentially learn within an environment and to learn abstract features, common
across different environments. Over the course of the agent’s trajectory, it learns an abstract representation of the
spatial structure it inhabits, and as a result, is able to make predictions from state-action pairs (apple, North), even
if it has never seen this specific pair before.

The structure of TEM is biologically motivated, and reflects the interactions between key brain areas responsible
for spatial navigation and memory. At the lowest level of the model, sensory observations *x* mirror representations
of the lateral entorhinal cortex (LEC) [4] whilst the most abstract representations *g* are analogous to those found
in the medial entorhinal cortex (MEC) [5]. These two representations are handled separately and are only every
brought together to retrieve memories via the hippocampal (HPC) representation *p*.

TEM has been shown to reproduce some of these well-known neural features, as well being able to remap between
environments. The abstract EC representation *g* resembles grid cells and band cells, shown in fig. 2; by splitting *g*
into 5 temporally-filtered streams, one is able to generate grid cells on a variety of spatial scales (modules). The
formation of memories with place-like representations also encourages remapping across environments. Again, these
place-like fields span multiple sizes and thus mirror the hierarchical composition of hippocampal place fields [6].
Similarly, TEM’s HPC cells demonstrate remapping by not preserving their spatial correlation, but instead relocating
under different environments.

## 2. Implementation
The basis of our implementation is the interaction of two classes, one for the TEM agent and the other for our
continuous environment; these are initialised in a main file and interact within nested training loops, one for the
training iterations, and another for the sequence of walks within an environment. The parameters of both the model
(learning rates and sizes of HPC/EC representations etc.) and environments (width and state density etc.) can be set
prior to training, using a separate file that is passed as an argument to both classes during initialisation. All operations
associated with the model are carried out within the agent class, as opposed to in the original implementation where
compression and initialisation are done outside. Both the agent and environment classes inherit from core NPG
classes and thus can be used in the context of any other experimental setup, accounted for by our test bed. The
ability to train on batches was a novel contribution to the NPG framework, and will be useful in the future when
running models on given experimental trajectories.

For each training iteration, the environment produces a trajectory of 25 steps, within the continuous environment;
this trajectory is determined by the action policy, inherent to the agent class, which is used by the environment to
determine the actions at each step. An example of these walks is given in fig. 8. The step size of the agent, as well
as all parameters associated with the physical environment, are fully customisable. In the case of TEM, each batch
consists of environments of varying sizes, encouraging the agent to learn the abstract structure, divorced from any
notions of width and depth.

## 3. Running the Model

#### TEM Virtual Environment
This implementation uses an older version of TensorFlow (1.9.0) and subsequently requires Python version 3.6. We suggest setting up a virtual (conda) environment with the following packages:
```
Python=3.6
TensorFlow=1.9.0
setuptools=39.0.1

astropy
matplotlib
numpy
seaborn
tqdm
sci-kit image
scipy
```

#### Running TEM
The TEM model is run from the [notebook file](whittington_2020_examples.ipynb). To run the model with default parameters, simply *Run All*. If you would like to alter these parameters, they can be found in the [parameters file](agents/TEM_parameters.py). The model itself can be found [here](agents/whittington_2020.py).

## 4. References
[1] J. C. Whittington, T. H. Muller, S. Mark, G. Chen, C. Barry, N. Burgess, and T. E. Behrens, “The tolman-eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal forma-tion,” Cell, vol. 183, pp. 1249–1263.e23, Nov. 2020.

[2] . Krupic, N. Burgess, and J. O’Keefe, “Neural representations of location composed of spatially periodic bands," Science, vol. 337, pp. 853–857, Aug. 2012.

[3] E. C. Tolman, “Cognitive maps in rats and men.,” Psychological Review, vol. 55, no. 4, pp. 189–208, 1948.

[4] S. S. Deshmukh and J. J. Knierim, “Representation of non-spatial and spatial information in the lateral entorhinal cortex,” Frontiers in Behavioral Neuroscience, vol. 5, 2011.

[5] F. Savelli, D. Yoganarasimha, and J. J. Knierim, “Influence of boundary removal on the spatial representations
of the medial entorhinal cortex,” Hippocampus, vol. 18, pp. 1270–1282, Dec. 2008.

[6] M. L. Shapiro, H. Tanila, and H. Eichenbaum, “Cues that hippocampal place cells encode: Dynamic and hierarchical representation of local and distal stimuli,” Hippocampus, vol. 7, no. 6, pp. 624–642, 1997.


<img src=“/nfs/nhome/live/lhollingsworth/Documents/Miscelaneous/TEM_image.jpg” width=1000>

# Running Model

In [None]:
import sys
sys.path.append("../")
import numpy as np
import random
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from sehec.arenas.TEMenv import *
from sehec.agents.whittington_2020 import *
from sehec.agents.TEM_extras.TEM_parameters import *

pars = default_params()

Initialise environment(s) and agent.

In [None]:
env_name = 'TEMenv'
mod_name = 'TEM'

envs = TEMenv(environment_name=env_name, **pars)
agent = TEM(model_name=mod_name, **pars)

Run the model.

In [None]:
for i in tqdm(range(pars['train_iters'])):
    adjs, trans, allowed = envs.step()
    n_walk, _ = agent.update(adjs, trans, allowed, 0, i)
    for j in range(n_walk - 1):
        # RL Loop
        adjs, trans, allowed = envs.step()
        _, history = agent.update(adjs, trans, allowed, j+1, i)

envs.plot_trajectory(history_data=history)
plt.savefig('saved_plot_trajectory{0}.png'.format(j))

Reset the model.

In [None]:
agent.reset()

## Plotting Model Results

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import datetime
import re
from os import listdir
import sys
sys.path.insert(0, '../Summaries')
import sehec.agents.TEM_extras.TEM_plotting_functions
from sehec.agents.TEM_extras.TEM_arb_functions import *
from sehec.agents.TEM_extras.TEM_helper_functions import *
from sehec.agents.TEM_extras.TEM_behaviour_analyses import *
from sehec.agents.TEM_extras.TEM_environment_functions import *

*FIX SAVE_DIRS PROBLEM*

In [None]:
save_dirs = ['/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/EHC_model_comparison/sehec/models/Summaries/']

date = '2022-11-01'
run = '1'

recent = -1
data, para, list_of_files, save_path = TEM_plotting_functions.get_data(save_dirs, run, date, recent)

A_RNN, x_all, g_all, p_all, p_gen_all, acc_s_t_to, acc_s_t_from, positions, timeseries = data
params, widths, batch_id, g_size, p_size, s_size, s_size_comp, n_freq, width, states = para
print(params['world_type'])

mult = 4 if params['world_type'] == 'tonegawa' else 4  # upsample
smoothing = 1
cmap = 'jet'
maxmin=True

seaborn.set_style(style='white')
seaborn.set_style({'axes.spines.bottom': False,'axes.spines.left': False,'axes.spines.right': \
                   False,'axes.spines.top': False})

masks, g_lim, p_lim = TEM_plotting_functions.sort_data(g_all, p_all, widths, mult, smoothing, params, batch_id, \
                                         g_max_0=False, p_max_0=True)

env0 = 1
env1 = 2
env2 = 3
env3 = 4

Agent coverage:

In [None]:
plt.figure(figsize=(10,5))
for i, env in enumerate([env2, env3]):
    plt.subplot(1,2,i+1)
    cell_reshaped = reshape_cells(positions[env], widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped)
    plt.colorbar()

plt.show()

print(min(positions[env0]), min(positions[env1]))

Accuracy Map:

In [None]:
plt.figure(figsize=(10,10))
for i, env in enumerate([env0, env1]):
    plt.subplot(2,2,i+1)
    cell_reshaped = reshape_cells(acc_s_t_to[env], widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped,vmax=1,vmin=0)
    plt.title('accuracy to')

    plt.colorbar()
    plt.subplot(2,2,i+3)
    cell_reshaped = reshape_cells(acc_s_t_from[env], widths[batch_id[env]], params['world_type'])
    plt.imshow(cell_reshaped,vmax=1,vmin=0)
    plt.title('accuracy from')

    plt.colorbar()

plt.show()

#### Entorhinal Cortex
Grid cells:

In [None]:
square_plot(g_all[env0], widths[batch_id[env0]], name='g0', maxmin=maxmin, shiny=None, \
            hexy=params['world_type'], lims=g_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env0])

square_plot(g_all[env1], widths[batch_id[env1]], name='g1', maxmin=maxmin, shiny=None, \
            hexy=params['world_type'], lims=g_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env1])

Autocorrelations:

In [None]:
square_autocorr_plot(g_all[env0], widths[batch_id[env0]], name='g0_auto', \
                     hexy=params['world_type'], mult=mult, smoothing=smoothing, cmap=cmap, circle=True)

square_autocorr_plot(g_all[env1], widths[batch_id[env1]], name='g1_auto', \
                     hexy=params['world_type'], mult=mult, smoothing=smoothing, cmap=cmap, circle=True)

#### Hippocampus
Place cells:

In [None]:
square_plot(p_all[env0],widths[batch_id[env0]], name='p0', shiny=None,\
            hexy=params['world_type'], lims=p_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env0])

square_plot(p_all[env1],widths[batch_id[env1]], name='p1', shiny=None, \
            hexy=params['world_type'], lims=p_lim, mult=mult, smoothing=smoothing, cmap=cmap, mask=masks[env1])

## Behavioural Analysis

In [None]:
seaborn.set_style(style='white')

params['acc_simu'] = 1  # how accurate simulated node/edge agent is
recent = -1  # how far back into history of saved data
filt_size = 61  # smoothing window size (must be odd)
n = 10
fracs = [x /n for x in range(n+2)]  # for assessing accuracy within certain proportions of nodes visited

# for steps since visted analysis - assess accuracy within those steps
if params['world_type'] in ['family_tree', 'line_ti', 'tonegawa']:
    a_s = [0, 10, 20]
else:
    a_s = [0, 4, 10, 20, 40, 60, 100, 200, 300, 400, 600]

# Load data
positions_link, coos, env_info, distance_info = link_inferences(save_path, list_of_files, widths, batch_id, params,\
                                                                index=recent)
n_states, wids = env_info

n_available_states = np.zeros_like(wids)
n_available_edges = np.zeros_like(wids)

n_available_edgess = [460, 460, 561, 561, 288, 369, 460, 561, 288, 369, 460, 561, 288, 288, 369, 369]
n_available_statess = [100, 100, 121, 121, 64, 81, 100, 121, 64, 81, 100, 121, 64, 64, 81, 81]

for i in range(len(n_available_edgess)):
    n_available_edges[i] = n_available_edgess[i]

for i in range(len(n_available_statess)):
    n_available_states[i] = n_available_statess[i]


# Perform behavioural analayses. Partition results into environments of same size
allowed_widths = sorted(np.unique([widths[b_id] for b_id in batch_id]))
results = []
for allowed_wid in allowed_widths:
    p_cors, nodes_visited_all, edges_visited_all, time_vis_anal = \
        analyse_link_inference(allowed_wid, fracs, a_s, positions_link, coos, env_info, params, n_available_edges, n_available_states)
    p_cors = [ind for ind in p_cors if len(ind)>0]
    results.append([p_cors, nodes_visited_all, edges_visited_all, time_vis_anal])

Model Accuracy:

In [None]:
plot_acc_vs_sum_nodes_edges(results, allowed_widths, coos, filt_size, wids, n_available_states, n_available_edges)