# Explanations of the MazeAgent_reviewResponses Script from Tom George's STDP-SR Repository
The official repository that the code is taken from can be found here: https://github.com/TomGeorge1234/STDP-SR/tree/main


### Imports and Packages
Any imports and packages that need to be added to the python environment to run the script

In [None]:
import numpy as np 

import pandas as pd 
from tqdm.notebook import tqdm
from datetime import datetime 
import numbers
from pprint import pprint as pprintq
import os
from scipy.stats import vonmises
from scipy.spatial import distance_matrix
import dill 
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt 
import matplotlib
from matplotlib.animation import FuncAnimation

import tomplotlib.tomplotlib as tpl

# set a directory for the figures to be saved to
tpl.figureDirectory = './figures'
tpl.set_colorscheme(colorscheme=2) # set the colour scheme for the figures. See tomplotlib documentation for options

## Default Parameters of the agent
The parameters that the agent will use unless they are defined otherwise

### Maze Parameters:

- **mazeType** = defines the maze that the agent moves in -> can be defined in the getWalls() function

- **stateType** = defines the "shape" of the feature that is TD (Temporal difference) learned on -> that feature can be Gaussian, onehot (same magnitude within, zero outside), gaussianCS (no clue), circles, bump (also no clue) -> in initialise it also defines a "fourier" option

- **Movement Policy** = defines the movement policy of the agent -> what that policy looks like, what options there are and how it is updated at each time step dt is further define din the function movementPolicyUpdate()

- **roomSize** = a parameter that defines the size of the respective maze that was set in mazeType

- **dt** = is the simulation time step -> discretises continuous time and determines how often certain other variables are updated. Default is "None" bcs that way it chooses the largest time step that is safe to use -> keep a balance between computational power needed and model precision

- **dx** = discretises space -> steps the agent makes

- **speedScale** = defines the movement speed of the agent -> meters /sec

- **rotSpeedScale** = how much the agent rotates per second -> radians/sec

- **initPos** = initial position of the agent when the agent is initialised, not at the start of every trial -> e.g. if you initialise a new agent at every trial, this initial position is used, but if you use the same agent for every trial, only the first will have the initial position

- **initDir** = initial direction of the agent in which the agent moves -> same as for initPos

- **nCells** = how many features to use -> *Question:* basis features or successor features? => probably basis tho

- **centres** = this is an array of the positions of teh receptive fields -> this overwrites nCells -> probably if there are more or less cells in "centres" than in "nCells" => "centres" basically has the final word

- **sigma** = controls the width of the place fields of the basis features -> each cell has receptive field centred somehwere in the maze and sigma defines the width of that field: small sigma = narrow place field; large sigma = wide place field => kinda depends on the stateType tho and is for example irrelevant for onehot place cells. For Guassian cells sigma effectively controls the width of the curve, while for onehot, the size depends on nCells (this needs to be double checked tho!!)

- **doorsClosed** = whether the doors are open or closed in a multicompartment maze -> in the case of George study, this probably applied to the 2-room maze (*Note to self: Might be interesting to look at if we reach DNMTP situation with project -> i.e. can we make a if condition or sth to determine whether an arm is open or not*)

- **reorderCells** = whether or not to reorder the cells that were defined in "centres" or "nCells". This doesn't change the actual ordering/placing of cells in the maze but just the way in which they are indexed -> i.e. indexed fom left to right in 1D or in columns/rows in 2D

- **firingRateLookUp** = when true, the model uses a table of precomputed firing rates to get the firing rates of the basis features and doesn't need to compute the firing rates every time form the tuning curve of the place fields -> less computationally demanding (*Note: would be interesting to know how x is discretised in the table*)

- **biasDoorCross** = whether the bias to cross the door is "ON" or "OFF" -> relevant for 2-room maze in George et al. (2023) study

- **biasWallFollow** = whether the bias to follow walls (thigmotaxis) is "ON" or "OFF" -> relvant for 2-room maze 

### TD Parameters:

- **tau** = the TD decay time -> determines the predictive time horizon of the SR and is in that sense sort of analogous to the discount factor in discrete-time TD learning. large tau -> long predictive time horizon; small tau -> short predictive time horizon
- **TDdx** = controls when in space TD updates happen -> so the distance between TD learning updates in meters
- **alpha** = the learning rate (as in usual RL TD learning) -> controls convergence speed of successor features -> larger alpha => faster convergence bcs faster learning
- **successorFeatureNorm** = a scaling constant that rescales the magnitude of the SR : *M ← Norm × M* Apparently used to improve numerical stabiloity of the SR and to keep TD and STDP SRs comprable
- **TDreg** = adds a weight decay apparently to make sure the weights don't grow too large and helps prevent divergence with long simulations

### STDP Parameters:

- **peakFiringRate** = peak firing rate of place cell at center of place field if preferred phase is at centre of field
- **tau_STDP_plus** = the pre trace decay time -> time constant of the pre-before-post potentiation window -> how long a pre synaptic spike leaves a trace
- **tau_STDP_minus** = post trace decay time -> how long a post spike influences LTD -> basically the post-before-pre depression window
- **a_STDP** = pre-before-post potentiation factor (post-before-pre = 1) -> the relative strngth of potetiation vs depression
- **eta** = the STDP learning rate (equivalent to alpha in TD) -> higher eta = faster weight change; lower eta = slower convergence
- **baseLineFiringRate** = a constant firing rate that is added everywhere -> usually set to 0 (default) to isolate spatially selective firing in changing weights
- **use_full_STDP_rule** = whether or not to use full STDP rule or a simplified version (*Note: what is the simplified version tho?*)
- **online_mapping** = determines how to map CA3 -> CA1 during learning -> "identity" means that CA1 initiall mirrors CA3 and then STDP modifieds this towards the successor representation
- **rownorm** = normalises outgoing weights from each neuron -> without pure STDP learning (according to ChatGPT, so needs to be double-checked!)

### Phase Precession Parameters:

- **thetaFreq** = determines oscillation frequency
- **precessFraction** = Fraction of 2π through which preferred firing phase shifts across a place field -> 0.5 means it shifts through half a theta cycle => sort of controls the strength of the phase precession -> if higher -> stronger ordering inside a single traversal
- **kappa** = The von Mises spread parameter -> determines the phase tuning sharpness, i.e. hogh tightly the firing of spikes is locked to their preferred phase => if kappa larger -> stronger sequential compression (*Note: maybe sth to look at if we change theta frequency, where we then also need tighter phase locking*)


In [None]:
#Default parameters for MazeAgent 
defaultParams = { 

          #Maze params 
          'mazeType'            : 'oneRoom',  #type of maze, define in getWalls() function
          'stateType'           : 'gaussian', #feature on which to TD learn (onehot, gaussian, gaussianCS, circles, bump)
          'movementPolicy'      : 'raudies',  #movement policy (raudies, random walk, windows screensaver)
          'roomSize'            : 1,          #maze size scaling parameter, metres
          'dt'                  : None,       #simulation time disretisation (defualts to largest )
          'dx'                  : 0.01,       #space discretisation (for plotting, movement is continuous)
          'speedScale'          : 0.16,       #movement speed scale, metres/second
          'rotSpeedScale'       : None,       #rotational speed scale, radians/second
          'initPos'             : [0.1,0.1],  #initial position [x0, y0], metres
          'initDir'             : [1,0],      #initial direction, unit vector
          'nCells'              : None,       #how many features to use
          'centres'             : None,       #array of receptive field positions. Overwrites nCells
          'sigma'               : 1,          #basis cell width scale (irrelevant for onehots)
          'doorsClosed'         : True,       #whether doors are opened or closed in multicompartment maze
          'reorderCells'        : True,       #whether to reorde the cell centres which have been provided
          'firingRateLookUp'    : False,      #use quantised lookup table for firing rates 
          'biasDoorCross'       : False,      #if True, in twoRoom maze door crossings are biased towards
          'biasWallFollow'      : True,       #if True, agent aligns to wall when gets too near.

          #TD params 
          'tau'                 : 4,          #TD decay time, seconds
          'TDdx'                : 0.01,       #rough distance between TD learning updates, metres 
          'alpha'               : 0.01,       #TD learning rate 
          'successorFeatureNorm': 100,        #linear scaling on successor feature definition found to improve learning stability
          'TDreg'               : 0.01,       #L2 regularisation 
          
          #STDP params
          'peakFiringRate'      : 5,          #peak firing rate of a cell (middle of place field,preferred theta phase)
          'tau_STDP_plus'       : 20e-3,      #pre trace decay time
          'tau_STDP_minus'      : 40e-3,      #post trace decay time
          'a_STDP'              : -0.4,       #pre-before-post potentiation factor (post-before-pre = 1) 
          'eta'                 : 0.05,       #STDP learning rate
          'baselineFiringRate'  : 0,          #baseline firing rate for cells 
          'use_full_STDP_rule'  : False,      #whether to use full STDP rule     
          'online_mapping'      : 'identity',  #how to map CA3-->CA1 during learning
          'rownorm'             : False,
            


          #Theta precession params
          'thetaFreq'           : 10,         #theta frequency
          'precessFraction'     : 0.5,        #fraction of 2pi the prefered phase moves through
          'kappa'               : 1,          # von mises spread parameter

}

## "def init" & "def updateParams"

### **def init**
- Sets the parameters of the agent or loads them from a file if they are define din a file before.
- Initialises the rest of the class (i.e. the maze, speed etc bcs also set in teh parameters).
- Uses updateParams to update any parameters.

### **def updateParams**
- simply updates the parameters of the class based on what is given in the dictionary

In [None]:
class MazeAgent():
    """MazeAgent defines an agent moving around a maze. 
    The agent moves according to a predefined movement policy
    As the agent moves it learns a successor representation over state vectors according to a TD learning rule 
    The movement polcy is 
        (i)  continuous in space. There is no discretisation of location. Time is discretised into steps of dt
        (ii) completely decoupled from the TD learning.
    TD learning is 
        (i)  state general. i.e. it learns generic SRs for feature vectors which are not necessarily onehot. See de Cothi and Barry, 2020  
        (ii) time continuous. Defined in terms of a memory decay time tau, not unitless gamma. Any two states can be used fro a TD learning step irrespective of their seperation in time. 
    As the rat moves and learns its position and time stamps are continually saved. Periodically a snapshot of the current SR matrix and state of other parameters in the maze are also saved. 
    """   
    def __init__(self,
                params={},
                loadFromFileCalled=None):
        """Sets the parameters of the maze anad agent (using default if not provided) 
        and initialises everything. This includes: 
        •initilising history dataframes
        •making the maze (a dictionary of "walls" which cant be crossed)
        •setting position, velocity, time
        •discretising space into coordinates for later plotting
        •initialising basis features (gaussian centres, fourier frequencies etc.)
        •initialising SR matrix 

        Args:
            params (dict, optional): A dictionary of parameters which you want to differ from the default. Defaults to {}.
        """        
        if loadFromFileCalled is not None: 
            self.loadFromFile(name=loadFromFileCalled)
            
        else:
            print("Setting parameters")
            for key, value in defaultParams.items():
                setattr(self, key, value) #set each parameter as an attribute of the class, using default if not provided in params
            self.updateParams(params) #update any parameters provided in params dictionary

            print("Initialising")
            self.initialise() #initialise the rest of the class (position, velocity, time, maze, features, SR matrix etc.)
            print("DONE")

    def updateParams(self,
                     params : dict):        
        """Updates parameters from a dictionary. 
        All parameters found in params will be updated to new value

        Args:
            params (dict): dictionary of parameters to change
            initialise (bool, optional): [description]. Defaults to False.
        """        
        for key, value in params.items():
            setattr(self, key, value)

### **def initialise** 
- initialises the agent and the maze 
    - defines data frames to store output in I guess (history dataframes)
    - sets initial velocity, position and direction of teh agent, based on what's in init or params
    - initialises the time (set to 0), a run counter (set to 0 at beginning), and sets the theta phase -> i.e. initial phase of theta when agent is initialised. When called later it gives the current theta phase at time *t* in  0, 2π
    - sets up the maze as defined in params (also considering the "doors" parameter)
    - creates a spatial discretisation grid of the maze, mainly for plotting and spatial lookup (e.g., rate maps, SR maps)
    - checks boundary coordinates of the maze (extent) as well as height and width
    - create set of evenly spaced x- and y-coordinates
    - handles undefined parameters -> when they are not given (*note: why not default then?*)
- initialises basis features and successor matrix M
    - sets basis features and generates pace cell centres if not given -> not an issue tho if we just give the centres
- reorders place cells in the matrices from left to right or from inside out
- initialises STDP weight matrix and STDP pre and post traces

    
    

In [None]:
def initialise(self): #should only be called once at the start 
        """Initialises the maze and agent. Should only be called once at the start.
        """        
        #initialise history dataframes
        print("   making state/history dataframes")
        self.mazeState = {}
        self.history = pd.DataFrame(columns = ['t','pos','delta','runID']) 
        self.snapshots = pd.DataFrame(columns = ['t','M','W','mazeState'])
        self.spikedata = {'CA3':{'times':[],'ids':[]}, 'CA1':{'times':[],'ids':[]}}

        #set pos/vel
        print("   initialising velocity, position and direction")
        self.pos = np.array(self.initPos)
        self.speed = self.speedScale
        self.dir = np.array(self.initDir)

        #time and runID
        print("   setting time/run counters")
        self.t = 0
        self.runID = 0  
        self.thetaPhase = self.thetaFreq*(self.t%(1/self.thetaFreq))*2*np.pi


        #make maze 
        print("   making the maze walls")
        self.walls = getWalls(mazeType=self.mazeType, roomSize=self.roomSize)
        walls = self.walls.copy()
        if self.doorsClosed == False: 
            del walls['doors']
            self.mazeState['walls'] = walls
        elif self.doorsClosed == True: 
            self.mazeState['walls'] = walls

        #extent, xArray, yArray, discreteCoords
        print("   discretising position for later plotting")
        if abs((self.roomSize / self.dx) - round(self.roomSize / self.dx)) > 0.00001: # check if dx is an integer divisor of room size
            print("      dx must be an integer fraction of room size, setting it to %.4f, %g along room length" %(self.roomSize / round(self.roomSize / self.dx), round(self.roomSize / self.dx)))
            self.dx = self.roomSize / round(self.roomSize / self.dx) 
        minx, maxx, miny, maxy = 0, 0, 0, 0
        for room in self.walls: #find the extent of the maze by looking at the max and min x and y values of all the walls
            wa = self.walls[room]
            minx, maxx, miny, maxy = min(minx,np.min(wa[...,0])), max(maxx,np.max(wa[...,0])), min(miny,np.min(wa[...,1])), max(maxy,np.max(wa[...,1])) 
        self.extent = np.array([minx,maxx,miny,maxy]) 
        self.width = maxx-minx
        self.height = maxy-miny
        self.xArray = np.arange(minx + self.dx/2, maxx, self.dx)
        self.yArray = np.arange(miny + self.dx/2, maxy, self.dx)[::-1]
        x_mesh, y_mesh = np.meshgrid(self.xArray,self.yArray) #create meshgrid -> 2D grid of x and y coordinates for later plotting
        coordinate_mesh = np.array([x_mesh, y_mesh]) # combine into coordinate pairs
        self.discreteCoords = np.swapaxes(np.swapaxes(coordinate_mesh,0,1),1,2) #an array of discretised position coords over entire map extent 
        self.mazeState['extent'] = self.extent # stores plotting boundaries in mazeState for later use

        #handle None params
        print("   handling undefined parameters")
        if self.dt == None: #set dt to be smaller than both the TD learning time constant and the STDP time constants for learning stability
            self.dt = min(self.tau_STDP_plus,self.tau_STDP_minus) / 2
        if self.pos is None: #if no initial position provided, set to be 20% in from bottom left corner
            ex = self.extent
            self.pos = np.array([ex[0] + 0.2*(ex[1]-ex[0]),ex[2] + 0.2*(ex[3]-ex[2])])
        if self.dir is None: #if no initial direction provided, set to be along the long axis of the maze
            if self.mazeType == 'longCorridor': self.dir = np.array([0,1]) # for corridor maze
            elif self.mazeType == 'loop': self.dir = np.array([1,0]) # for loop maze
            else: self.dir = np.array([1,1]) / np.sqrt(2) # for other mazes
        if self.rotSpeedScale is None: # if no rotational speed scale provided, set to be lower for mazes where precise turns are required
            if self.mazeType == 'loop' or self.mazeType == 'longCorridor': # for loop and corridor mazes
                self.rotSpeedScale = np.pi
            else: 
                self.rotSpeedScale = 3*np.pi # for other mazes where more rapid turning is possible without crashing into walls
        if (self.nCells is None) and (self.centres is None): #if no number of cells or cell centres provided, set number of cells according to maze size and sigma (cell width)
            ex = self.extent
            area, pcarea  = (ex[1]-ex[0])*(ex[3]-ex[2]), np.pi * ((self.sigma/2)**2)
            cellsPerArea = 10
            self.nCells = int(cellsPerArea * area / pcarea) #~10 in any given place
        if self.mazeType == 'TMaze': #if T maze, need to make sure enough cells to cover the small arms and that the agent starts in the stem, so set nCells according to the size of the arms
            self.LRDecisionPending=True #boolean to keep track of whether the agent is currently in the process of making a left/right decision at the T junction (used for plotting and for biasing movement if biasDoorCross is True)
        self.doorPassage = False #boolean to keep track of whether the agent is currently passing through the door in the two room maze (used for plotting and for biasing movement if biasDoorCross is True)
        self.doorPassageTime = 0 # counter to keep track of how long the agent has been passing through the door  
        self.lastTurnUpdate = -1 
        self.randomTurnSpeed = 0

        #initialise basis cells and M (successor matrix)
        print("   initialising basis features for learning")

        if self.stateType in ['gaussian', 'gaussianCS','gaussianThreshold', 'circles','onehot','bump']:
            if self.centres is not None: #if we don't provide locations for cell centres...
                self.nCells = self.centres.shape[0] #set number of cells according to number of centres provided
                self.stateSize = self.nCells #set state size according to number of cells provided -> stateSize is with that the number of cells 
            else: #scatter some ourselves (making sure they aren't too close)-> if centres are not provided
                self.stateSize=self.nCells #set state size according to number of cells provided
                xcentres = np.random.uniform(self.extent[0],self.extent[1],self.nCells) #randomly generate x coordinates
                ycentres = np.random.uniform(self.extent[2],self.extent[3],self.nCells) # randomly generate y coordinates
                self.centres = np.array([xcentres,ycentres]).T #set the cell centres to be the randomly generated coordinates
                inds = self.centres[:,0].argsort() #order the cell centres by x coordinate (for plotting purposes)
                self.centres = self.centres[inds] #
                print("   checking basis cells aren't too close") 
                min_d = 0.1/0.9 # 
                done = False #boolean to keep track of whether the cell centres are sufficiently far apart (if not, they are re-scattered with a smaller min_d until they are)
                while done != True: #while the cell centres are not sufficiently far apart
                    min_d *= 0.9
                    print("     min seperation distance:  %.1f cm" %(min_d*100))
                    count = 0
                    while count <= 10:
                        d = distance_matrix(self.centres,self.centres)
                        d  += 0.1*np.eye(d.shape[0])
                        d_xid, d_yid = np.where(d < min_d)
                        print('      ',int(len(d_xid)/2),' overlapping pairs',end='\n')
                        if len(d_xid) == 0:
                            done = True 
                            break
                        to_remove = []
                        for i in range(len(d_xid)):
                            if d_xid[i] < d_yid[i]:
                                to_remove.append(d_xid[i])
                        to_remove = np.unique(to_remove)
                        xcentres = np.random.uniform(self.extent[0],self.extent[1],len(to_remove))
                        ycentres = np.random.uniform(self.extent[2],self.extent[3],len(to_remove))
                        self.centres[to_remove] = np.array([xcentres,ycentres]).T
                        count += 1
            self.M = np.eye(self.stateSize) # initialise an identity matrix (1s on diagonal, 0s everywhere else), of size stateSize i.e. number of cells (makes sense!)
            self.W = self.M.copy() / self.nCells # copy of self.M but divided by number of nCells (why?)
            self.M_theta = self.M.copy() # copy of self.M -> Matrix for theta
            self.W_notheta = self.W.copy() # copy of self.W -> matrix for no theta



            #order the place cells so successor matrix has some structure:
            if self.reorderCells==True: # from parameters at beginning
                if self.mazeType == 'twoRooms': #from centre outwards
                    middle = np.array([self.extent[1]/2,self.extent[3]/2])
                    distance_to_centre = np.linalg.norm(middle - self.centres,axis=1)
                    distance_to_centre = distance_to_centre * (2*(self.centres[:,0]>middle[0])-1)
                    inds = distance_to_centre.argsort()
                    self.centres = self.centres[inds]
                else: #from left to right
                    inds = self.centres[:,0].argsort()
                    self.centres = self.centres[inds]

        elif self.stateType == 'fourier':
            self.stateSize = self.nCells
            self.kVectors = np.random.rand(self.nCells,2) - 0.5
            self.kVectors /= np.linalg.norm(self.kVectors, axis=1)[:,None]
            self.kFreq = 2*np.pi / np.random.uniform(0.01,1,size=(self.nCells))
            self.phi = np.random.uniform(0,2*np.pi,size=(self.nCells))
            self.M = np.eye(self.stateSize)
            #self.M = np.zeros((self.stateSize,self.stateSize))
        
        if hasattr(self.sigma,"__len__"):
            if self.sigma.__len__() == self.nCells:
                self.sigmas = self.sigma
        else:
            self.sigmas = np.array([self.sigma]*self.nCells)

        #array of states, one for each discretised position coordinate 
        print("   calculating state vector at all discretised positions")
        self.statesAlreadyInitialised = False
        self.discreteStates = self.positionArray_to_stateArray(self.discreteCoords,stateType=self.stateType,verbose=True) #an array of discretised position coords over entire map extent 
        self.statesAlreadyInitialised = True

        #store time zero snapshot
        snapshot = pd.DataFrame({'t':[self.t], 'M': [self.M.copy()], 'W': [self.W.copy()],'W_notheta': [self.W_notheta.copy()], 'mazeState':[self.mazeState]})
        self.snapshots = pd.concat([self.snapshots, snapshot], ignore_index=True)

        #STDP stuff
        print("   initialising STDP weight matrix and traces")
        self.preTrace = np.zeros(self.nCells) #causes potentiation -> pre trace for theta
        self.preTrace_notheta = np.zeros(self.nCells) #causes potentiation -> pre trace for no theta
        self.postTrace = np.zeros(self.nCells) #causes depression -> post trace for theta
        self.postTrace_notheta = np.zeros(self.nCells) #causes depression -> post trace for no theta
        self.lastSpikeTime = np.array(-10.0) # time of last spike for each cell, used for calculating STDP updates at end of each step (initialized to a long time ago so that there are no initial updates)
        self.lastSpikeTime_notheta = np.array(-10.0) # same for no theta
        self.spikeCount = np.array(0) # counter to keep track of how many spikes have been fired in total (used for calculating STDP updates at end of each step)
        self.spikeCount_notheta = np.array(0) # counter for no theta

### **def runRat**
Runs an episode/trial in which the agent explores the maze according to movement policy and learns the SR based on either TD learning or STDP learning
Saves data from the run in self.history with a specified runID that increased with each run

In [None]:
def runRat(self,
            trainTime=10,
            saveEvery=0.5,
            TDSRLearn=True,
            STDPLearn=True):
        """The main experiment call.
        A "run" consists of a period where the agent explores the maze according to the movement policy. 
        As it explores it learns, by TD, a successor representation over state vectors. 
        The can be called multiple times. Each successive run will be saved in self.history with an increasing runID
        Snapshots of the current SR matrix and mazeState can be saved along the way
        Runs can be interrupted with KeyboardInterrupt, data will still be saved. 
        Args:
            trainTime (int, optional): How long to explore in minutes. Defaults to 10.
            saveEvery (int, optional): Frequency to save snapshots, in minutes. Defaults to 1.
            TDSRLearn (bool,optional): toggles whether to do TD learning 
            STDPLearn (bool, optional): toggles whether to do STDP learning 
        """        
        steps = int(trainTime * 60 / self.dt) #number of steps to perform 

        hist_t = np.zeros(steps) #history array for time
        hist_pos = np.zeros((steps,2)) #history array for position
        hist_delta = np.zeros(steps) #history array for TD learning update size (delta)

        lastTDstep, distanceToTD = 0, np.random.exponential(self.TDdx) #2cm scale
        
        """Main training loop. Principally on each iteration: 
            • always updates motion policy 
            • often does TD learning step

            • sometimes saves snapshot"""
        for i in tqdm(range(steps)): #main training loop

            try:
                #update pos, velocity, direction and time according to movement policy
                self.movementPolicyUpdate()
                if i > 1:

                    # print(self.pos)
                    """STDP learning step"""
                    if (STDPLearn == True) and (self.stateType in  ['bump','gaussian', 'gaussianCS','gaussianThreshold', 'circles']):
                        if self.use_full_STDP_rule == True:
                            _ = self.STDPLearningStep_detailed(dt = self.t - hist_t[i-1])
                        else:
                            _ = self.STDPLearningStep(dt = self.t - hist_t[i-1])

                            
                    """TD learning step"""
                    if TDSRLearn == True: 
                        
                        alpha = self.alpha
                        try: alpha_ = alpha[0] * np.exp(-(i/steps)*(np.log(self.alpha[0]/self.alpha[1]))) #decaying alpha
                        except: alpha_ = self.alpha
                        

                        if np.linalg.norm(self.pos - hist_pos[lastTDstep]) >= distanceToTD: #if it's moved over 2cm meters from last step 
                            dtTD = self.t - hist_t[lastTDstep]
                            delta = self.TDLearningStep(pos=self.pos, prevPos=hist_pos[lastTDstep], dt=dtTD, tau=self.tau, alpha=alpha_)
                            lastTDstep = i 
                            distanceToTD = np.random.exponential(self.TDdx)
                            hist_delta[i] = delta



                self.thetaPhase = self.thetaFreq*(self.t%(1/self.thetaFreq))*2*np.pi #8Hz theta # didn't they use 10Hz?

                #update history arrays
                hist_pos[i] = self.pos
                hist_t[i] = self.t

                #save snapshot 
                if (isinstance(saveEvery, numbers.Number)) and (i % int(saveEvery * 60 / self.dt) == 0): #if it's time to save a snapshot (every saveEvery minutes)
                    snapshot = pd.DataFrame({'t':[self.t], 'M': [self.M.copy()], 'W': [self.W.copy()], 'W_notheta':[self.W_notheta.copy()], 'mazeState':[self.mazeState]})
                    self.snapshots = pd.concat([self.snapshots, snapshot], ignore_index=True)

            except KeyboardInterrupt: 
                print("Keyboard Interrupt:")
                break
            # except ValueError as error:
            #     print("ValueError:")
            #     print(error)
            #     print(f"   Rat position: {self.pos}")
            #     break

        self.runID += 1 #increment runID for next run
        runHistory = pd.DataFrame({'t':list(hist_t[:i]), 'pos':list(hist_pos[:i]),'delta':list(hist_delta[:i])}) # make a dataframe of the history of this run
        self.history = pd.concat([self.history, runHistory], ignore_index=True) # add this run's history to the overall history dataframe
        snapshot = pd.DataFrame({'t': [self.t], 'M': [self.M.copy()], 'W': [self.W.copy()], 'W_notheta':[self.W_notheta.copy()], 'mazeState':[self.mazeState]}) #make a final snapshot at the end of the run
        self.snapshots = pd.concat([self.snapshots, snapshot], ignore_index=True) #add this final snapshot to the snapshots dataframe

        #find and save grid/place cells so you don't have to repeatedly calculate them when plotting 
        print("Calculating place and grid cells")
        self.gridFields = self.getGridFields(self.M) # calculate grid fields from successor matrix M and save as attribute of class
        self.placeFields = self.getPlaceFields(self.M) # calculate place fields from successor matrix M and save as attribute of class

        if TDSRLearn == True: # if TD learning was performed, make some plots of the learning
            # plotter = Visualiser(self)
            # plotter.plotTrajectory(starttime=(self.t/60)-0.2, endtime=self.t/60)
            delta = np.array(hist_delta)
            time = np.array(hist_t)
            time = time[delta!=0] / 60
            delta = delta[delta!=0]
            time, delta = time[::10], delta[::10]
            smooth_delta = [np.mean(delta[max(0,i-100):min(i+100,len(delta))]) for i in range(len(delta))]
            fig, ax = plt.subplots(figsize=(2,1))
            ax.scatter(time,delta,s=0.5,alpha=0.5)
            ax.scatter(time,smooth_delta,s=1,alpha=0.5,c='C2')
            ax.set_xlabel("Time / min")
            ax.set_ylabel("Update size")


### **def TDLearningStep**
Defines the TD learning step -> how the TD SR is updated based on TD Learning rule. Also defines a more computationally efficient way to compute the learning step in the case of "onehot" basis features. 

In [None]:
def TDLearningStep(self, pos, prevPos, dt, tau, alpha):
        """TD learning step
            Improves estimate of SR matrix, M, by a TD learning step. 
            By default this is done using learning rule for generic feature vectors (see de Cothi and Barry 2020). 
            If stateType is onehot, additional efficiencies can be gained by using onehot specific learning rule (see Stachenfeld et al. 2017)
            Does time continuous TD learning (see Doya, 2000)
        Args:
            pos: position at t+dt (t) (dt=simulation timestep)
            prevPos (array): position at t (t-dt)
            dt (float): time difference between two positions
            tau (float or int): memory decay time (analogous to gamma in TD, gamma = 1 - dt/tau)
            alpha (float): learning rate
            mask (bool or str): whether to mask TM update to update only cells near current location
            asynchronus (bool): update cells asynchronusly (like hopfield)
        """
        state = self.posToState(pos,stateType=self.stateType) # get state vector for current position
        prevState = self.posToState(prevPos,stateType=self.stateType) # get state vector for previous position

        data = ( (state,                        prevState,                    self.M        ) , 
                 (self.thetaModulation(state),  self.thetaModulation(state),  self.M_theta) ) # data for both theta and no theta cases, to be looped through for learning updates

        
        for i, (state, prevState, M) in  enumerate(data): 
            #onehot optimised TD learning 
            if self.stateType == 'onehot': 
                s_t = np.argwhere(prevState)[0][0]
                s_tplus1 = np.argwhere(state)[0][0]
                Delta = state + (tau / dt) * ((1 - dt/tau) * M[:,s_tplus1] - M[:,s_t])
                M[:,s_t] += alpha * Delta - 2 * alpha * self.TDreg * M[:,s_t]

            #normal TD learning 
            else:
                delta = ((tau * dt) / (tau + dt)) * self.successorFeatureNorm * prevState + M @ ((tau/(tau + dt))*state - prevState) # TD error, scaled by successorFeatureNorm for learning stability
                Delta = np.outer(delta, prevState) # TD learning update for M matrix (outer product of TD error and previous state))
                M += alpha * Delta - 2 * alpha * self.TDreg * M #regularisation term to prevent weights from growing too large, which would cause instability in learning
            
            if i == 0: 
                Del = Delta # store the TD update for the non-theta case, to be returned for monitoring learning progress (e.g. in plots of update size over time)

        return np.linalg.norm(Del) # return size of TD update for non-theta case, as a measure of learning progress

### **def STDPLearningStep**
Updates the weights according to the STDP learning rule for each basis feature - successor feature combination 

In [None]:
def STDPLearningStep(self,dt):       
        """Takes the curent theta phase and estimate firing rates for all basis cells according to a simple theta sweep model. 
           From here it samples spikes and performs STDP learning on a weight matrix.

        Args:
            dt (float): Time step length 

        Returns:
            float array: vector of firing rates for this time step 
        """   
        state = self.posToState(self.pos)

        data = ( (state,
                self.W_notheta,  
                self.preTrace_notheta,  
                self.postTrace_notheta,  
                self.lastSpikeTime_notheta, 
                self.spikeCount_notheta),

                 (self.thetaModulation(state),
                self.W,          
                self.preTrace,          
                self.postTrace,          
                self.lastSpikeTime, 
                self.spikeCount), 
                ) # data for both theta and no theta cases, to be looped through for learning updates

        
        for i, (firingRate, W, preTrace, postTrace, lastSpikeTime, spikeCount) in enumerate(data): 
            firingRate_ = self.peakFiringRate * firingRate + self.baselineFiringRate #scale firing rate and add noise
            n_spike_list = np.random.poisson(firingRate_*dt)
            
            spikingNeurons = (n_spike_list != 0) #in short time dt cells can spike 0 or 1 time only (good enough approximation) 
            spikeCount += sum(spikingNeurons) # update spike count for monitoring purposes
            spikeTimes = np.random.uniform(self.t,self.t+dt,self.nCells)[spikingNeurons] # sample spike times for cells that spiked, uniformly across the time step
            spikeIDs = np.arange(self.nCells)[spikingNeurons] # get the IDs of the cells that spiked (i.e. their index in the firing rate vector)
            spikeList = np.vstack((spikeIDs,spikeTimes)).T # combine spike times and IDs into a list for looping through
            spikeList = spikeList[np.argsort(spikeList[:,1])] # order the spike list by spike time, so that STDP updates are performed in the correct order (pre-before-post potentiation, post-before-pre depression)

            for spikeInfo in spikeList: # loop through each spike in order of spike time and perform STDP updates
                cell, time = int(spikeInfo[0]), spikeInfo[1] 
                timeDiff = time - lastSpikeTime 


                preTrace        *= np.exp(- timeDiff / self.tau_STDP_plus) #traces for all cells decay...
                postTrace       *= np.exp(- timeDiff / self.tau_STDP_minus) #traces for all cells decay...
                W[cell,:]       += self.eta * preTrace #weights to postsynaptic neuron (should increase when post fires)
                W[:,cell]       += self.eta * postTrace #weights to presynaptic neuron (should decrease when post fires) 
                postTrace[cell] += self.a_STDP  #update trace (post trace probably negative)
                preTrace[cell]  += 1 #update trace 



                lastSpikeTime += timeDiff

            if i == 1: 
                thetaFiringRate = firingRate_

        return thetaFiringRate

### **STDPLearningStep_detailed**
Detailed version of the STDP Learning Step. 
- Has more elif conditions 

In [None]:
def STDPLearningStep_detailed(self,dt):       
        """Takes the curent theta phase and estimate firing rates for all basis cells according to a simple theta sweep model. 
           From here it samples spikes and performs STDP learning on a weight matrix.

        Args:
            dt (float): Time step length 

        Returns:
            float array: vector of firing rates for this time step 
        """   
        state = self.posToState(self.pos)

        data = ( (state,
                self.W_notheta,  
                self.preTrace_notheta,  
                self.postTrace_notheta,  
                self.lastSpikeTime_notheta, 
                self.spikeCount_notheta),

                 (self.thetaModulation(state),
                self.W,          
                self.preTrace,          
                self.postTrace,          
                self.lastSpikeTime, 
                self.spikeCount), 
                )

        
        for i, (firingRate, W, preTrace, postTrace, lastSpikeTime, spikeCount) in enumerate(data): 
            preFiringRate_ = self.peakFiringRate * firingRate + self.baselineFiringRate #scale firing rate and add noise
            if self.online_mapping == "identity": 
                mapMatrix =  np.identity(self.nCells)
            elif self.online_mapping == "Widentity": 
                mapMatrix =  W + 0.5*np.identity(self.nCells)
            elif self.online_mapping == "W":
                mapMatrix =  W
            else: 
                mapMatrix = self.online_mapping

            postFiringRate_ = np.maximum(0,np.matmul(mapMatrix,preFiringRate_))
            firingRate_ = np.concatenate((preFiringRate_,postFiringRate_))
            layerLabel_ = np.array(['pre']*len(preFiringRate_) + ['post']*len(postFiringRate_))
            neuronIDs = np.concatenate((np.arange(len(preFiringRate_)), np.arange(len(postFiringRate_))))
            n_spike_list = np.random.poisson(firingRate_*dt)
            
            spikingNeurons = (n_spike_list != 0) #in short time dt cells can spike 0 or 1 time only (good enough approximation) 
            spikeCount += sum(spikingNeurons)
            spikeTimes = np.random.uniform(self.t,self.t+dt,len(neuronIDs))[spikingNeurons]
            spikeIDs = neuronIDs[spikingNeurons]
            spikeLayerLabels = layerLabel_[spikingNeurons]
            spikeList = np.vstack((spikeIDs,spikeTimes,spikeLayerLabels)).T
            spikeList = spikeList[np.argsort(spikeList[:,1])]   

            for spikeInfo in spikeList:
                cell, time, layer = int(spikeInfo[0]), float(spikeInfo[1]), spikeInfo[2]
                timeDiff = time - lastSpikeTime 

                preTrace        *= np.exp(- timeDiff / self.tau_STDP_plus) #traces for all cells decay...
                postTrace       *= np.exp(- timeDiff / self.tau_STDP_minus) #traces for all cells decay...
                if layer == 'pre':
                    W[:,cell]       += self.eta * postTrace #weights from presynaptic neuron should decrease when pre fires (post-before-PRE) 
                    preTrace[cell]  += 1 #update trace 
                if layer == 'post':
                    W[cell,:]       += self.eta * preTrace #weights to postsynaptic neuron should increase when post fires (pre-before-POST)
                    postTrace[cell] += self.a_STDP  #update trace (post trace probably negative)

                lastSpikeTime += timeDiff

            if i == 1: 
                thetaFiringRate = firingRate_
        
        if self.rownorm == True: 
            # self.W = self.W / np.linalg.norm(self.W,axis=1)[:,np.newaxis]
            # self.W_notheta = self.W_notheta / np.linalg.norm(self.W_notheta,axis=1)[:,np.newaxis]
            sumW = np.sum(self.W,axis=1)
            sumW[sumW<1]=1
            self.W = self.W / sumW[:,np.newaxis]
            sumWnt = np.sum(self.W,axis=1)
            sumWnt[sumWnt<1]=1
            self.W_notheta = self.W_notheta / sumWnt[:,np.newaxis]
        
        #save spike data
        CA3spiketimes = spikeTimes[spikeLayerLabels=='pre']
        CA3spikeids = spikeIDs[spikeLayerLabels=='pre']
        CA1spiketimes = spikeTimes[spikeLayerLabels=='post']
        CA1spikeids = spikeIDs[spikeLayerLabels=='post']
        self.spikedata['CA3']['times'].extend(CA3spiketimes)
        self.spikedata['CA3']['ids'].extend(CA3spikeids)
        self.spikedata['CA1']['times'].extend(CA1spiketimes)
        self.spikedata['CA1']['ids'].extend(CA1spikeids)

        return thetaFiringRate

### **thetaModulation**
Uses the firing rate vector that was returned by STDPLearningStep or STDPLearningStep_detailed and modulates it according to theta phase precession -> basically when cell fires within the theta cycle
This function makes sure the cell only fires in preferred phase and phase precesses! 

In [None]:
def thetaModulation(self, firingRate, position=None, direction=None):
        """Takes a firing rate vector and modulates it to account for theta phase precession

        Args:
            firingRate (np.array): The raw (position dependent) firing rate vector to be modulated 
            position (np.array(2,), optional): The agent position. Defaults to None.
            direction (np.array(2,), optional): The agent direction. Defaults to None.
        """        
        if position is None:
            position = self.pos # if no position provided, use current position of the agent
        if direction is None:
            direction = self.dir # if no direction provided, use current direction of the agent

        vectorToCells = self.vectorsToCellCentres(position) # vector from agent to each cell centre
        sigmasToCellMidline = (np.dot(vectorToCells,direction) / np.linalg.norm(direction))  / self.sigmas  #as mutiple of sigma # distance along the direction of movement to the cell centre, scaled by cell width (sigma)
        preferedThetaPhase = np.pi + sigmasToCellMidline * self.precessFraction * np.pi # preferred theta phase for each cell according to a simple linear phase precession model, where the cell is at its preferred firing phase (pi) when the agent is at the cell centre, and precesses linearly to earlier phases as the agent traverses the place field 

        phaseDiff = preferedThetaPhase - self.thetaPhase # difference between current theta phase and preferred theta phase for each cell
        modulatedFiringRate = firingRate * vonmises.pdf(phaseDiff,kappa=self.kappa) * 2*np.pi # modulate firing rate according to von mises distribution of phase differences, scaled so that the peak firing rate is the same as the unmodulated case (i.e. modulation redistributes firing across the theta cycle but doesn't change the overall firing rate)

        return modulatedFiringRate 

### **def movementPolicyUpdate**
- 

In [None]:
def movementPolicyUpdate(self):
        """Movement policy update. 
            In principle this does a very simple thing: 
            • updates time by dt, 
            • updates position along the velocity direction 
            • updates velocity (speed and direction) accoridng to a movement policy
            In reality it's a complex function as the policy requires checking for immediate or upcoming collisions with all walls at each step.
            This is done by function self.checkWallIntercepts()
            What it does with this info (bounce off wall, turn to follow wall, etc.) depends on policy. 
        """

        dt = self.dt # time step for movement update
        self.t += dt # update time by dt
        proposedNewPos = self.pos + self.speed * self.dir * dt # proposed new position if it moves along current direction at current speed for one time step
        proposedStep = np.array([self.pos,proposedNewPos]) # proposed step as a line segment from current position to proposed new position, used for checking wall intercepts
        
        # Only relevant for 2-room maze!
        if (self.biasDoorCross == True) and (self.mazeType == 'twoRooms'): 
            #if agent crosses into door zone there's its turn direction is biased to try and cross the door 
            #this is done by setting agents direction in the right direction and not changing it again until after it's crossed
            doorRegionSize = 1
            if self.doorPassage == False:
                #if step cross into door region
                if (np.linalg.norm(self.pos - np.array([self.roomSize,self.roomSize/2])) > doorRegionSize) and (np.linalg.norm(proposedNewPos - np.array([self.roomSize,self.roomSize/2])) < doorRegionSize) and (abs(self.pos[0] - self.roomSize) > 0.01):
                    if 100*np.random.uniform(0,1) < 50: #start a doorPassage
                        self.doorPassage = True
                        self.doorPassageTime = self.t
                        return
                    else: #ignore this 
                        pass
            if self.doorPassage == True: 
                if ((self.pos[0]<(self.roomSize)) != (proposedNewPos[0]<(self.roomSize))) or ((self.t - self.doorPassageTime)*self.speedScale > 2*doorRegionSize):
                    self.doorPassage = False
                    if ((self.pos[0]<(self.roomSize)) != (proposedNewPos[0]<(self.roomSize))): 
                        print("crossed",self.t)
                    if ((self.t - self.doorPassageTime)*self.speedScale > 2*doorRegionSize):
                        print("time")
        
        # Basically checks whether proposed next step would collide with a a wall or not 
        checkResult = self.checkWallIntercepts(proposedStep) # check for intercepts between proposed step and all walls, returns a tuple of (collisionNow, wall) where collisionNow is a boolean for whether the proposed step collides with a wall, and wall is the wall that it collides with (if any)
        if self.movementPolicy == 'randomWalk': # wht do do if the policy is "random walk"
            if checkResult[0] != 'collisionNow': # in case of no collision, update position and direction according to random walk policy (move to proposed new position and turn by a small random angle)
                self.pos = proposedNewPos
                randomTurnSpeed = np.random.normal(0,self.rotSpeedScale) #scaled by Gaussian with mean 0 so that on average no bias in turn direction, but with some variability controlled by rotSpeedScale
                self.dir = turn(self.dir,turnAngle=randomTurnSpeed*dt)
            elif checkResult[0] == 'collisionNow': # in case of collision, update direction according to bounce policy (bounce off wall according to angle of incidence)
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
        
        if self.movementPolicy == 'trueRandomWalk': # compared to randomWalk, this one has fully random direction changes at each step
            if checkResult[0] != 'collisionNow': 
                self.pos = proposedNewPos
                self.dir = turn(self.dir,turnAngle=np.random.uniform(0,2*np.pi)) # scaled by uniform distribution so that turn angle is fully random between 0 and 360 degrees
            elif checkResult[0] == 'collisionNow':
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
        
        if self.movementPolicy == 'leftRightRandomWalk': # compared to randomWalk, this one has random direction changes at each step but only between 3 options: turn left, turn right or keep going straight (no bias between these options)
            if checkResult[0] != 'collisionNow': 
                self.pos = proposedNewPos
                self.dir = turn(self.dir,turnAngle=np.random.choice([0,np.pi])) # turn angle is randomly chosen to be either 0 (keep going straight) or pi (turn around), with equal probability
            elif checkResult[0] == 'collisionNow':
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
        
        if self.movementPolicy == 'raudies': # compared to randomWalk, this one has a bias to turn in a consistent direction (e.g. left) when it encounters a wall, which encourages it to follow along the wall rather than bouncing off it
            if checkResult[0] == 'collisionNow':
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'bounce') # in case of collision now bounce off wall according to angle of incidence 
            elif ((checkResult[0] == 'collisionAhead') and (self.biasWallFollow==True)):
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'follow') # if collision is ahead (i.e. proposed step would collide with wall but current position is not yet colliding), and biasWallFollow is True, then turn to follow the wall rather than bouncing off it
            elif (checkResult[0] == 'noImmediateCollision') or (((checkResult[0] == 'collisionAhead') and (self.biasWallFollow==False))):
                self.pos = proposedNewPos # if no immediate collision, or if collision is ahead but biasWallFollow is False, then move to proposed new position and turn by a small random angle

            
            self.speed = np.random.rayleigh(self.speedScale)
            if self.t - self.lastTurnUpdate >= 0.1: #turn updating done at intervals independednt of dt or else many small turns cancel out but few big ones dont 
                randTurnMean = 0
                if self.doorPassage == True: 
                        d_theta  = theta(self.dir) - theta(np.array([self.roomSize,self.roomSize/2]) - self.pos)
                        if d_theta > 0: randTurnMean = -self.rotSpeedScale
                        else: randTurnMean = self.rotSpeedScale
                self.randomTurnSpeed = np.random.normal(randTurnMean,self.rotSpeedScale)
                self.lastTurnUpdate = self.t
            self.dir = turn(self.dir, turnAngle=self.randomTurnSpeed*dt)

        if self.movementPolicy == 'windowsScreensaver': # compared to randomWalk, this one bounces off walls like a windows screensaver, but doesn't change direction otherwise (i.e. no random turns), so it tends to follow along walls for long periods of time
            if checkResult[0] != 'collisionNow': 
                self.pos = proposedNewPos # has no turns basically just bounces off walls according to angle of incidence 
            elif checkResult[0] == 'collisionNow':
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
        
        if self.movementPolicy == '1DOrnUhl': # moves in a straight line but with speed updated according to an Ornstein-Uhlenbeck process, which gives it some variability in speed but no variability in direction (so it tends to move in straight lines but with varying speed, which can lead to long straight movements and some periods of immobility)
            if checkResult[0] != 'collisionNow': 
                self.pos = proposedNewPos
            elif checkResult[0] == 'collisionNow':
                wall = checkResult[1]
                self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
            self.speed += ornstein_uhlenbeck(dt=dt, x=self.speed, drift=self.speedScale,noise_scale=self.speedScale, coherence_time=5)
            self.speed = max(0,self.speed)

        
        if self.mazeType == 'loop':
            self.pos[0] = self.pos[0] % self.roomSize # if it's a loop maze, then if it goes off one end it comes back on the other end, so we take the modulus of the position with the room size to achieve this looping effect
        
        # T-MAZE MENTIONED!! -> specific movement of the agent in that case
        if self.mazeType == 'TMaze': # in a T maze, if it crosses the junction point, it has to turn left or right (no going straight), and if it goes too far up the arms of the T, it gets turned around to go back down the stem
            if (self.pos[0] > self.roomSize+0.05) and (self.LRDecisionPending==True):
                if np.random.choice([0,1],p=[0.66,0.34]) == 0:
                    self.dir = np.array([0,1])
                else:
                    self.dir = np.array([0,-1])
                self.LRDecisionPending=False
            if self.pos[1] > self.extent[3] or self.pos[1] < self.extent[2]:
                self.pos = np.array([0,1])
                self.dir = np.array([1,0])
                self.LRDecisionPending=True


        #catchall instances a rat escapes the maze by accident, pops it 2cm within maze 
        if ((self.pos[0] < self.extent[0]) or 
            (self.pos[0] > self.extent[1]) or 
            (self.pos[1] < self.extent[2]) or 
            (self.pos[1] > self.extent[3])):
            print(self.pos)
            self.pos[0] = max(self.pos[0],self.extent[0]+0.02)
            self.pos[0] = min(self.pos[0],self.extent[1]-0.02)
            self.pos[1] = max(self.pos[1],self.extent[2]+0.02)
            self.pos[1] = min(self.pos[1],self.extent[3]-0.02)
            print("Rat escaped!") # bro wtf hahaha
            if self.mazeType == 'TMaze':
                self.dir=np.array([1,0])
                self.LRDecisionPending = True
            # plotter = Visualiser(self)
            # plotter.plotTrajectory(starttime=(self.t/60)-0.2, endtime=self.t/60)

### **def vectorsToCellCentres**
Calculates the shortest displacement vector from the agent’s current position to every place field centre, taking maze geometry into account -> e.g. walls.
It tells each cell:
“How far away am I from you, and in which direction?”

In [None]:
def vectorsToCellCentres(self,pos,distance=False):
        """Takes a position vector shape (2,) and returns an array of shape (nCells,2) of the 
        shortest vector path to all cells, taking into account loop geometry etc. 

        Args:
            pos (array): position vector shape (2,)

        Returns:
            vectorToCells (array): shape (30,2)
        """        
        if self.mazeType == 'loop' and self.doorsClosed == False: # if it's a loop maze with doors open, then the shortest path to a cell might be either the direct path or the path that goes around the loop, so we calculate both and take the shorter one
            pos_plus = pos + np.array([self.roomSize,0]) 
            pos_minus = pos - np.array([self.roomSize,0])
            positions = np.array([pos,pos_plus,pos_minus])
            vectors = self.centres[:,np.newaxis,:] - positions[np.newaxis,:,:]
            shortest = np.argmin(np.linalg.norm(vectors,axis=-1),axis=1)
            shortest_vectors = np.diagonal(vectors[:,shortest,:],axis1=0,axis2=1).T

        else: # in all other cases the shortest path to each cell is just the direct path, so we can just take the vector from the position to the cell centre
            shortest_vectors = self.centres - self.pos
            
        return shortest_vectors 

### **distanceToCellCentres**
Calculates the shortest distance from the agent's position to each cell centre. In all cases, except the two-maze room, that's simply the euclidean distance, obtained from the shortest_vectors, i.e. output of vectorsToCellCentres. In the case of the two-room maze the function first checks whether the cell and the agent are in the same room and basically takes into account any detour the agent needs to take when having to pass the door first. 

In [None]:
def distanceToCellCentres(self, pos):
        """Calculates distance to cell centres. 
           In the case of the two room maze, this distance is the shortest feasible walk carefully accounting for doorways etc. 
        Args:
            pos (no.array): The position to calculate the distances from 

        Returns:
            np.array: (nCells,) array of distances
        """         

        if self.mazeType == 'twoRooms': 
            distances = np.zeros(self.nCells)
            wall_x = self.walls['doors'][0][0][0]
            wall_y1, wall_y2 = self.walls['doors'][0][0][1], self.walls['doors'][0][1][1]
            for i in range(self.nCells):
                vec = np.array(pos - self.centres[i])
                if ((self.centres[i][0] < wall_x) and (pos[0] < wall_x)) or ((self.centres[i][0] > wall_x) and (pos[0] > wall_x)):
                    distances[i] = np.linalg.norm(vec)
                else: #cell and position in different rooms 
                    if self.doorsClosed == True:
                        distances[i] = 100*self.roomSize 
                        print("doorsClosed")
                    else:
                        step = np.array([pos,self.centres[i]])
                        if self.checkWallIntercepts(step)[0] == 'collisionNow':
                            pastBottomWall = np.linalg.norm(np.array([wall_x,wall_y1]) - pos) + np.linalg.norm(np.array([wall_x,wall_y1]) - self.centres[i])
                            pastTopWall = np.linalg.norm(np.array([wall_x,wall_y2]) - pos) + np.linalg.norm(np.array([wall_x,wall_y2]) - self.centres[i])
                            distances[i] = min(pastBottomWall,pastTopWall)
                        else: 
                            distances[i] = np.linalg.norm(vec)

        else: 
            shortest_vector = self.vectorsToCellCentres(pos)
            distances = np.linalg.norm(shortest_vector,axis=1)

        return distances

### **toggleDoors**
Basically just opens or closes doors if there is a change needed and updates the mazeState. 
*Note: Could be interesting to look at in case of DNMTP task*

In [None]:
def toggleDoors(self, doorsClosed = None): #this function could be made more advanced to toggle more maze options
        """Opens or closes door and updates mazeState
            mazeState stores the most recent version of the maze walls dictionary which will include 'door' wall only if doorsClosed is True
        Args:
            doorsClosed ([bool], optional): True if doors to be closed, False if doors to be opened. Defaults to None, in which case current door state is flipped.
        Returns:
            [dict]: the walls dictionary
        """        
        if doorsClosed is not None: 
            self.doorsClosed = doorsClosed
        else: self.doorsClosed = not self.doorsClosed

        walls = self.walls.copy()
        if self.doorsClosed == False: 
            del walls['doors']
            self.mazeState['walls'] = walls
        elif self.doorsClosed == True: 
            self.mazeState['walls'] = walls

        self.discreteStates = self.positionArray_to_stateArray(self.discreteCoords,stateType=self.stateType) #an array of discretised position coords over entire map extent 

        return self.mazeState['walls']

### **checkWallIntercepts**
Basically checks for any proposed next step whether that would result in a collision with a wall on that step, on the next few steps, or not at all.  

In [None]:
def checkWallIntercepts(self,proposedStep,collisionDistance=0.1): #proposedStep = [pos,proposedNextPos]
        """Given the cuurent proposed step [currentPos, nextPos] it calculates whether a collision with any of the walls exists along this step.
        There are three possibilities from most worrying to least:
            • there is a collision ON the current step. Do something immediately.
            • there is a collision along the current trajectory in the next few cm's, but not on the current step. Consider doing something.
            • there is no collision coming up soon. Carry on as you are. 
        Args:
            proposedStep (array): The proposed step. np.array( [ [x_current, y_current] , [x_next, y_next] ] )

        Returns:
            tuple: (str, array), (<whether there is no collision, collision now or collision ahead> , <the wall in question>)
        """        
        s1, s2 = np.array(proposedStep[0]), np.array(proposedStep[1])
        pos = s1
        ds = s2 - s1 # distance between current and next position
        stepLength = np.linalg.norm(ds) # length of proposed step
        ds_perp = perp(ds) # perpendicular vector to proposed step, used for calculating intercepts with walls

        collisionList = [[],[]] # list of walls that collide with current step 
        futureCollisionList = [[],[]] # list of walls that collide with current trajectory but not current step 
        #check if the current step results in a collision 
        walls = self.mazeState['walls'] #current wall state

        for wallObject in walls.keys():
            for wall in walls[wallObject]:
                w1, w2 = np.array(wall[0]), np.array(wall[1])
                dw = w2 - w1
                dw_perp = perp(dw)

                # calculates point of intercept between the line passing along the current step direction and the lines passing along the walls,
                # if this intercept lies on the current step and on the current wall (0 < lam_s < 1, 0 < lam_w < 1) this implies a "collision" 
                # if it lies ahead of the current step and on the current wall (lam_s > 1, 0 < lam_w < 1) then we should "veer" away from this wall
                # this occurs iff the solution to s1 + lam_s*(s2-s1) = w1 + lam_w*(w2 - w1) satisfies 0 <= lam_s & lam_w <= 1
                with np.errstate(divide='ignore'):
                    lam_s = (np.dot(w1, dw_perp) - np.dot(s1, dw_perp)) / (np.dot(ds, dw_perp))
                    lam_w = (np.dot(s1, ds_perp) - np.dot(w1, ds_perp)) / (np.dot(dw, ds_perp))

                #there are two situations we need to worry about: 
                # • 0 < lam_s < 1 and 0 < lam_w < 1: the collision is ON the current proposed step . Do something immediately.
                # • lam_s > 1     and 0 < lam_w < 1: the collision is on the current trajectory, some time in the future. Maybe do something. 
                if (0 <= lam_s <= 1) and (0 <= lam_w <= 1):
                    collisionList[0].append(wall)
                    collisionList[1].append([lam_s,lam_w])
                    continue

                if (lam_s > 1) and (0 <= lam_w <= 1):
                    if lam_s * stepLength <= collisionDistance: #if the future collision is under collisionDistance away
                        futureCollisionList[0].append(wall)
                        futureCollisionList[1].append([lam_s,lam_w])
                        continue
        
        if len(collisionList[0]) != 0: # if there is a collision on the current step, return this wall and the fact that it's a current collision (i.e. we need to do something about it immediately)
            wall_id = np.argmin(np.array(collisionList[1])[:,0]) #first wall you collide with on step 
            wall = collisionList[0][wall_id]
            return ('collisionNow', wall)
        
        elif len(futureCollisionList[0]) != 0: # if there is a collision ahead on the current trajectory, return this wall and the fact that it's an upcoming collision (i.e. we might want to do something about it, depending on policy and whether biasWallFollow is True or not)
            wall_id = np.argmin(np.array(futureCollisionList[1])[:,0]) #first wall you would collide with along current step 
            wall = futureCollisionList[0][wall_id]
            return ('collisionAhead', wall)
        
        else:
            return ('noImmediateCollision',None) # if there are no collisions on current step or ahead on current trajectory, return None and the fact that there is no immediate collision (i.e. we can carry on as we are)

### **def getPlaceFields**
Calculates the place fields of all place cells -> seems to be the place fields of the successor features not basis features

In [None]:
def getPlaceFields(self, M=None, threshold=None):
        """Calculates receptive fields of all place cells 
            There is one place cell for each feature cell. 
            A place cell (as  in de Cothi 2020) is defined as a thresholded linear combination of feature cells
            where the linear combination is a row of the SR matrix. 
        Args:
            M (array): SR matrix
        Returns:
            array: Receptive fields of shape [nCells, nX, nY]
        """        
        if M is None: 
            M = self.M
        M = M.copy()
        #normalise: 
        # M = M / np.diag(M)[:,np.newaxis]
        if threshold is None: 
            placeCellThreshold =  0.9  #place cell threshold value (fraction of its maximum)
        else: 
            placeCellThreshold =  threshold
        placeFields = np.einsum("ij,klj->ikl",M,self.discreteStates)
        threshold = placeCellThreshold*np.amax(placeFields,axis=(1,2))[:,None,None]
        # threshold = placeCellThreshold
        placeFields = np.maximum(0,placeFields - threshold)
        return placeFields

### **getGridFields**
Computes grid cell firing fields from the SR by using eigenvectors from the SR.
The number of grid cells = number of nCells -> makes sense if their fields are computed from the SR

In [None]:
def getGridFields(self, M, alignToFinal=False):
        """Calculates receptive fiels of all grid cells 
            There is an equal number of grid cells as place cells and feature cells. 
            A grid cell (as in de Cothi 2020) is defined as a thresholded linear combination of feature cells
            where the linear combination weights are the eigenvectors of the SR matrix. 
        Args:
            M (array): SR matrix
            alignToFinal (bool): Since negative of eigenvec is also eigenvec try maximise overlap with final one (for making animations)
        Returns:
            array: Receptive fields of shape [nCells, nX, nY]
        """
        M = M.copy()
        _, eigvecs = np.linalg.eig(M) #"v[:,i] is the eigenvector corresponding to the eigenvalue w[i]"
        eigvecs = np.real(eigvecs)
        gridCellThreshold = 0 
        gridFields = np.einsum("ij,kli->jkl",eigvecs,self.discreteStates)
        threshold = gridCellThreshold*np.amax(gridFields,axis=(1,2))[:,None,None]
        if alignToFinal == True:
            grids_final_flat = np.reshape(self.gridFields,(self.stateSize,-1))
            grids_flat = np.reshape(gridFields,(self.stateSize,-1))
            dotprods = np.empty(grids_flat.shape[0])
            for i in range(len(dotprods)):
                dotprodsigns = np.sign(np.diag(np.matmul(grids_final_flat,grids_flat.T)))
                gridFields *= dotprodsigns[:,None,None]
        gridFields = np.maximum(0,gridFields)
        return gridFields

### **posToState**
It tells you how strongly each place/feature cell fires at a given position.
- computes the distance from position to each cell centre
- then converts the distance into activity depending on stateType -> e.g. gaussian, onehot etc.
If firing firingRateLookUp != False, the firng rate is taken from the lookup table, which speeds up calculations

In [None]:
def posToState(self, pos, stateType=None, normalise=True, cheapNormalise=False,initialisingCells=False): #pos is an [n1, n2, n3, ...., 2] array of 2D positions
        
        if (self.statesAlreadyInitialised == False) or (self.firingRateLookUp == False):
            #calculates the firing rate of all cells 
            pos = np.array(pos)
            if stateType == None: stateType = self.stateType
        
            vector_to_cells = self.centres - pos
            distance_to_cells = [np.linalg.norm(vector_to_cells,axis=1)]
            closest_cell_ID = np.argmin(distance_to_cells)

            if (self.mazeType == 'loop') and (self.doorsClosed == False):
                distance_to_cells.append(np.linalg.norm(self.centres - pos + [self.extent[1],0],axis=1))
                distance_to_cells.append(np.linalg.norm(self.centres - pos - [self.extent[1],0],axis=1))
            
            if (self.mazeType == 'twoRooms'): 
                distance_to_cells = [self.distanceToCellCentres(pos)]

            if stateType == 'onehot':
                state = np.zeros(self.nCells)
                state[closest_cell_ID] = 1
            
            if stateType == 'gaussianThreshold':
                state = np.zeros(self.nCells)
                for distance in distance_to_cells: 
                    state += np.maximum(np.exp(-distance**2 / (2*(self.sigmas**2))) - np.exp(-1/2) , 0) / (1-np.exp(-1/2))
                    # state = state / (self.sigmas) #normalises so same no. spikes emitted for all cell sizes

            if stateType == 'gaussian':
                state = np.zeros(self.nCells)
                for distance in distance_to_cells: 
                    state += np.exp(-distance**2 / (2*(self.sigmas**2)))
                    state = state/self.sigmas

            if stateType == 'bump':
                state = np.zeros(self.nCells)
                for distance in distance_to_cells:
                    state[distance<self.sigmas] += np.e * np.exp(-1/(1-(distance/self.sigmas)**2))[distance<self.sigmas]
                    state[distance>=self.sigmas] += 0
                    state = state/self.sigmas
        
        else:
            #uses look up table to rapidly pull out the firing rates without having to recalculate them 

            closestQuantisedArea = np.unravel_index(np.argmin(np.linalg.norm(self.discreteCoords - pos,axis=-1)),self.discreteCoords.shape[:-1])
            state = self.discreteStates[closestQuantisedArea]


        return state

### **positionArray_to_stateArray**
Applies posTosState to many positions at once -> computes the neural activity for any position on the grid
=> neural feature representation

In [None]:
def positionArray_to_stateArray(self, positionArray, stateType=None,verbose=False): 
        """Takes an array of 2D positions of size (n1, n2, n3, ..., 2)
        returns the state vector for each of these positions of size (n1, n2, n3, ..., N) where N is the size of the state vector
        Args:
            positionArray ([type]): [description]
            stateType ([type], optional): [description]. Defaults to None.
        """        
        if stateType == None: stateType = self.stateType
        states = np.zeros(positionArray.shape[:-1] + (self.nCells,))
        if verbose == False:
            for idx in np.ndindex(positionArray.shape[:-1]):
                states[idx] = self.posToState(pos = positionArray[idx],stateType = stateType)
        else: 
            for idx in tqdm(np.ndindex(positionArray.shape[:-1]),total=int(positionArray.size/2)):
                states[idx] = self.posToState(pos = positionArray[idx],stateType = stateType)

        return states

### **averageM**
- Aligns and averages rows of the SR matrix M to get a canonical successor profile.
- Extracts the average predictive shape of the SR, independent of position.

In [None]:
def averageM(self, M=None):
        if M == None:
            M = self.M
        M_copy = M.copy()
        roll = int(self.nCells/2)
        for i in range(self.nCells): # Livi: changed from agent.nCells to self.nCells bcs latter doesn't give erorr
            M_copy[i,:] = np.roll(M[i,:],-i+roll)
        M_av,M_std = np.mean(M_copy,axis=0),np.std(M_copy,axis=0)
        M_av, M_std = M_av/np.max(M_av), M_std/np.max(M_std)
        return M_av, M_std

### **getMetrics**
What it does:
- Computes quantitative metrics comparing:
- STDP weights with theta (W)
- STDP weights without theta (W_notheta)
- True TD successor matrix (M)

Specifically computes:
- R² → how similar W is to M
- SNR → how clean/structured the learned weights are
- Skewness → forward bias (predictive asymmetry)
- Peak location → how far prediction shifts forward

Conceptually:
- Measures how well STDP (with vs without theta) approximates the TD successor representation.
- This is basically the evaluation function for your thesis.

In [None]:
def getMetrics(self,time=None):
        if time is not None: 
            hist_id = self.snapshots['t'].sub(time).abs().to_numpy().argmin()
            snapshot = self.snapshots.iloc[hist_id]
        else:
            snapshot = self.snapshots.iloc[-1]

        x = self.centres[:,0].copy()

        W = rowAlignMatrix(snapshot['W'].copy())
        W_notheta = rowAlignMatrix(snapshot['W_notheta'].copy())
        M = rowAlignMatrix(self.snapshots.iloc[-1]['M'].copy())

        mid = int(self.nCells / 2)

        #R2s
        R_W = Rsquared(W,M)              
        R_Wnotheta = Rsquared(W_notheta,M)  

        #SNRs
        SNR_W = (np.max(np.mean(W,axis=0)) - np.min(np.mean(W,axis=0))) / np.mean(np.std(W,axis=0)[mid-5:mid+5])
        SNR_Wnotheta = (np.max(np.mean(W_notheta,axis=0)) - np.min(np.mean(W_notheta,axis=0))) / np.mean(np.std(W_notheta,axis=0)[mid-5:mid+5])

        #skews
        W_flat = np.mean(W,axis=0)/np.trapz(np.mean(W,axis=0),x)
        W_flat /= np.max(W_flat)
        Wnotheta_flat = np.mean(W_notheta,axis=0)/np.trapz(np.mean(W_notheta,axis=0),x)
        Wnotheta_flat /= np.max(Wnotheta_flat)
        M_flat = np.mean(M,axis=0)/np.trapz(np.mean(M,axis=0),x)
        M_flat /= np.max(M_flat)
        try: skew_W = getSkewness(W_flat)
        except RuntimeError: skew_W = np.NaN
        try: skew_Wnotheta = getSkewness(Wnotheta_flat)
        except RuntimeError: skew_Wnotheta = np.NaN
        try: skew_M = getSkewness(M_flat)
        except RuntimeError: skew_M = np.NaN

        #peaks 
        peak_W = x[np.argmax(W_flat)] - 2.5
        peak_Wnotheta = x[np.argmax(Wnotheta_flat)] - 2.5
        peak_M = x[np.argmax(M_flat)] - 2.5
        return R_W, R_Wnotheta, SNR_W, SNR_Wnotheta, float(skew_W), float(skew_Wnotheta), float(skew_M), peak_W, peak_Wnotheta, peak_M


### **saveToFile** & **loadFromFile**
-> saveToFile saves the entire agent object as a snapshot of the whole simulation to an .npz file

-> loadFromFile basically reconstructs a saved simulation state by loading previously saved simulation state

In [None]:
def saveToFile(self,name,directory="../savedObjects/"):
        np.savez(directory+name+".npz",self.__dict__)
        return

def loadFromFile(self,name,directory="../savedObjects/"):
    attributes_dictionary = np.load(directory+name+".npz",allow_pickle=True)['arr_0'].item()
    print("Loading attributes...",end="")
    for key, value in attributes_dictionary.items():
        setattr(self, key, value)
    print("done. use 'agent.__dict__.keys()'  to see available attributes")


## Non-Class Functions
- getWalls
    - basically stores and retursn dictionaries of all mazeTypes
- wallBounceOrFollow
    - returns a new direction given the current trajection direction, the wall and an instruction (i.e. "follow" or "bounce") -> basically what to do in case of collision and how to implement the policy, e.g. if policy is to follow walls
- turn
    - turns current direction by a certain turning angle and returns new direction
- perp
    - returns the perpendicular (90 degree) vector *b* to a input 2-vector *a* 
- theta
    - This function computes the direction angle (heading) of a vector or movement segment, expressed between 0 and 2π. It accepts either:
        - A direction vector → shape (2,)
            → Returns the angle of that vector.
        - A start and end position → shape (2,2)
            → Computes the direction from start → end, then returns its angle.

In [None]:
def getWalls(mazeType, roomSize=1):
    """Stores and returns dictionaries containing all the walls of a maze
    Args:
        mazeType (str): Name of the maze 
        roomSize (int, optional): scaling parameter for roomsize. Defaults to 1 metre.
    Returns:
        dict: wall dictionary
    """    
    walls = {}
    rs = roomSize
    if mazeType == 'oneRoom':
        walls['room1'] = np.array([
                                [[0,0],[0,rs]],
                                [[0,rs],[rs,rs]],
                                [[rs,rs],[rs,0]],
                                [[rs,0],[0,0]]])
    elif mazeType == 'twoRooms':
        walls['room1'] = np.array([
                                [[0,0],[0,rs]],
                                [[0,rs],[rs,rs]],
                                [[rs,rs],[rs,0.6*rs]],
                                [[rs,0.4*rs],[rs,0]],
                                [[rs,0],[0,0]]])
        walls['room2'] = np.array([
                                [[rs,0],[rs,0.4*rs]],
                                [[rs,0.6*rs],[rs,rs]],
                                [[rs,rs],[2*rs,rs]],
                                [[2*rs,rs],[2*rs,0]],
                                [[2*rs,0],[rs,0]]])
        walls['doors'] = np.array([[[rs,0.4*rs],[rs,0.6*rs]]])
    elif mazeType == 'fourRooms':
        walls['room1'] = np.array([
                                [[0,0],[0,rs]],
                                [[0,rs],[0.4*rs,rs]],
                                [[0.6*rs,rs],[rs,rs]],
                                [[rs,rs],[rs,0.6*rs]],
                                [[rs,0.4*rs],[rs,0]],
                                [[rs,0],[0,0]]])
        walls['room2'] = np.array([
                                [[rs,0],[rs,0.4*rs]],
                                [[rs,0.6*rs],[rs,rs]],
                                [[rs,rs],[1.4*rs,rs]],
                                [[1.6*rs,rs],[2*rs,rs]],
                                [[2*rs,rs],[2*rs,0]],
                                [[2*rs,0],[rs,0]]])
        walls['room3'] = np.array([
                                [[0,rs],[0.4*rs,rs]],
                                [[0.6*rs,rs],[rs,rs]],
                                [[rs,rs],[rs,1.4*rs]],
                                [[rs,1.6*rs],[rs,2*rs]],
                                [[rs,2*rs],[0,2*rs]],
                                [[0,2*rs],[0,rs]]])
        walls['room4'] = np.array([
                                [[rs,rs],[1.4*rs,rs]],
                                [[1.6*rs,rs],[2*rs,rs]],
                                [[2*rs,rs],[2*rs,2*rs]],
                                [[2*rs,2*rs],[rs,2*rs]],
                                [[rs,2*rs],[rs,1.6*rs]],
                                [[rs,1.4*rs],[rs,rs]]])
        walls['doors'] = np.array([[[rs,0.4*rs],[rs,0.6*rs]],
                                        [[0.4*rs,rs],[0.6*rs,rs]],
                                        [[rs,1.4*rs],[rs,1.6*rs]],
                                        [[1.4*rs,rs],[1.6*rs,rs]]])
    elif mazeType == 'twoRoomPassage':
        walls['room1'] = np.array([
                                [[0,0],[rs,0]],
                                [[rs,0],[rs,rs]],
                                [[rs,rs],[0.75*rs,rs]],
                                [[0.25*rs,rs],[0,rs]],
                                [[0,rs],[0,0]]])
        walls['room2'] = np.array([
                                [[rs,0],[2*rs,0]],
                                [[2*rs,0],[2*rs,rs]],
                                [[2*rs,rs],[1.75*rs,rs]],
                                [[1.25*rs,rs],[rs,rs]],
                                [[rs,rs],[rs,0]]])
        walls['room3'] = np.array([
                                [[0,rs],[0,1.4*rs]],
                                [[0,1.4*rs],[2*rs,1.4*rs]],
                                [[2*rs,1.4*rs],[2*rs,rs]]])
        walls['doors'] = np.array([[[0.25*rs,rs],[0.75*rs,rs]],
                                [[1.25*rs,rs],[1.75*rs,rs]]])
    elif mazeType == 'longCorridor':
        walls['room1'] = np.array([
                                [[0,0],[0,rs]],
                                [[0,rs],[rs,rs]],
                                [[rs,rs],[rs,0]],
                                [[rs,0],[0,0]]])
        walls['longbarrier'] = np.array([
                                [[0.1*rs,0],[0.1*rs,0.9*rs]],
                                [[0.2*rs,rs],[0.2*rs,0.1*rs]],
                                [[0.3*rs,0],[0.3*rs,0.9*rs]],
                                [[0.4*rs,rs],[0.4*rs,0.1*rs]],
                                [[0.5*rs,0],[0.5*rs,0.9*rs]],
                                [[0.6*rs,rs],[0.6*rs,0.1*rs]],
                                [[0.7*rs,0],[0.7*rs,0.9*rs]],
                                [[0.8*rs,rs],[0.8*rs,0.1*rs]],
                                [[0.9*rs,0],[0.9*rs,0.9*rs]]])
    elif mazeType == 'rectangleRoom':
        ratio = np.pi/2.8
        walls['room1'] = np.array([
                                [[0,0],[0,rs]],
                                [[0,rs],[ratio*rs,rs]],
                                [[ratio*rs,rs],[ratio*rs,0]],
                                [[ratio*rs,0],[0,0]]])
    elif mazeType == 'loop':
        height = 0.2
        walls['room'] = np.array([
                                [[0,0],[rs,0]],
                                [[0,height],[rs,height]]])
        walls['doors'] = np.array([
                                [[0,0],[0,height]],
                                [[rs,0],[rs,height]]])
    
    elif mazeType == 'TMaze':
        walls['corridors'] = np.array([
                                [[0,0.05+rs],[rs,0.05+rs]],
                                [[0,-0.05+rs],[rs,-0.05+rs]],
                                [[rs,0.05+rs],[rs,rs+rs]],
                                [[rs+0.1,rs+rs],[rs+0.1,-rs+rs]],
                                [[rs,-0.05+rs],[rs,-rs+rs]]]) 
        walls['doors'] = np.array([[[rs,-0.05+rs-0.5],[rs+0.1,-0.05+rs-0.5]]])                               

    return walls

#MOVEMENT FUNCTIONS
def wallBounceOrFollow(currentDirection,wall,whatToDo='bounce'):
    """Given current direction, and wall and an instruction returns a new direction which is the result of implementing that instruction on the current direction
        wrt the wall. e.g. 'bounce' returns direction after elastic bounce off wall. 'follow' returns direction parallel to wall (closest to current heading)
    Args:
        currentDirection (array): the current direction vector
        wall (array): start and end coordinates of the wall
        whatToDo (str, optional): 'bounce' or 'follow'. Defaults to 'bounce'.
    Returns:
        array: new direction
    """    
    if whatToDo == 'bounce':
        wallPerp = perp(wall[1] - wall[0])
        if np.dot(wallPerp,currentDirection) <= 0:
            wallPerp = -wallPerp #it is now the perpendicular with smallest angle to dir 
        wallPar = wall[1] - wall[0]
        if np.dot(wallPar,currentDirection) <= 0:
            wallPar = -wallPar #it is now the parallel with smallest angle to dir 
        wallPar, wallPerp = wallPar/np.linalg.norm(wallPar), wallPerp/np.linalg.norm(wallPerp) #normalise
        dir_ = wallPar * np.dot(currentDirection,wallPar) - wallPerp * np.dot(currentDirection,wallPerp)
        newDir = dir_/np.linalg.norm(dir_)
    elif whatToDo == 'follow':
        wallPar = wall[1] - wall[0]
        if np.dot(wallPar,currentDirection) <= 0:
            wallPar = -wallPar #it is now the parallel with smallest angle to dir 
        wallPar = wallPar/np.linalg.norm(wallPar)
        dir_ = wallPar * np.dot(currentDirection,wallPar)
        newDir = dir_/np.linalg.norm(dir_)
    return newDir

def turn(currentDirection, turnAngle):
    """Turns the current direction by an amount theta, modulus 2pi
    Args:
        currentDirection (array): current direction 2-vector
        turnAngle (float): angle ot turn in radians
    Returns:
        array: new direction
    """    
    theta_ = theta(currentDirection)
    theta_ += turnAngle
    theta_ = np.mod(theta_, 2*np.pi)
    newDirection = np.array([np.cos(theta_),np.sin(theta_)])
    return newDirection

def perp(a=None):
    """Given 2-vector, a, returns its perpendicular
    Args:
        a (array, optional): 2-vector direction. Defaults to None.
    Returns:
        array: perpendicular to a
    """    
    b = np.empty_like(a)
    b[0] = -a[1]
    b[1] = a[0] 
    return b

def theta(segment):
    """Given a 'segment' (either 2x2 start and end positions or 2x1 direction bearing) 
         returns the 'angle' of this segment modulo 2pi
    Args:
        segment (array): The segment, (2,2) or (2,) array 
    Returns:
        float: angle of segment
    """    
    eps = 1e-6
    if segment.shape == (2,): 
        return np.mod(np.arctan2(segment[1],(segment[0] + eps)),2*np.pi)
    elif segment.shape == (2,2):
        return np.mod(np.arctan2((segment[1][1]-segment[0][1]),(segment[1][0] - segment[0][0] + eps)), 2*np.pi)


## **Visualiser Class**
- plotMazeStructure
    - plots the maze -> structure of the maze as a fgure
- plotTrajectory
    - plots the trajectory of the agent
- plotM
    - plots a matrix (SR or weights) as a heatmap, scaled so it can be visually compared to a reference matrix.
    - Depending on whichM:
        - 'M' → TD successor matrix
        - 'W' → STDP weights (with theta)
        - 'W_notheta' → STDP weights without theta modulation
        - 'M_theta' → his is usually the SR computed under theta-modulated activity or a theta-filtered version of M (depending on implementation)
- addTimestamp
    - adds a timestamp to a figure
- plotPlaceField
    - returns a figure of the place fields of cell(s) -> not sure if for one cell/certain number or all in the maze structure
- plotReceptiveField
    - plots the receptive field as a heatmap (I think)
- plotGridField
    - plots fields of the grid cells
- plotFeatureCells
    - plots the basis feature cells -> I guess but need to check what the output actually is
- plotHeatMap
    - plots a heat map of the agent positions and trajectory (I think)
- animateField
    - animation of either 'place' cells, 'grid' cells, synaptic weight matrix 'M'
- plotMAveraged
    - plots averaged marices -> no clue tbh -> need to check
    - only works for 1D open loop maze apparently though
- plotMetrics
    - plots metrics: SNR, R2 and figures apparently
- plotFieldSilhouette
    - plots the receptive field silhouette of the neurons place field, the predictive field based on TD, STDP+theta, STDP without theta

In [None]:
class Visualiser():
    def __init__(self, mazeAgent):
        self.mazeAgent = mazeAgent
        self.snapshots = mazeAgent.snapshots
        self.history = mazeAgent.history

    def plotMazeStructure(self,fig=None,ax=None,hist_id=-1,save=False):
        snapshot = self.snapshots.iloc[hist_id]
        extent, walls = snapshot['mazeState']['extent'], snapshot['mazeState']['walls']
        if (fig, ax) == (None, None): 
            fig, ax = plt.subplots(figsize=(4*(extent[1]-extent[0]),4*(extent[3]-extent[2])))
        for wallObject in walls.keys():
            for wall in walls[wallObject]:
                ax.plot([wall[0][0],wall[1][0]],[wall[0][1],wall[1][1]],color='darkgrey',linewidth=8)
            # ax.set_xlim(left=extent[0]-0.05,right=extent[1]+0.05)
            # ax.set_ylim(bottom=extent[2]-0.05,top=extent[3]+0.05)
        ax.set_aspect('equal')
        ax.grid(False)
        ax.axis('off')
        if save == True: 
            saveFigure(fig, 'mazeStructure')
        return fig, ax
    
    def plotTrajectory(self,fig=None, ax=None, hist_id=-1,starttime=0,endtime=2,plotAsLine=False):
        skiprate = max(1,int(0.015/(self.mazeAgent.speedScale * self.mazeAgent.dt)))
        if (fig, ax) == (None, None):
            fig, ax = self.plotMazeStructure(hist_id=hist_id)
        startid = self.history['t'].sub(starttime*60).abs().to_numpy().argmin()
        endid = self.history['t'].sub(endtime*60).abs().to_numpy().argmin()
        trajectory = np.stack(self.history['pos'][startid:endid])[::skiprate]
        if plotAsLine == False:
            ax.scatter(trajectory[:,0],trajectory[:,1],s=10,alpha=0.7,zorder=2)
        elif plotAsLine == True:
            ax.plot(trajectory[:,0],trajectory[:,1])
        saveFigure(fig, "trajectory")
        return fig, ax

    
    def plotM(self,hist_id=-1, time=None, M=None,fig=None,ax=None,save=True,savename="",title="",show=True,plotTimeStamp=False,colorbar=True,whichM='M',colormatchto='TD_M'):
        if time is not None: 
            hist_id = self.snapshots['t'].sub(time*60).abs().to_numpy().argmin()
        snapshot = self.snapshots.iloc[hist_id]
        if (ax is not None) and (fig is not None): 
            ax.clear()
        else:
            fig, ax = plt.subplots(figsize=(2,2))
        if M is None:
            if whichM == 'M': M = snapshot['M'].copy()
            elif whichM == 'W': M = snapshot['W'].copy()
            elif whichM == 'M_theta': M = self.mazeAgent.M_theta.copy()
            elif whichM == 'W_notheta': M = self.mazeAgent.W_notheta.copy()

        t = int(np.round(snapshot['t']))
        most_positive = np.max(M)
        most_negative = np.min(M)

        if colormatchto == 'TD_M': 
            M_colormatch = self.mazeAgent.M
        elif colormatchto == 'W_onDiag': 
            M_colormatch = self.mazeAgent.W_onDiag
        # elif colormatchto is not None:
        #     M_colormatch = np.load(colormatchto)
        else: 
            M_colormatch = np.array([-1,1])

        # if np.min(M)/np.min(M_colormatch) > np.max(M)/np.max(M_colormatch):
            # M *= np.min(M_colormatch)/np.min(M)
        # else:
        M_ = M.copy()
        # np.fill_diagonal(M_,0)
        non_diag_max = np.max(M_)
        non_diag_kind_max = np.mean(M_[M_>0.9*non_diag_max])
        M *= np.max(M_colormatch)/non_diag_kind_max

        im = ax.imshow(M,cmap='viridis',vmin=np.min(M_colormatch),vmax=np.max(M_colormatch))
        divider = make_axes_locatable(ax)
        try: cax.clear()
        except: 
            pass
        if colorbar == True:
            cax = divider.append_axes("right", size="5%", pad=0.05)
            cb = fig.colorbar(im, cax=cax, ticks=[0])
            cb.outline.set_visible(False)

        ax.set_aspect('equal')
        ax.grid(False)
        ax.axis('off')
        ax.set_title(title)
        if save==True:
            saveFigure(fig, "M"+savename)
        if plotTimeStamp == True: 
            ax.text(100, 5,"%g s"%t, fontsize=5,c='w',horizontalalignment='center',verticalalignment='center')
        if show==False:
            plt.close(fig)
        
        try:
            return fig, ax, cb, cax
        except:
            return fig, ax 

    def addTimestamp(self, fig, ax, i=-1):
        t = self.mazeAgent.saveHist[i]['t']
        ax.text(x=0, y=0, t="%.2f" %t)

    def plotPlaceField(self, M=None, hist_id=-1, time=None, fig=None, ax=None, number=None, show=True, animationCall=False, plotTimeStamp=False,save=True,STDP=False,threshold=None,fitEllipse_=False,no_theta=False):
        if M is None: 
            #time in minutes
            if time is not None: 
                hist_id = self.snapshots['t'].sub(time*60).abs().to_numpy().argmin()
            #if a figure/ax objects are passed, clear the axis and replot the maze
            if (ax is not None) and (fig is not None): 
                ax.clear()
                self.plotMazeStructure(fig=fig, ax=ax, hist_id=hist_id)
            # else if they are not passed plot the maze
            if (fig, ax) == (None, None):
                fig, ax = self.plotMazeStructure(hist_id=hist_id)
            
            if number == None: number = np.random.randint(0,self.mazeAgent.stateSize-1)
            
            snapshot = self.snapshots.iloc[hist_id]
            M = snapshot['M']
            if STDP==True: 
                if no_theta == True:
                    snapshot = self.snapshots.iloc[hist_id-22]
                    M = snapshot['W_notheta']
                else:
                    M = snapshot['W']
            t = int(np.round(snapshot['t'] / 60))
        else:
            snapshot = self.snapshots.iloc[hist_id]
            fig, ax = self.plotMazeStructure(hist_id=hist_id)
        extent = snapshot['mazeState']['extent']
        placeFields = self.mazeAgent.getPlaceFields(M=M,threshold=threshold)
        ax.imshow(placeFields[number],extent=extent,interpolation=None)

        if fitEllipse_ == True: 
            (X,Y,Z),_,_ = fitEllipse(placeFields[number],coords=self.mazeAgent.discreteCoords)
            ax.contour(X, Y, Z, levels=[1], colors=('w'), linewidths=2, linestyles="dashed")

        
        if self.mazeAgent.mazeType == 'loop':
            x = self.mazeAgent.discreteCoords[10,:,0]

        peakid= np.argmax(placeFields[number])
        peakcoord = self.mazeAgent.discreteCoords.reshape(-1,2)[peakid]
        ax.scatter(peakcoord[0],peakcoord[1],marker='x',s=130,color='darkgrey',linewidth=4,alpha=1)

        if plotTimeStamp == True: 
            ax.text(extent[1]-0.07, extent[3]-0.05,"%g"%t, fontsize=5,c='w',horizontalalignment='center',verticalalignment='center')
        if show==False:
            plt.close(fig)
        if save==True:
            saveFigure(fig, "placeField")
        return fig, ax
    

    def plotReceptiveField(self, number=None, hist_id=-1, fig=None, ax=None, show=True, fitEllipse_=False):
        if (fig, ax) == (None, None):
            fig, ax = self.plotMazeStructure(hist_id=hist_id)
        if number == None: number = np.random.randint(0,self.mazeAgent.nCells-1)
        extent = self.mazeAgent.extent
        rf = self.mazeAgent.discreteStates[..., number]
        ax.imshow(rf,extent=extent,interpolation=None)
        peakid= np.argmax(rf)
        peakcoord = self.mazeAgent.discreteCoords.reshape(-1,2)[peakid]
        ax.scatter(peakcoord[0],peakcoord[1],marker='x',s=130,color='darkgrey',linewidth=4,alpha=1)
        if self.mazeAgent.mazeType == 'loop':
            x = self.mazeAgent.discreteCoords[10,:,0]
        if fitEllipse_ == True: 
            (X,Y,Z),_,_= fitEllipse(rf,coords=self.mazeAgent.discreteCoords)
            ax.contour(X, Y, Z, levels=[1], colors=('w'), linewidths=2, linestyles="dashed")
        if show==False:
            plt.close(fig)
        saveFigure(fig, "receptiveField")
        return fig, ax
    

    def plotGridField(self, hist_id=-1, time=None, fig=None, ax=None, number=0, show=True, animationCall=False, plotTimeStamp=False,save=True,STDP=False):
        if time is not None: 
            hist_id = self.snapshots['t'].sub(time*60).abs().to_numpy().argmin()
        snapshot = self.snapshots.iloc[hist_id]
        M = snapshot['M']
        if STDP==True: 
            M = snapshot['W'].T    
        t = snapshot['t'] / 60
        extent = snapshot['mazeState']['extent']
        if hist_id == -1 and animationCall == False:
            gridFields = self.mazeAgent.gridFields
        else:
            gridFields = self.mazeAgent.getGridFields(M=M,alignToFinal=True)

        def sigmoid(x):
            return np.exp(x) / (np.exp(x) + np.exp(-x))
        if number == 'many': 
            fig = plt.figure(figsize=(10, 10*((extent[3]-extent[2])/(extent[1]-extent[0]))))
            gs = matplotlib.gridspec.GridSpec(6, 6, hspace=0.1, wspace=0.1)
            c=0
            # numberstoplot = np.array([60 + 5*i for i in np.arange(36)])
            numberstoplot = np.concatenate((np.array([0,1,2,3,4,5]),np.geomspace(6,gridFields.shape[0]-1,30).astype(int)))
            for i in range(6):
                for j in range(6):
                    ax = plt.subplot(gs[i,j])
                    # ax.imshow(sigmoid(gridFields[numberstoplot[c]]),extent=extent,interpolation=None)
                    ax.imshow(gridFields[numberstoplot[c]],extent=extent,interpolation=None)
                    ax.grid(False)
                    ax.axis('off')
                    ax.text(extent[1]-0.07, extent[3]-0.05,str(numberstoplot[c]+1),fontsize=5,c='w',horizontalalignment='center',verticalalignment='center')
                    c+=1

        else:
            #if a figure/ax objects are passed, clear the axis and replot the maze
            if (ax is not None) and (fig is not None): 
                ax.clear()
                self.plotMazeStructure(fig=fig, ax=ax, hist_id=hist_id)
            # else if they are not passed plot the maze
            if (fig, ax) == (None, None):
                fig, ax = self.plotMazeStructure(hist_id=hist_id)
            
            if number == None: number = np.random.randint(a=0,b=self.mazeAgent.stateSize-1)

            ax.imshow(gridFields[number],extent=extent,interpolation=None)

            if plotTimeStamp == True: 
                ax.text(extent[1]-0.07, extent[3]-0.05,"%g"%t, fontsize=5,c='w',horizontalalignment='center',verticalalignment='center')
            if show==False:
                plt.close(fig)
        
        if save==True:
            saveFigure(fig, "gridField")
        return fig, ax
        
    def plotFeatureCells(self, hist_id=-1,textlabel=True,shufflebeforeplot=False,centresOnly=False,onepink=False,threetypes=False):
        fig, ax = self.plotMazeStructure(hist_id=hist_id)
        centres = self.mazeAgent.centres.copy()
        ids = np.arange(len(centres))
        if shufflebeforeplot==True:
            np.random.shuffle(ids)
        centres = centres[ids]
        for (i, centre) in enumerate(centres):
            # if i%10==0:
                if textlabel==True:
                    ax.text(centre[0],centre[1],str(ids[i]),fontsize=15,horizontalalignment='center',verticalalignment='center')
                if self.mazeAgent.mazeType == 'TMaze':
                    if abs(centre[1]-1)<0.001: color = 'C0'
                    elif centre[1]>1.001: color = 'C1'
                    elif centre[1]<0.999: color = 'C2'
                    else: color = 'C3'
                else:
                    color = 'C'+str(i)
                if centresOnly == True: 
                    alpha=1
                    c='darkgrey'
                    if onepink == True:
                        if i == 30:
                            if self.mazeAgent.mazeType == 'twoRooms':
                                centre = np.array([3,1.1])
                                c = 'C3'
                    if self.mazeAgent.mazeType == 'twoRooms': 
                        s = 400; linewidth = 9
                    else: 
                        s = 130; linewidth = 4
                    if threetypes==True: 
                        alpha=0.7
                        if i%3 == 0:
                            centre -= [0.02,-0.03]
                            c='C0'
                        elif i%3 == 1:
                            c='C1'
                        elif i%3 == 2:
                            centre += [0.02,-0.03]
                            c='C3'

                    ax.scatter(centre[0],centre[1],marker='x',s=s,color=c,linewidth=linewidth,alpha=alpha)
                else:
                    circle = matplotlib.patches.Ellipse((centre[0],centre[1]), 2*self.mazeAgent.sigmas[i], 2*self.mazeAgent.sigmas[i], alpha=0.5, facecolor=color)
                    ax.add_patch(circle)
                
        saveFigure(fig, "basis")
        return fig, ax 
    
    def plotHeatMap(self,smoothing=1):
        posdata = np.stack(self.mazeAgent.history['pos'])
        bins = [int(n/smoothing) for n in list(self.mazeAgent.discreteCoords.shape[:2])]
        bins.reverse()
        hist = np.histogram2d(posdata[:,0],posdata[:,1],bins=bins)[0]
        fig, ax = self.plotMazeStructure(hist_id=-1)
        ax.imshow(hist.T, extent=self.mazeAgent.extent)
        return fig, ax



    def animateField(self, number=0,field='place',interval=100):
        if field == 'place':
            fig, ax = self.plotPlaceField(hist_id=0,number=number,show=False,save=False)
            anim = FuncAnimation(fig, self.plotPlaceField, fargs=(None, fig, ax, number, False, True, True, False), frames=len(self.snapshots), repeat=False, interval=interval)
        elif field == 'grid':
            fig, ax = self.plotGridField(hist_id=0,number=number,show=False,save=False)
            anim = FuncAnimation(fig, self.plotGridField, fargs=(None, fig, ax, number, False, True, True, False), frames=len(self.snapshots), repeat=False, interval=interval)
        elif field == 'M':
            fig, ax = self.plotM(hist_id=0,show=False,save=False,colorbar=False)
            anim = FuncAnimation(fig, self.plotM, fargs=(None, fig, ax, False,"", "Synaptic Weight Matrix \n (STDP Hebbian learning)", False,False,False,True), frames=len(self.snapshots), repeat=False, interval=interval)
        
        today = datetime.strftime(datetime.now(),'%y%m%d')
        now = datetime.strftime(datetime.now(),'%H%M')
        saveFigure(anim,saveTitle=field+"Animation",anim=True)
        return anim


    
    def plotMAveraged(self,time=None, x_ticks=None, plot_no_theta=True, color='C1',ylim=None, renorm=True):
        # only works/defined for open loop maze
        if time is not None: 
            hist_id = self.snapshots['t'].sub(time*60).abs().to_numpy().argmin()
            snapshot = self.snapshots.iloc[hist_id]
        else:
            snapshot = self.snapshots.iloc[-1]

        M = snapshot['M'].copy()
        W = snapshot['W'].copy() 
        W_notheta = snapshot['W_notheta'].copy()
        roll = int(self.mazeAgent.nCells/2)
        M_copy, W_copy, W_notheta_copy = M.copy(), W.copy(), W_notheta.copy()
        for i in range(self.mazeAgent.nCells):
            M_copy[i,:] = np.roll(M[i,:],-i+roll)
            W_copy[i,:] = np.roll(W[i,:],-i+roll)
            W_notheta_copy[i,:] = np.roll(W_notheta[i,:],-i+roll)

        M_av,M_std = np.mean(M_copy,axis=0),np.std(M_copy,axis=0)
        W_av,W_std = np.mean(W_copy,axis=0),np.std(W_copy,axis=0)
        W_notheta_av,W_notheta_std = np.mean(W_notheta_copy,axis=0),np.std(W_notheta_copy,axis=0)

        # print(f"skew W: {getSkewness(W_av)} vs skew W_notheta: {getSkewness(W_notheta_av)} vs skew M: {getSkewness(M_av)}")
        # print(f"mass ratio = {np.sum(W_av[:int(len(W_av/2))])/np.sum(W_av[int(len(W_av/2)):])} vs {np.sum(W_notheta_av[:int(len(W_notheta_av/2))])/np.sum(W_notheta_av[int(len(W_notheta_av/2)):])} (no theta)")
        print(f"mass ratio = {np.sum(W_av[:25])/np.sum(W_av[25:])} vs {np.sum(W_notheta_av[:25])/np.sum(W_notheta_av[25:])} (no theta)")

        if renorm == True: 
            W_norm = np.maximum(np.max(W_notheta_av),np.max(W_av))
            M_av,M_std = M_av/(np.max(M_av)), M_std/(np.max(M_av))
            W_av,W_std = W_av/W_norm, W_std/W_norm
            W_notheta_av,W_notheta_std = W_notheta_av/W_norm, W_notheta_std/W_norm
        x = self.mazeAgent.centres[:,0]
        x = x-x[roll]

        for i in range(len(x)):
            if x[i] > self.mazeAgent.extent[1]/2:
                x[i] = x[i] - self.mazeAgent.extent[1]

        roll = int(self.mazeAgent.nCells/2)


        fig, ax = plt.subplots(2,1,figsize=(2,2))


        Rs_wav = Rsquared(W,M)
        Rsq_wnothetaav = Rsquared(W_notheta,M)

        ax[1].plot(x,M_av,c='C0',linewidth=2, label = r" ")
        ax[0].plot(x,W_av,c=color,label=r" ",linewidth=2)
        # ax[1].plot(x,M_theta_av,c='C0',linewidth=1.5,alpha=0.5,linestyle='dotted')

        ax[1].fill_between(x,M_av+M_std,M_av-M_std,facecolor='C0',alpha=0.2)
        ax[0].fill_between(x,W_av+W_std,W_av-W_std,facecolor=color,alpha=0.2)
        if plot_no_theta == True: 
            ax[0].fill_between(x,W_notheta_av+W_notheta_std,W_notheta_av-W_notheta_std,facecolor=color,alpha=0.2)
            ax[0].plot(x,W_notheta_av,c=color,label=r" ",linewidth=1.5,alpha=0.7,linestyle='--', dashes=(1, 1))

        ax[0].set_yticks([])
        ax[1].set_yticks([])
        ax[0].set_xlim(min(x),max(x))
        ax[1].set_xlim(min(x),max(x))
        if ylim is not None: 
            ax[0].set_ylim(0,ylim)
        else:
            ax[0].set_ylim(min(W_av-W_std),max(W_av+W_std))

        ticks = (x_ticks or [-2,-1,0,1,2])
        ticklabels = [""]*len(ticks)
        ax[0].set_xticks(ticks)
        ax[0].set_xticklabels(ticklabels)
        ax[1].set_xticks(ticks)
        ax[1].set_xticklabels(ticklabels)
        # ax[0].tick_params(width=2,color='darkgrey')
        # ax[1].tick_params(width=2,color='darkgrey')
        plt.grid(False)

        ax[0].spines['left'].set_position('zero')
        # ax[0].spines['left'].set_color('darkgrey')
        ax[0].spines['left'].set_linewidth(2)
        ax[0].spines['right'].set_color('none')        
        ax[0].spines['bottom'].set_position('zero')
        # ax[0].spines['bottom'].set_color('darkgrey')
        ax[0].spines['bottom'].set_linewidth(2)
        ax[0].spines['top'].set_color('none')

        ax[1].spines['left'].set_position('zero')
        # ax[1].spines['left'].set_color('darkgrey')
        ax[1].spines['left'].set_linewidth(2)
        ax[1].spines['right'].set_color('none')        
        ax[1].spines['bottom'].set_position('zero')
        # ax[1].spines['bottom'].set_color('darkgrey')
        ax[1].spines['bottom'].set_linewidth(2)
        ax[1].spines['top'].set_color('none')

        ax[0].legend(frameon=False)    
        ax[1].legend(frameon=False)    
        return fig, ax

    def plotMetrics(self,total_time=None, x_ticks=None):
        t         = []

        W_snr         = [] 
        W_notheta_snr = []

        W_r2 = []
        W_notheta_r2  = []

        x = self.mazeAgent.centres[:,0]

        M = rowAlignMatrix(self.mazeAgent.snapshots.iloc[-1]['M'])

        for i in range(len(self.mazeAgent.snapshots)-1):
            snapshot = self.mazeAgent.snapshots.iloc[i]
            time = snapshot['t']
            if time >= 31:
                
                R2_W, R2_Wnotheta, SNR_W, SNR_Wnotheta, skew_W, skew_Wnotheta, skew_M, peak_W, peak_Wnotheta, peak_M = self.mazeAgent.getMetrics(time=time)
        
                t.append(time/60)

                W_snr.append(SNR_W)
                W_notheta_snr.append(SNR_Wnotheta)
            
                W_r2.append(R2_W)
                W_notheta_r2.append(R2_Wnotheta)

        
        snapshot = self.mazeAgent.snapshots.iloc[-1]
        time = snapshot['t']

        fig, ax = plt.subplots(2,1,figsize=(1,2),sharex=True)

        W_r2, W_notheta_r2 = np.array(W_r2), np.array(W_notheta_r2)
        thresh = 0.5 
        t_50 = t[np.argmin(np.abs(W_r2 - thresh))] 
        t_50_notheta = t[np.argmin(np.abs(W_notheta_r2 - thresh))] 
        print(f"t0.5 {t_50:.3f} vs t0.5notheta {t_50_notheta:.3f}")

        end = -1
        if total_time is not None: 
            t_last = np.argmin(np.abs(np.array(t) - total_time))
            end = t_last
        ax[1].plot(t[:end],W_snr[:end],c='C1',linewidth=2, label=r"$\theta$")
        ax[1].plot(t[:end],W_notheta_snr[:end],c='C1',linewidth=1.5,linestyle='--', dashes=(1, 1),alpha=0.7,label=r"No $\theta$")

        ax[0].plot(t[:end],W_r2[:end],c='C1',linewidth=2)
        ax[0].plot(t[:end],W_notheta_r2[:end],c='C1',linewidth=1.5,linestyle='--', dashes=(1, 1),alpha=0.7)




        ax[1].set_ylim(bottom=0,top=max(W_snr)+0.15)
        ax[0].set_ylim(bottom=0,top=1)

        ticks = (x_ticks or [0,15,30])
        ticklabels = [""]*len(ticks)
        ax[1].set_xticks(ticks)
        ax[1].set_xticklabels(ticklabels)
        ax[0].set_xticks(ticks)
        ax[0].set_xticklabels(ticklabels)

        # ax[1].tick_params(width=2,color='darkgrey')
        # ax[0].tick_params(width=2,color='darkgrey')
        ax[1].set_yticks([0,3,6])
        ax[1].set_yticklabels(["","",""])
        ax[0].set_yticks([0,0.5,1])
        ax[0].set_yticklabels(["","",""])

        for i in range(2):

            ax[i].spines['left'].set_position('zero')
            # ax[i].spines['left'].set_color('darkgrey')
            ax[i].spines['left'].set_linewidth(2)
            ax[i].spines['right'].set_color('none')        
            ax[i].spines['bottom'].set_position('zero')
            # ax[i].spines['bottom'].set_color('darkgrey')
            ax[i].spines['bottom'].set_linewidth(2)
            ax[i].spines['top'].set_color('none')
        
        return fig, ax 


    def plotFieldSilhouette(self, N=25, plot_pf=True, plot_pf_notheta=False, plot_pf_M=True, plot_rf=False,no_theta=False):
        hist_id = self.snapshots['t'].sub(30*60).abs().to_numpy().argmin()
        snapshot = self.snapshots.iloc[hist_id-22]

        x = self.mazeAgent.discreteCoords[10,:,0]
        rf = self.mazeAgent.discreteStates[10,:,N]
        pf = self.mazeAgent.getPlaceFields(M=self.mazeAgent.W, threshold=0)[N][10,:]
        pf_notheta = self.mazeAgent.getPlaceFields(M=snapshot['W_notheta'], threshold=0)[N][10,:]
        pf_M = self.mazeAgent.getPlaceFields(M=self.mazeAgent.M, threshold=0)[N][10,:]
        rf, pf, pf_notheta, pf_M = rf/np.trapz(rf,x), pf/np.trapz(pf,x), pf_notheta/np.trapz(pf_notheta,x), pf_M/np.trapz(pf_M,x)

        fig, ax = plt.subplots(figsize=(2,0.5))
        ax.set_xlim(0,5)
        if plot_rf == True:
            ax.fill_between(x[rf>=0],rf[rf>=0],0,facecolor="C2",alpha=0.5)
        if plot_pf_M == True:
            ax.fill_between(x[pf_M>=0],pf_M[pf_M>=0],0,facecolor="C0",alpha=0.5)        
        if plot_pf == True:
            ax.fill_between(x[pf>=0],pf[pf>=0],0,facecolor="C1",alpha=0.5)
        if plot_pf_notheta == True:
            ax.fill_between(x[pf_notheta>=0],pf_notheta[pf_notheta>=0],0,facecolor="C1",alpha=0.5)
        
        pf = self.mazeAgent.getPlaceFields(M=self.mazeAgent.W, threshold=0)[N]
        pf_notheta = self.mazeAgent.getPlaceFields(M=snapshot['W_notheta'], threshold=0)[N]
        pf_M = self.mazeAgent.getPlaceFields(M=self.mazeAgent.M, threshold=0)[N]
        print("R2: M-->W=",Rsquared(pf,pf_M),"M-->Wnotheta=",Rsquared(pf_notheta,pf_M))
        tpl.xyAxes(ax)
        ax.spines['left'].set_color('none')
        ax.set_xticks([0,2.5,5])
        ax.set_yticks([])
        ax.set_xticklabels(["","",""])
        plt.tight_layout()

        return fig, ax

## **Helper Functions**
- fitEllipse
    - fits an ellipse to an image -> probably used to fit an ellipse to the heat map of a place field to check for elongation along walls
- rowAlignMatrix
    - aligns the rows of the matrix M
- saveFigure
    - saves figure to file by folder and time/name -> returns path where figure was saved
- pickleAndSave
    - pickels and saves class -> not efficient but easy and overwrites previous saves without generating a warning (!!!) -> what is pickle tho??
- loadAndDepickle
    - loads and depickles class -> using pickleAndSave
- Rsquared
    - computes R2 between two arrays -> easy
- getCOM
    - does sth with averaging rows or columns in an array
- getMoment
    - "Get moments not from sample of points but from a function (list of x and F(x)=y) -> don't know really
- getCircularMoment
    - nope
- getSkewness
    - calculates the skewness of a certain object/function/curve idk
- getPeak
    - gets the peak of a function x,y
- ornstein_uhlenbeck
    - some special case thingy

In [None]:
def fitEllipse(image, threshold=0.5,coords=None,verbose=True):
    """Takes an array (image) and fits an ellipse to it. 
    It does this by finding contours then regressing these points against the formula Ax2 + Bxy + Cy2 + Dx + Ey = 1

    Args:
        image (np.array()): The image or array to which the 
        threshold (float, optional): The relativethreshold upon whch the edges of the image will be defined. Defaults to 0.75.
        coords (np.array(image.shape,2)): The underlying coordinates of the image . Defaults to None in which case tries to get them .

    Returns:
        tuple of arrays: the x, y and z coords of the ellipse function (can be contour plotted on top of image)
    """ 

    fig2, ax2 = plt.subplots()
    cs = ax2.contour(coords[...,0],coords[...,1],image,np.array([threshold*max(image.flatten())])).collections[0].get_paths()[0].vertices
    plt.close()
    x,y = cs[:,0], cs[:,1]
    X = np.stack((x**2,x*y,y**2,x,y)).T
    F = 1
    coords_ = coords.reshape(-1,2)
    xmin,xmax,ymin,ymax=min(coords_[:,0]),max(coords_[:,0]),min(coords_[:,1]),max(coords_[:,1])
    Y = F*np.ones(X.shape[0])
    (A,B,C,D,E) = np.matmul(np.linalg.inv(np.matmul(X.T,X) + 0.0*np.identity(X.T.shape[0])),np.matmul(X.T,Y)) #least squares fit ellipse Ax2 + Bxy + Cy2 + Dx + Ey = F
    x_coord = np.linspace(xmin,xmax,1000)
    y_coord = np.linspace(ymin,ymax,1000)   
    X_coord, Y_coord = np.meshgrid(x_coord, y_coord)
    Z_coord = A * X_coord ** 2 + B * X_coord * Y_coord + C * Y_coord**2 + D * X_coord + E * Y_coord 
    #finally get eccentricity 
    m = np.array([[A,   B/2,  D/2],
                    [B/2, C,    E/2],
                    [D/2, E/2,  F]])
    if np.linalg.det(m) < 0:
        eta = 1
    else:
        eta = -1
    eccen = np.sqrt((2*np.sqrt((A-C)**2 + B**2))/(eta*(A+C)+np.sqrt((A-C)**2 + B**2)))
    if verbose == True:
        print("Eccentricity = %.3f" %eccen)

    angle = np.arctan((1/B)*(C-A-np.sqrt((A-C)**2+B**2)))

    return (X_coord, Y_coord, Z_coord), eccen, angle

def rowAlignMatrix(M):
    M_copy = M.copy()
    roll = int(M.shape[0]/2)
    for i in range(M.shape[0]):
        M_copy[i,:] = np.roll(M[i,:],-i+roll)
    return M_copy

def saveFigure(fig,saveTitle="",transparent=True,anim=False,specialLocation=None,figureDirectory="../figures/"):
    """saves figure to file, by data (folder) and time (name) 
    Args:
        fig (matplotlib fig object): the figure to be saved
        saveTitle (str, optional): name to be saved as. Current time will be appended to this Defaults to "".
    """	

    today =  datetime.strftime(datetime.now(),'%y%m%d')
    if not os.path.isdir(figureDirectory + f"{today}/"):
        os.mkdir(figureDirectory + f"{today}/")
    figdir = figureDirectory + f"{today}/"
    now = datetime.strftime(datetime.now(),'%H%M')
    path_ = f"{figdir}{saveTitle}_{now}"
    path = path_
    i=1
    while True:
        if os.path.isfile(path+".pdf") or os.path.isfile(path+".mp4"):
            path = path_+"_"+str(i)
            i+=1
        else: break
    if anim == True:
        fig.save(path + ".mp4")
    else:
        fig.savefig(path+".pdf", dpi=400,transparent=transparent)
    
    if specialLocation is not None: 
        fig.savefig(specialLocation, dpi=400,transparent=transparent,bbox_inches='tight')

    return path

def pickleAndSave(class_,name,saveDir='../savedObjects/'):
	"""pickles and saves a class
	this is not an efficient way to save the data, but it is easy 
	this will overwrite previous saves without warning
	Args:
		class_ (any class): the class/model to save
		name (str): the name to save it under 
		saveDir (str, optional): Directory to save into. Defaults to './savedItems/'.
	"""	
	with open(saveDir + name+'.pkl', 'wb') as output:
		dill.dump(class_, output)
	return 

def loadAndDepickle(name, saveDir='../savedObjects/'):
	"""Loads and depickles a class saved using pickleAndSave
	Args:
		name (str): name it was saved as
		saveDir (str, optional): Directory it was saved in. Defaults to './savedItems/'.
	Returns:
		class: the class/model
	"""	
	with open(saveDir + name+'.pkl', 'rb') as input:
		item = dill.load(input)
	return item

def Rsquared(y1, y2):
    """R squared between two arrays 

    Args:
        y1 (np.array()): array 1
        y2 (np.array()): array 2

    Returns:
        float: R squared value between them 
    """    
    return ((1/y1.size) * np.sum((y1-np.mean(y1)) * (y2-np.mean(y2))) / (np.std(y1) * np.std(y2)))**2


def getCOM(array):
    print(array.shape)
    i_av, j_av = 0, 0
    for i in range(array.shape[0]):
        for j in range(array.shape[1]):
            i_av += array[i,j]*i
            j_av += array[i,j]*j
    i_av /= np.sum(array)
    j_av /= np.sum(array)
    i_av = int(i_av)
    j_av = int(j_av)
    return (i_av,j_av)

def getMoment(x,y,moment=1,c=0):
    """Get moments not from sample of points but from a function (list of x, and f(x)=y)

    Args:
        x (np.array): independent variable
        y (np.array): dependent variable
        moment (int, optional): which moment ot get. Defaults to 1.
        c (float): about which point to find moment, defaults to zero
    """    
    s_x = 0
    s_y = 0
    for i in range(len(x)):
        s_x += ( (x[i] - c)**moment ) * y[i]
        s_y += y[i]
    return s_x / s_y

def getCircularMoment(theta,y,moment=1,c=0):
    Cp = np.sum(np.cos(moment*theta)*y) / np.sum(y)
    Sp = np.sum(np.sin(moment*theta)*y) / np.sum(y)
    Rp = np.sqrt(Cp**2 + Sp**2)
    if Cp > 0 and Sp > 0: 
        Tp = np.arctan(Sp/Cp)
    elif Cp < 0: 
        Tp = np.arctan(Sp/Cp) + np.pi
    elif Sp < 0 and Cp > 0: 
        Tp = np.arctan(Sp/Cp) + 2*np.pi
    return Rp, Tp



def getSkewness(y,circular=False):
    if not np.all(y>=0):
        y = np.maximum(y,0)
    x = np.linspace(0,2*np.pi,len(y))
    if circular == False: 
        mean = getMoment(x,y)
        std = np.sqrt(getMoment(x,y,moment=2,c=mean))
        skewness = getMoment(x,y,moment=3,c=mean) / std**3
    if circular == True: #NCSS Statistical Software NCSS.com, Chapter 230, Circular Data Analysis, https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/NCSS/Circular_Data_Analysis.pdf
        R1, T1 = getCircularMoment(x,y,moment=1)
        R2, T2 = getCircularMoment(x,y,moment=2)
        V = 1-R1
        skewness = R2*np.sin(2*T1-T2) / (1-R1)**(3/2)

    return skewness


def getPeak(x,y,smooth=True):
    if smooth == True: 
        y_smooth = np.empty_like(y)
        for i in range(len(y)):
            y_smooth[i] = np.mean(y[max(0,i-1):min(i+2,len(y))])
        peak = x[np.argmax(y_smooth)]
    else:
        peak = x[np.argmax(y)]
    return peak 

def fisherZ(r):
    return np.log((1+r)/(1-r)) / 2


def ornstein_uhlenbeck(dt, x, drift=0.0, noise_scale=0.2, coherence_time=5.0):
    """An ornstein uhlenbeck process in x.
    x can be multidimensional 
    Args:
        dt: update time step
        x: the stochastic variable being updated
        drift (float, or same type as x, optional): [description]. Defaults to 0.
        noise_scale (float, or same type as x, optional): Magnitude of deviations from drift. Defaults to 0.2 (20 cm s^-1 if units of x are in metres).
        coherence_time (float, optional): Effectively over what time scale you expect x to change directions. Defaults to 5.
    Returns:
        dx (same type as x); the required update to x
    """
    x = np.array(x)
    drift = drift * np.ones_like(x)
    noise_scale = noise_scale * np.ones_like(x)
    coherence_time = coherence_time * np.ones_like(x)
    sigma = np.sqrt((2 * noise_scale ** 2) / (coherence_time * dt))
    theta = 1 / coherence_time
    dx = theta * (drift - x) * dt + sigma * np.random.normal(size=x.shape, scale=dt)
    return dx