# Creating animations with mouse behavior and hippocampal neurons

In [1]:
import ratinabox # Import this magical package: https://github.com/TomGeorge1234/RatInABox
from ratinabox.Environment import Environment
from ratinabox.Agent import Agent
from ratinabox.Neurons import *

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation
import PIL

import numpy as np
from tqdm.notebook import tqdm

import seaborn as sns
import cmasher

In [2]:
def orient_image_along_path(img, path_x, path_y, frame_index, ax, center_of_mass_fraction=(0.5,0.5)):
    
    if frame_index==0:
        frame_index+=1
        
    _, img_width, _, img_height = img.get_extent()
    anchor_position = (path_x[frame_index],path_y[frame_index])   
    image_position = (
        anchor_position[0]-img_width*center_of_mass_fraction[0],
        anchor_position[1]-img_height*center_of_mass_fraction[1])
    
    if (path_x[frame_index] > path_x[frame_index-1]):
        offset = -np.pi/2
    else:
        offset = np.pi/2
        
    theta = np.arctan((path_y[frame_index] - path_y[frame_index-1])/(path_x[frame_index] - path_x[frame_index-1]))+offset
    transform = matplotlib.transforms.Affine2D().translate(*image_position).rotate_around(*anchor_position, theta=theta)
    img.set_transform(transform + ax.transData)

## 2D case

In [None]:
# Setting up RatInABox

Env = Environment() 
Ag = Agent(Env,params={"dt":0.005,"speed_mean":0.3})
PCs = PlaceCells(Ag,params={"n":1, "widths":0.05,"max_fr":30})
GCs = GridCells(Ag,params={"n":1, "gridscale":0.25,"max_fr":30})
BCs = BoundaryVectorCells(Ag,params={"n":1, "widths":0.05,"max_fr":30})
HDs = HeadDirectionCells(Ag, params={"n":1, "max_fr":30})
                          
                          
                          
for i in range(int(40/Ag.dt)): 
    Ag.update()
    PCs.update()
    GCs.update()
    BCs.update()
    HDs.update()

Ag.plot_trajectory()
PCs.plot_rate_map(spikes=True)
GCs.plot_rate_map(spikes=True)
BCs.plot_rate_map(spikes=True)

## Animating the rat

<video width="500" height=" " 
       src="assets/GCs spikes.mp4"  
       controls>
</video>

In [11]:
def setup_environment_figure(figsize=(10,10)):
    fig, ax = plt.subplots(1,1,figsize=figsize,dpi=300)
    fig.set_facecolor("black")
    ax.axis(False)
    ax.set_xlim(-0.2,1.2)
    ax.set_ylim(-0.2,1.2)
    ax.set_facecolor("black")
    return fig, ax

def get_rat_image(ax, rat_img_path="assets/mouse dorsal view.png"):
    rat = PIL.Image.open(rat_img_path)
    rat_x_extent=0.05
    
    rat_img = ax.imshow(rat,extent=(0,rat_x_extent,0,(rat.height/rat.width)*rat_x_extent),zorder=1)

    return rat_img

In [None]:
# ---- Which elements to show:

SHOW_RAT = True 
SHOW_SPIKES = True
SPIKING_CELL = GCs
SHOW_PATH = True


# Calculating necessary arrays
trajectory = np.array(Ag.history["pos"])
spikes_array = np.array(SPIKING_CELL.history["spikes"]).reshape(-1)
spikes_t_mask = spikes_array==True
spike_locations = trajectory[spikes_t_mask, :]


# Plotting
fig, ax = setup_environment_figure()
x,y = trajectory[:,0], trajectory[:,1]

if SHOW_RAT:
    rat_img = get_rat_image(ax)
if SHOW_PATH:
    rat_path = ax.plot(x,y, zorder=0, color="gray",lw=0.75)[0]
if SHOW_SPIKES:
    spike_scatter = ax.plot(spike_locations[:,0],spike_locations[:,1], ".", color="#a645de", ms=10)[0]

def animate(t_pos):
    returned_artists = []
    if SHOW_RAT:
        orient_image_along_path(rat_img,x,y,t_pos,ax,(0.5, 0.5))
        returned_artists.append(rat_img)
    if SHOW_PATH :
        rat_path.set_data(x[:t_pos], y[:t_pos])
        returned_artists.append(rat_path)
    if SHOW_SPIKES:
        spikes_t_mask = spikes_array==True
        spike_locations = trajectory[:t_pos][spikes_t_mask[:t_pos], :]
        spike_scatter.set_data(spike_locations[:,0],spike_locations[:,1])
        returned_artists.append(spike_scatter)
    
    return *returned_artists,

anim = matplotlib.animation.FuncAnimation(fig, animate, frames=tqdm(np.arange(1,len(trajectory),4)), interval=30)
anim.save("assets/GCs spikes.mp4")

## Creating rate maps


<table>
    <tr>
      <td>
      <img src='assets/PC firing map.png'width=300>
      </td>
      <td>
      <img src='assets/GC firing map.png'width=300>
      </td>
      <td>
      <img src='assets/BC firing map.png'width=300>
      </td>
     </tr>
</table>

In [19]:
def plot_rate_map(cells,ax=None, which_neurons=0):
    if ax is None:
        fig, ax = setup_environment_figure()
    cells.plot_rate_map(ax=ax, colorbar=False)
    im = ax.get_images()[0]
    im.set_cmap(cmasher.cosmic)
    im.set_interpolation("bilinear")
    return im

# Place cell
fig, ax = setup_environment_figure()
plot_rate_map(PCs,ax=ax)
plt.savefig("assets/PC firing map.png")

# Grid cell
fig, ax = setup_environment_figure()
plot_rate_map(GCs,ax=ax)
plt.savefig("assets/GC firing map.png")

# Boundary cell
fig, ax = setup_environment_figure()
plot_rate_map(BCs,ax=ax)
plt.savefig("assets/BC firing map.png")

## Animating firing rates


<video width="1000" height=" " 
       src="assets/Rate curves.mp4"  
       controls>
</video>


In [None]:
cells = [PCs,GCs, BCs]

firing_rates = []
for cell in cells:
    firing_rates.append(np.array(cell.history["firingrate"]).reshape(-1))

frames = np.arange(len(PCs.history["firingrate"]))


rate_plots=[]
rate_fills=[]
rate_colors=["#b03bff", "#ff3b4b", "#fff71c", "#1cc6ff"]

fig, axs = plt.subplots(len(cells),1,figsize=(15,4),dpi=200, sharex=True)
fig.set_facecolor("black")

for ax in axs:
    ax.set_facecolor("black")
    ax.axis(False)
    
for k in range(len(cells)):
    rate_plots.append(axs[k].plot(frames, firing_rates[k],color=rate_colors[k])[0])
    rate_fills.append(axs[k].fill_between(frames,0, firing_rates[k],color=rate_colors[k],alpha=0.25))
    
    
def animate(frame):
    for ax in axs:
        ax.collections.clear()
    t_mask = frames>frame
    for k in range(len(cells)):
        y = np.ma.masked_where(t_mask, firing_rates[k])
        rate_plots[k].set_data(frames, y)
        rate_fills[k] = axs[k].fill_between(frames,0,y,color=rate_colors[k],alpha=0.25)
    return rate_plots+rate_fills

animation = matplotlib.animation.FuncAnimation(fig, animate, frames=tqdm(frames[::4]), interval=30)
animation.save("assets/Rate curves.mp4")