## Imports

In [None]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import nest as sim
import numpy as np
import pandas
from collections import Counter
import time as tm
import scipy.stats
import scipy.io
from time import sleep, process_time
import sys
import os
import pandas as pd
from scipy.stats import pearsonr, spearmanr

## Functions

In [None]:
def cart2ring(x, y, offset):

    if not isinstance(x, np.ndarray):

        x = np.array(x)

    if not isinstance(y, np.ndarray):

        y = np.array(y)

    assert x.size == y.size
    assert x.size > 0

    if x.size == 1:

        projection_r1 = np.dot(np.array([x, y]), np.array([x, -x / np.tan(np.radians(offset))])) / np.linalg.norm(np.array([x, -x / np.tan(np.radians(offset))]))
        projection_r2 = np.dot(np.array([x, y]), np.array([x, -x / np.tan(np.radians(60 + offset))])) / np.linalg.norm(np.array([x, -x / np.tan(np.radians(60 + offset))]))
        projection_r3 = np.dot(np.array([x, y]), np.array([x, x / np.tan(np.radians(60 - offset))])) / np.linalg.norm(np.array([x, x / np.tan(np.radians(60 - offset))]))

    elif x.size > 1:

        projection_r1 = np.array([np.dot(np.array([i, j]), np.array([i, -i / np.tan(np.radians(offset))])) / np.linalg.norm(np.array([i, -i / np.tan(np.radians(offset))])) for i, j in zip(x, y)])
        projection_r2 = np.array([np.dot(np.array([i, j]), np.array([i, -i / np.tan(np.radians(60 + offset))])) / np.linalg.norm(np.array([i, -i / np.tan(np.radians(60 + offset))])) for i, j in zip(x, y)])
        projection_r3 = np.array([np.dot(np.array([i, j]), np.array([i, i / np.tan(np.radians(60 - offset))])) / np.linalg.norm(np.array([i, i / np.tan(np.radians(60 - offset))])) for i, j in zip(x, y)])


    else:

        raise NotImplementedError

    offset = offset % 360

    if offset >= 0 and offset < 60:

        ring1 = -np.sign(x) * projection_r1
        ring2 = -np.sign(x) * projection_r2
        ring3 = -np.sign(x) * projection_r3

    elif offset >= 60 and offset < 120:

        ring1 = -np.sign(x) * projection_r1
        ring2 = -np.sign(x) * projection_r2
        ring3 = np.sign(x) * projection_r3

    elif offset >= 120 and offset < 180:

        ring1 = -np.sign(x) * projection_r1
        ring2 = np.sign(x) * projection_r2
        ring3 = np.sign(x) * projection_r3

    elif offset >= 180 and offset < 240:

        ring1 = np.sign(x) * projection_r1
        ring2 = np.sign(x) * projection_r2
        ring3 = np.sign(x) * projection_r3

    elif offset >= 240 and offset < 300:

        ring1 = np.sign(x) * projection_r1
        ring2 = np.sign(x) * projection_r2
        ring3 = -np.sign(x) * projection_r3

    elif offset >= 300 and offset < 360:

        ring1 = np.sign(x) * projection_r1
        ring2 = -np.sign(x) * projection_r2
        ring3 = -np.sign(x) * projection_r3

    return np.array([ring1, ring2, ring3]).T

def ring2cart(ring1, ring2, ring3, offset):

    if not isinstance(ring1, np.ndarray):

        ring1 = np.array(ring1)

    if not isinstance(ring2, np.ndarray):

        ring2 = np.array(ring2)

    if not isinstance(ring3, np.ndarray):

        ring3 = np.array(ring3)

    assert ring1.size > 0    

    danger_values_r1 = [x for x in range(0, 360, 90)] # r1 will align with x or y at these offset values
    danger_values_r2 = [x for x in range(30, 360, 90)] # r2 will align with x or y at these offset values
    danger_values_r3 = [x for x in range(60, 360, 90)] # r3 will align with x or y at these offset values

    ### New method: Use intersection of normals to find the corresponding (x,y) 

    # Draw ring axes that span the length of the arena in question
    # The maximum extent of x and y are equal to the longest ring
    # The largest ratio of ring:cartesian values are if a ring axis is aligned exactly to x or y
    # Therefore, no point in ring space can be outside the corresponding bounds in cartesian space

    offset = offset % 360

    if offset not in danger_values_r1 and offset not in danger_values_r2 and offset not in danger_values_r3:

        max_x = np.max([ring1, ring2, ring3], axis = 0) / np.cos(np.radians(np.max([offset, offset + 60, offset + 120])))
        min_x = -max_x

        ring1_y = ring1 / np.cos(np.radians(offset))
        ring2_y = ring2 / np.cos(np.radians(60 + offset))
        ring3_y = -ring3 / np.cos(np.radians(60 - offset))

        y_r1_n_start    =  -max_x * np.tan(np.radians(offset)) + ring1_y
        y_r1_n_end      =  -min_x * np.tan(np.radians(offset)) + ring1_y

        y_r2_n_start    =  -max_x * np.tan(np.radians(60 + offset)) + ring2_y
        y_r2_n_end      =  -min_x * np.tan(np.radians(60 + offset)) + ring2_y

        y_r3_n_start    =  max_x * np.tan(np.radians(60 - offset)) + ring3_y
        y_r3_n_end      =  min_x * np.tan(np.radians(60 - offset)) + ring3_y

        # Get start and end points of ring axes

        start_r1_n = np.array([min_x, y_r1_n_start])
        start_r2_n = np.array([min_x, y_r2_n_start])
        start_r3_n = np.array([min_x, y_r3_n_start])

        end_r1_n = np.array([max_x, y_r1_n_end])
        end_r2_n = np.array([max_x, y_r2_n_end])
        end_r3_n = np.array([max_x, y_r3_n_end])

        x_values = np.empty(shape = (ring1.size, 3))
        y_values = np.empty(shape = (ring1.size, 3))

        # Calculate where each pair intersects

        x1,y1 = start_r1_n
        x2,y2 = end_r1_n
        x3,y3 = start_r2_n
        x4,y4 = end_r2_n

        denom = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)

        denom[(denom < 0.001) & (denom >= 0)] = 0.001
        denom[(denom > -0.001) & (denom < 0)] = -0.001

        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denom
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denom
        x = x3 + ua * (x4-x3)
        y = y3 + ua * (y4-y3)

        x_values[:, 0] = x
        y_values[:, 0] = y

        x1,y1 = start_r2_n
        x2,y2 = end_r2_n
        x3,y3 = start_r3_n
        x4,y4 = end_r3_n

        denom = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)

        denom[(denom < 0.001) & (denom >= 0)] = 0.001
        denom[(denom > -0.001) & (denom < 0)] = -0.001

        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denom
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denom
        x = x3 + ua * (x4-x3)
        y = y3 + ua * (y4-y3)

        x_values[:, 1] = x
        y_values[:, 1] = y

        x1,y1 = start_r1_n
        x2,y2 = end_r1_n
        x3,y3 = start_r3_n
        x4,y4 = end_r3_n

        denom = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)

        denom[(denom < 0.001) & (denom >= 0)] = 0.001
        denom[(denom > -0.001) & (denom < 0)] = -0.001

        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denom
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denom
        x = x3 + ua * (x4-x3)
        y = y3 + ua * (y4-y3)

        x_values[:, 2] = x
        y_values[:, 2] = y

    elif offset in danger_values_r1:

        max_x = np.max([ring2, ring3], axis = 0) / np.cos(np.radians(np.max([offset, offset + 60, offset + 120])))
        min_x = -max_x

        ring2_y = ring2 / np.cos(np.radians(60 + offset))
        ring3_y = -ring3 / np.cos(np.radians(60 - offset))

        y_r2_n_start    =  -max_x * np.tan(np.radians(60 + offset)) + ring2_y
        y_r2_n_end      =  -min_x * np.tan(np.radians(60 + offset)) + ring2_y

        y_r3_n_start    =  max_x * np.tan(np.radians(60 - offset)) + ring3_y
        y_r3_n_end      =  min_x * np.tan(np.radians(60 - offset)) + ring3_y

        # Get start and end points of ring axes

        start_r2_n = np.array([min_x, y_r2_n_start])
        start_r3_n = np.array([min_x, y_r3_n_start])

        end_r2_n = np.array([max_x, y_r2_n_end])
        end_r3_n = np.array([max_x, y_r3_n_end])

        x_values = np.empty(shape = (ring1.size, 1))
        y_values = np.empty(shape = (ring1.size, 1))

        # Calculate where each pair intersects

        x1,y1 = start_r2_n
        x2,y2 = end_r2_n
        x3,y3 = start_r3_n
        x4,y4 = end_r3_n

        denom = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)

        denom[(denom < 0.001) & (denom >= 0)] = 0.001
        denom[(denom > -0.001) & (denom < 0)] = -0.001

        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denom
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denom
        x = x3 + ua * (x4-x3)
        y = y3 + ua * (y4-y3)

        x_values[:, 0] = x
        y_values[:, 0] = y

    elif offset in danger_values_r2:

        max_x = np.max([ring1, ring3], axis = 0) / np.cos(np.radians(np.max([offset, offset + 60, offset + 120])))
        min_x = -max_x

        ring1_y = ring1 / np.cos(np.radians(offset))
        ring3_y = -ring3 / np.cos(np.radians(60 - offset))

        y_r1_n_start    =  -max_x * np.tan(np.radians(offset)) + ring1_y
        y_r1_n_end      =  -min_x * np.tan(np.radians(offset)) + ring1_y

        y_r3_n_start    =  max_x * np.tan(np.radians(60 - offset)) + ring3_y
        y_r3_n_end      =  min_x * np.tan(np.radians(60 - offset)) + ring3_y

        # Get start and end points of ring axes

        start_r1_n = np.array([min_x, y_r1_n_start])
        start_r3_n = np.array([min_x, y_r3_n_start])

        end_r1_n = np.array([max_x, y_r1_n_end])
        end_r3_n = np.array([max_x, y_r3_n_end])

        x_values = np.empty(shape = (ring1.size, 1))
        y_values = np.empty(shape = (ring1.size, 1))

        # Calculate where each pair intersects

        x1,y1 = start_r1_n
        x2,y2 = end_r1_n
        x3,y3 = start_r3_n
        x4,y4 = end_r3_n

        denom = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)

        denom[(denom < 0.001) & (denom >= 0)] = 0.001
        denom[(denom > -0.001) & (denom < 0)] = -0.001

        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denom
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denom
        x = x3 + ua * (x4-x3)
        y = y3 + ua * (y4-y3)

        x_values[:, 0] = x
        y_values[:, 0] = y

    elif offset in danger_values_r3:

        max_x = np.max([ring2, ring3], axis = 0) / np.cos(np.radians(np.max([offset, offset + 60, offset + 120])))
        min_x = -max_x

        ring1_y = ring1 / np.cos(np.radians(offset))
        ring2_y = ring2 / np.cos(np.radians(60 + offset))

        y_r1_n_start    =  -max_x * np.tan(np.radians(offset)) + ring1_y
        y_r1_n_end      =  -min_x * np.tan(np.radians(offset)) + ring1_y

        y_r2_n_start    =  -max_x * np.tan(np.radians(60 + offset)) + ring2_y
        y_r2_n_end      =  -min_x * np.tan(np.radians(60 + offset)) + ring2_y

        # Get start and end points of ring axes

        start_r1_n = np.array([min_x, y_r1_n_start])
        start_r2_n = np.array([min_x, y_r2_n_start])

        end_r1_n = np.array([max_x, y_r1_n_end])
        end_r2_n = np.array([max_x, y_r2_n_end])

        x_values = np.empty(shape = (ring1.size, 1))
        y_values = np.empty(shape = (ring1.size, 1))

        # Calculate where each pair intersects

        x1,y1 = start_r1_n
        x2,y2 = end_r1_n
        x3,y3 = start_r2_n
        x4,y4 = end_r2_n

        denom = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)

        denom[(denom < 0.001) & (denom >= 0)] = 0.001
        denom[(denom > -0.001) & (denom < 0)] = -0.001

        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denom
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denom
        x = x3 + ua * (x4-x3)
        y = y3 + ua * (y4-y3)

        x_values[:, 0] = x
        y_values[:, 0] = y


    x = np.squeeze(np.mean(x_values, axis = 1))
    y = np.squeeze(np.mean(y_values, axis = 1))

    return np.array([x, y]).T

def wrap_to_distance(distance, boundary):

    wrapped = distance.copy()

    wrapped[distance > 0] = distance[distance > 0] % boundary
    wrapped[distance < 0] = distance[distance < 0] % boundary

    return wrapped

def ring_mean_activity(data, centre = True):

    # Data is expected to be a single ring's activity history of shape (timesteps, ring_size)

    # Centre == False: rays are from 0 -> 2pi, half-open interval. Centre == True: rays are adjusted to project from halfway along their arc

    ring_size = data.shape[1]

    arc_per_ring_segment = (2 * np.pi) / ring_size

    rays = np.repeat(np.arange(0, 2 * np.pi, arc_per_ring_segment).reshape(1, -1), data.shape[0], axis = 0)

    if centre:

        rays = rays + arc_per_ring_segment / 2

    rays_for_each_spike = np.empty(shape = (data.shape[0]), dtype = 'object')

    mean_activity = np.empty(shape = (data.shape[0]), dtype = 'float')

    #for i, ray, count in enumerate(zip(rays, ray_counts)):
    for i in range(rays_for_each_spike.shape[0]):

        if len(np.nonzero(data[i,:])) > 0:

            rays_for_each_spike[i] = np.repeat(rays[i,:][np.nonzero(data[i,:])], data[i, :][np.nonzero(data[i,:])].astype('int'))

            mean_activity[i] = np.arctan2(np.mean(np.sin(rays_for_each_spike[i])), np.mean(np.cos(rays_for_each_spike[i]))) % (2 * np.pi)

        else:

            mean_activity[i] = 0

        mean_activity[np.isnan(mean_activity)] = 0

    mean_activity_ring_index = mean_activity * (ring_size / (2*np.pi))

    return mean_activity_ring_index

In [None]:
for i in range(100):

    sim.set_verbosity("M_ERROR")
    
    sim.ResetKernel()
    
    sim.local_num_threads = 8 # NEST recommends 1 thread per core

    # 'simulate': Run on NEST to generate results, this will save the results to .npy files
    # 'load': Load a prior run without running the NEST simulator

    simulate_or_load = 'simulate'

    # 'spiral': Generate a spiral trajectory
    # 'rat': Load one or more Sargolini datasets from file

    spiral_or_rat = 'rat'
    concatenate_rat = False

    np.setbufsize(8192 * 4)

    # Population cell counts

    # N_ex: number of cells in the excitatory rings
    # N_in: number of cells in the inhibitory rings
    # N_cj: number of cells in the conjunctive rings
    # rings: number of mono-axis rings ('directional rings')
    # omni_rings: number of axis-invariant rings ('speed rings')
    # window_size: how large a window of excitatory ring cells is each principle axis cell sensitive to?
    # N_pa_cells_per_ring: how many PA cells exist for each ring (more = finer binning of excitatory ring activity)
    # N_pyramidals: how many pyramidal cells are there?

    calibration_mode = False # Set input to 0, to study static ring
    corrections = True

    N_ex = 120
    N_in = N_ex
    N_cj = N_ex

    rings = 3
    omni_rings = 0

    window_size = 15

    N_pa_cells_per_ring = N_ex // window_size

    rp_offset = 0

    N_pyramidals = N_pa_cells_per_ring ** rings

    tension = True

    minimum_input = 2700#675#2000#1000

    # Connection Gaussian weight parameters

    # sigma: variance for excitatory ring -> inhibitory ring weight Gaussian
    # in_sigma: variance for inhibitory ring -> excitatory ring weight Gaussian
    # in_cj_sigma: variance for inhibitory ring -> conjunctive rings weight Gaussian
    # mu: mean offset of Gaussians (when used)
    # prune_smaller_than: weights below this threshold will become 0, effectively removing the connection

    sigma = 0.1#0.12#0.1
    in_sigma = 0.1
    in_cj_sigma = 0.09375#0.075#0.095#0.0925#0.08#0.09#0.08 # was at about 1-1.02 before, when coming back to this, give these a try
    mu = 0.5
    prune_smaller_than = 10

    smooth_sigma = 10

    # Connection scalar weight parameters. These are signed as appropriate later on

    # base_ex: excitatory ring -> inhibitory ring weight strength (+)
    # base_in: inhibitory ring -> excitatory ring weight strength (-)
    # base_cj: conjunctive rings -> excitatory rings weight strength (+)
    # w_ex_cj: excitatory rings -> conjunctive rings weight strength (+)
    # w_in_cj: inhibitory_rings -> conjunctive rings weight strength (-)
    # w_ex_pa: excitatory_rings -> principle axis cells weight strength (+)
    # w_pa_py: principle axis cells -> pyramidal cells weight strength (+)

    # Spiral settings:

    # base_ex = 5000
    # base_in = 1500
    # base_cj = 500
    # w_ex_cj = 440
    # w_in_cj = 1800 # Was about 800 before, put up a lot higher before velocity calculation was changed to properly share out input
    # w_ex_pa = 300
    # w_pa_py = 200

    # Rat settings:

    # Next week: dropping inhibition is helping with integrating small velocities, try to lower it further and lower excitation too if needs be

    base_ex = 1750#4000#5000
    base_in = 0#5000#1500
    base_cj = 500
    w_ex_cj = 0
    w_in_cj = 3000#3500#4000#4500#3600#1700#700 # Was about 800 before, put up a lot higher before velocity calculation was changed to properly share out input
    w_ex_pa = 100#80
    w_pa_py = 300
    w_in_pa = 1000

    cj_in_offset = 0 # Is the inhibitory 'bowl' biased towards the direction of input? If so, by how many cells?

    # Synaptic transmission delay (I believe this includes the synapse proper and action potential travel time)

    delay = 0.1

    # Velocity scaling parameters

    # I_vel: multiply incoming velocity by this amount to get the input current representing the vestibular signal
    # velocity_threshold: Are very small values for velocity set to zero?
    # miniumum_velocity: What is the minimum non-zero velocity? (only works if velocity_threshold is True)

    I_vel = 2000000#800 # Seems to work best if you can get the velocities in a 0-4000 range

    # Are conjuctive cells synapsed onto by excitatory layer or inhibitory layer.

    # 'positive': scalar excitatory weight, conjunctive weights must be tuned to act as a coincidence detector for input velocity and bump activity
    # 'negative': Gaussian inhibitory weight, suppresses incoming velocity input too far from the attractor bump
    # 'both': inhibitory 'bowl' as per 'negative' and self-reinforcing excitatory connections

    conjunctive_mode = 'negative'

    # Intrinsic excitation of the excitatory ring, constant current in picoamps

    intrinsic_excitation = 0.#225.

    theta = False
    
    stochastic_membrane_potential = True
    
    stochastic_input_current = True

    # Initial (bump-forming) current injection parameters. This is a short spike of input to form the initial attractor state, to be adjusted by conjunctive input

    # I_init: strength of input current in picoamps
    # I_init_dur: how long this is applied for, in milliseconds
    # I_init_pos: where is this applied, in ring index (1-120). NEST, for better or worse, has neuron IDs starting at 1

    I_init = 350.0#300.0
    I_init_dur = 100.0
    I_init_pos = 60 - 1#(N_ex - 1)
    
    N_vp = sim.GetKernelStatus(['total_num_virtual_procs'])[0]

    if not os.path.exists("results_window_size_15.csv"):

        master_seed = 9032867582

        N_vp = sim.GetKernelStatus(['total_num_virtual_procs'])[0]

        sim.SetKernelStatus({'grng_seed' : master_seed+N_vp})

        sim.SetKernelStatus({'rng_seeds' : range(master_seed+N_vp+1, master_seed+2*N_vp+1)})

        membrane_seed = 2390786556

        input_seed = 6983476394

    else:

        results_dataframe = pd.read_csv("results_window_size_15.csv")

        master_seed = results_dataframe["Master Seed"].to_numpy()[-1] + 1

        sim.SetKernelStatus({'grng_seed' : master_seed+N_vp})

        sim.SetKernelStatus({'rng_seeds' : range(master_seed+N_vp+1, master_seed+2*N_vp+1)})

        membrane_seed = results_dataframe["Membrane Seed"].to_numpy()[-1] + 1

        input_seed = results_dataframe["Input Seed"].to_numpy()[-1] + 1
        
    membrane_rng = np.random.default_rng(seed = membrane_seed)
    
    input_rng = np.random.default_rng(seed = input_seed)

    ## Create neuron populations from the above parameters

    # Lists to store each ring's population

    exc = []
    inh = []
    l = []
    r = []
    pa_cells = []
    #input_pa_cells = []

    for i in range(rings):

        exc.append(sim.Create("iaf_psc_alpha", N_ex))

        inh.append(sim.Create("iaf_psc_alpha", N_in)) # Inhibitory layer

        l.append(sim.Create("iaf_psc_alpha", N_cj)) # Conjunctive layer for left turn
        r.append(sim.Create("iaf_psc_alpha", N_cj)) # Conjunctive layer for right turn

        pa_cells.append(sim.Create("iaf_psc_alpha", N_pa_cells_per_ring))

        #input_pa_cells.append(sim.Create("iaf_psc_alpha",N_pa_cells_per_ring))

    # The pyramidal cells associate across rings    

    pyramidal_cells = sim.Create("iaf_psc_alpha", N_pyramidals)
    
    if stochastic_membrane_potential:
        
        for ring in range(rings):
            
            for neuron in exc[ring]:
                
                sim.SetStatus([neuron], params = {"V_m": membrane_rng.integers(low = -70, high = 55) * 1.})
                
            for neuron in inh[ring]:
                
                sim.SetStatus([neuron], params = {"V_m": membrane_rng.integers(low = -70, high = 55) * 1.})
                
            for neuron in l[ring]:
                
                sim.SetStatus([neuron], params = {"V_m": membrane_rng.integers(low = -70, high = 55) * 1.})
                
            for neuron in r[ring]:
                
                sim.SetStatus([neuron], params = {"V_m": membrane_rng.integers(low = -70, high = 55) * 1.})
                              
            for neuron in pa_cells[ring]:
                
                sim.SetStatus([neuron], params = {"V_m": membrane_rng.integers(low = -70, high = 55) * 1.})
                
        for neuron in pyramidal_cells:

            sim.SetStatus([neuron], params = {"V_m": membrane_rng.integers(low = -70, high = 55) * 1.})

                
    input_grid_devices = sim.Create('step_current_generator', N_pyramidals)

    ## Define connection weight matrices

    # Empty matrices

    w_ex = np.empty((N_in,N_ex))
    w_in = np.empty((N_ex,N_in))

    for e in range(N_ex):
        for i in range(N_in):
            # Find minimum (true) distance between adjacent cells
            d1 = abs(e/N_ex - i/N_in)
            d2 = abs(e/N_ex - i/N_in -1)
            d3 = abs(e/N_ex - i/N_in +1)
            d = min(abs(d1),abs(d2),abs(d3))
            # Create gaussian value based on parameters above to define connection strengths
            w_gauss = np.exp(-(d - mu)**2/2/sigma**2) # Exitatory -> inhibitory
            w_ring = np.exp(-(d)**2/2/in_sigma**2) # Inhibitory -> excitatory
            # Assign appropriate weight values to matrices
            w_ex[i,e] = base_ex * w_gauss
            w_in[e,i] = base_in * w_ring 

    # Very small weights become 0

    w_ex[w_ex<prune_smaller_than] = 0
    w_in[w_in<prune_smaller_than] = 0

    # Plot weight matrix interactions as a sanity check. Should be an 'arch' of inhibition, leaving the suppressing areas far from the injection site

    intrinsic_input = np.tile(450., N_ex)

    injection_site = I_init_pos

    # As before, connection weight matrices, this time between conjunctive layers and the excitatory layer

    w_l = np.empty((N_ex,N_cj))
    w_r = np.empty((N_ex,N_cj))

    for c in range(N_cj):  
        for e in range(N_ex):
            # Minimum distance, this time between each conjunctive cell and the excitatory cell displaced 1 away (e +/- 1)
            # Left is anticlockwise, therefore drives the cell immediately to the left
            # Right is clockwise, therefore drives the cell immediately to the right
            d1 = abs((e-1)/N_cj - c/N_ex)
            d2 = abs((e-1)/N_cj - c/N_ex -1)
            d3 = abs((e-1)/N_cj - c/N_ex +1)
            d = min(abs(d1),abs(d2),abs(d3))
            w_l[e,c] = base_cj * (np.exp(-(d)**2/2/sigma**2))

            d1 = abs((e+1)/N_cj - c/N_ex)
            d2 = abs((e+1)/N_cj - c/N_ex -1)
            d3 = abs((e+1)/N_cj - c/N_ex +1)
            d = min(abs(d1),abs(d2),abs(d3))
            w_r[e,c] = base_cj * (np.exp(-(d)**2/2/sigma**2))

    # Set all not the max to zero; makes sure the conjunctive cells only drive the immediate neighbour
    # Still uses the Gaussian connection weight, just doesn't use the whole Gaussian (for now)

    m = np.amax(w_l)
    w_l[w_l<m] = 0
    m = np.amax(w_r)
    w_r[w_r<m] = 0

    # Gaussian weight matrix for inhibitory->left conjunctive cells (if conjuctive_mode == 'negative')

    w_in_l_cj_gauss = np.empty((N_cj,N_in))

    for i in range(N_in):
        for c in range(N_cj):  
            # Minimum distance, this time between each conjunctive cell and the excitatory cell displaced 1 away (e +/- 1)
            # Left is anticlockwise, therefore drives the cell immediately to the left
            # Right is clockwise, therefore drives the cell immediately to the right
            d1 = abs((c-cj_in_offset)/N_cj - i/N_in)
            d2 = abs((c-cj_in_offset)/N_cj - i/N_in -1)
            d3 = abs((c-cj_in_offset)/N_cj - i/N_in +1)
            d = min(abs(d1),abs(d2),abs(d3))
            w_in_l_cj_gauss[c,i] = w_in_cj * (np.exp(-(d)**2/2/in_cj_sigma**2))

    w_in_l_cj_gauss = w_in_l_cj_gauss# - np.max(w_in_cj_gauss)

    # Very small weights become 0

    w_in_l_cj_gauss[w_in_l_cj_gauss<prune_smaller_than] = 0
    w_in_l_cj_gauss[w_in_l_cj_gauss<prune_smaller_than] = 0

    # Gaussian weight matrix for inhibitory->right conjunctive cells (if conjuctive_mode == 'negative')

    w_in_r_cj_gauss = np.empty((N_cj,N_in))

    for i in range(N_in):
        for c in range(N_cj):  
            # Minimum distance, this time between each conjunctive cell and the excitatory cell displaced 1 away (e +/- 1)
            # Left is anticlockwise, therefore drives the cell immediately to the left
            # Right is clockwise, therefore drives the cell immediately to the right
            d1 = abs((c+cj_in_offset)/N_cj - i/N_in)
            d2 = abs((c+cj_in_offset)/N_cj - i/N_in -1)
            d3 = abs((c+cj_in_offset)/N_cj - i/N_in +1)
            d = min(abs(d1),abs(d2),abs(d3))
            w_in_r_cj_gauss[c,i] = w_in_cj * (np.exp(-(d)**2/2/in_cj_sigma**2))

    w_in_r_cj_gauss = w_in_r_cj_gauss# - np.max(w_in_cj_gauss)

    # Very small weights become 0

    w_in_r_cj_gauss[w_in_r_cj_gauss<prune_smaller_than] = 0
    w_in_r_cj_gauss[w_in_r_cj_gauss<prune_smaller_than] = 0

    ## Wire everything up

    for i in range(rings):

        # Excitatory and inhibitory set to connect all to all, using the prior calculated weight matrix

        exc_2_inh = sim.Connect(exc[i],inh[i],'all_to_all',syn_spec={'weight': w_ex, 'delay': delay})
        inh_2_exc = sim.Connect(inh[i],exc[i],'all_to_all',syn_spec={'weight': -w_in, 'delay': delay})

        # Conjunctive layers connecting to the excitatory layer, with weights

        l_2_exc = sim.Connect(l[i],exc[i],'all_to_all',syn_spec={'weight': w_l, 'delay': delay})
        r_2_exc = sim.Connect(r[i],exc[i],'all_to_all',syn_spec={'weight': w_r, 'delay': delay})

        if conjunctive_mode == 'positive':

            # Excitatory connecting one-to-one to both conjunctive layers, with fixed weight.  A 'coincidence detector'.

            exc_2_l = sim.Connect(exc[i],l[i],'one_to_one',syn_spec={'weight': w_ex_cj, 'delay': delay})
            exc_2_r = sim.Connect(exc[i],r[i],'one_to_one',syn_spec={'weight': w_ex_cj, 'delay': delay})

        elif conjunctive_mode == 'negative':

            # Inhibitory connecting one-all_to_all-one to both conjunctive layers, with inverse Gaussian weights

            inh_2_l = sim.Connect(inh[i],l[i],'all_to_all',syn_spec={'weight': -w_in_l_cj_gauss, 'delay': delay})
            inh_2_r = sim.Connect(inh[i],r[i],'all_to_all',syn_spec={'weight': -w_in_r_cj_gauss, 'delay': delay})

        elif conjunctive_mode == 'both':

            exc_2_l = sim.Connect(exc[i],l[i],'one_to_one',syn_spec={'weight': w_ex_cj, 'delay': delay})
            exc_2_r = sim.Connect(exc[i],r[i],'one_to_one',syn_spec={'weight': w_ex_cj, 'delay': delay})

            inh_2_l = sim.Connect(inh[i],l[i],'all_to_all',syn_spec={'weight': -w_in_l_cj_gauss, 'delay': delay})
            inh_2_r = sim.Connect(inh[i],r[i],'all_to_all',syn_spec={'weight': -w_in_r_cj_gauss, 'delay': delay})

    ## Wire everything up

    # Connect a N-wide window of the ring to each principle axis cell

    windows = []

    for ring in exc:

        for i in range(0, N_ex, window_size):

            window = ring[((i+rp_offset) % N_ex):((i+rp_offset) % N_ex) + window_size]

            if i + rp_offset + window_size > N_ex and len(window) < window_size:

                window = window + ring[0:rp_offset]

            windows.append(window)

    for i in range(rings):

        for j in range(N_pa_cells_per_ring):

            sim.Connect(windows[i*N_pa_cells_per_ring+j], [pa_cells[i][j]],'all_to_all',syn_spec={'weight': w_ex_pa, 'delay': delay})

    total_pa_cells = N_pa_cells_per_ring * rings

    cell_indices = np.zeros(shape = (N_pyramidals))
    target_cells = np.zeros(shape = (N_pyramidals))
    source_pa_cells = np.zeros(shape = (rings, N_pyramidals))
    source_pa_cell_indices = np.zeros(shape = (rings, N_pyramidals))

    in_range = True

    for r1 in (x for x in range(N_pa_cells_per_ring) if in_range is True): # Ring 1

        for r2 in (y for y in range(N_pa_cells_per_ring) if in_range is True): # Ring 2

            for r3 in (z for z in range(N_pa_cells_per_ring) if in_range is True): # Ring 3

                #cell_index = (((r1 + rp_offset) % N_pa_cells_per_ring) * N_pa_cells_per_ring ** 2) + (((r2 + rp_offset) % N_pa_cells_per_ring) * N_pa_cells_per_ring) + ((r3 + rp_offset) % N_pa_cells_per_ring) # Steps from 0 to max
                cell_index = (r1 * N_pa_cells_per_ring ** 2) + (r2 * N_pa_cells_per_ring) + r3 # Steps from 0 to max

                if cell_index != N_pyramidals:

                    target_cell = pyramidal_cells[cell_index]

                    sim.Connect([pa_cells[0][r1]], [target_cell],'all_to_all',syn_spec={'weight': w_pa_py, 'delay': delay})
                    sim.Connect([pa_cells[1][r2]], [target_cell],'all_to_all',syn_spec={'weight': w_pa_py, 'delay': delay})
                    sim.Connect([pa_cells[2][r3]], [target_cell],'all_to_all',syn_spec={'weight': w_pa_py, 'delay': delay})

                    # Gather up data for Pandas, to be used later in grid cell evalutation

                    cell_indices[cell_index] = cell_index
                    target_cells[cell_index] = target_cell
                    source_pa_cells[0][cell_index] = pa_cells[0][r1]
                    source_pa_cells[1][cell_index] = pa_cells[1][r2]
                    source_pa_cells[2][cell_index] = pa_cells[2][r3]
                    source_pa_cell_indices[0][cell_index] = r1
                    source_pa_cell_indices[1][cell_index] = r2
                    source_pa_cell_indices[2][cell_index] = r3

    #             else:

    #                 in_range = False

    pa_to_pyramidal_connections = pd.DataFrame({'Target Cell Index': cell_indices,
                                                'Target Pyramidal Cell': target_cells,
                                                'Ring 1 Index': source_pa_cell_indices[0],
                                                'Ring 2 Index': source_pa_cell_indices[1],
                                                'Ring 3 Index': source_pa_cell_indices[2],
                                                'Ring 1 PA Cell': source_pa_cells[0],
                                                'Ring 2 PA Cell': source_pa_cells[1],
                                                'Ring 3 PA Cell': source_pa_cells[2],})

    # Now do the same but in the opposite direction; assign each unique combination of RP cells an incoming 'grid input cell'
    # that can be driven by external input cues

    # For convenience, and to save modelling an extra population of 'pass through' cells, input devices current synapse directly onto the 
    # excitatory ring, at the midpoint of the RP receptive field

    # The non-existent input RP cells are given here as placeholders, for managing input to the ring in the format plausible for 
    # downstream brain areas to be aware of; it is assumed that the state of the rings themselves is too granular

    import pandas as pd

    total_pa_cells = N_pa_cells_per_ring * rings

    source_cell_indices = np.zeros(shape = (N_pyramidals))
    source_cells = np.zeros(shape = (N_pyramidals))
    target_ring_cells = np.zeros(shape = (rings, N_pyramidals))
    target_ring_cell_indices = np.zeros(shape = (rings, N_pyramidals))
    target_virtual_rp = np.zeros(shape = (rings, N_pyramidals))

    in_range = True

    for r1 in (x for x in range(N_pa_cells_per_ring) if in_range is True): # Ring 1

        for r2 in (y for y in range(N_pa_cells_per_ring) if in_range is True): # Ring 2

            for r3 in (z for z in range(N_pa_cells_per_ring) if in_range is True): # Ring 3

                #source_cell_index = (((r1 + rp_offset) % N_pa_cells_per_ring) * N_pa_cells_per_ring ** 2) + (((r2 + rp_offset) % N_pa_cells_per_ring) * N_pa_cells_per_ring) + ((r3 + rp_offset) % N_pa_cells_per_ring) # Steps from 0 to max
                source_cell_index = (r1 * N_pa_cells_per_ring ** 2) + (r2 * N_pa_cells_per_ring) + r3 # Steps from 0 to max

                if cell_index != N_pyramidals:

                    source_cell = input_grid_devices[source_cell_index]

                    target_cell_index_r1 = (r1 + 1) * window_size - (window_size // 2) - 1
                    target_cell_index_r2 = (r2 + 1) * window_size - (window_size // 2) - 1
                    target_cell_index_r3 = (r3 + 1) * window_size - (window_size // 2) - 1

                    target_cell_r1 = exc[0][target_cell_index_r1]
                    target_cell_r2 = exc[1][target_cell_index_r2]
                    target_cell_r3 = exc[2][target_cell_index_r3]

                    sim.Connect([source_cell], [target_cell_r1], 'all_to_all')#, syn_spec={'weight': w_in_pa, 'delay': delay})
                    sim.Connect([source_cell], [target_cell_r2], 'all_to_all')#, syn_spec={'weight': w_in_pa, 'delay': delay})
                    sim.Connect([source_cell], [target_cell_r3], 'all_to_all')#, syn_spec={'weight': w_in_pa, 'delay': delay})

                    # Gather up data for Pandas, to be used later in grid cell evalutation

                    source_cell_indices[source_cell_index] = source_cell_index
                    source_cells[source_cell_index] = source_cell
                    target_virtual_rp[0][source_cell_index] = r1
                    target_virtual_rp[1][source_cell_index] = r2
                    target_virtual_rp[2][source_cell_index] = r3
                    target_ring_cells[0][source_cell_index] = target_cell_r1
                    target_ring_cells[1][source_cell_index] = target_cell_r2
                    target_ring_cells[2][source_cell_index] = target_cell_r3
                    target_ring_cell_indices[0][source_cell_index] = target_cell_index_r1
                    target_ring_cell_indices[1][source_cell_index] = target_cell_index_r2
                    target_ring_cell_indices[2][source_cell_index] = target_cell_index_r3

    #             else:

    #                 in_range = False

    input_device_to_ring_cells_connections = pd.DataFrame({ 'Source Device Index': source_cell_indices,
                                                            'Source Device': source_cells,
                                                            'Ring 1 Virtual RP' : target_virtual_rp[0],
                                                            'Ring 2 Virtual RP' : target_virtual_rp[1],
                                                            'Ring 3 Virtual RP' : target_virtual_rp[2],
                                                            'Ring 1 Index': target_ring_cell_indices[0],
                                                            'Ring 2 Index': target_ring_cell_indices[1],
                                                            'Ring 3 Index': target_ring_cell_indices[2],
                                                            'Ring 1 Cell': target_ring_cells[0],
                                                            'Ring 2 Cell': target_ring_cells[1],
                                                            'Ring 3 Cell': target_ring_cells[2],})

    sim.GetConnections([source_cell])

    ## Record spike activity

    # Single spike detectors, connected to all cells in the given population in a given ring
    # 'params' dictionary describes which variables to log; gid: global neuron id, time is in milliseconds

    exc_spikes = []
    inh_spikes = []
    pa_spikes = []
    left_cj_spikes = []
    right_cj_spikes = []

    for i in range(rings):

        exc_spikes.append(sim.Create("spike_detector", 1, params={"withgid": True,"withtime": True}))
        sim.Connect(exc[i],exc_spikes[i])

        inh_spikes.append(sim.Create("spike_detector", 1, params={"withgid": True,"withtime": True}))
        sim.Connect(inh[i],inh_spikes[i])

        pa_spikes.append(sim.Create("spike_detector", 1, params={"withgid": True,"withtime": True}))
        sim.Connect(pa_cells[i],pa_spikes[i])

        left_cj_spikes.append(sim.Create("spike_detector", 1, params={"withgid": True,"withtime": True}))
        sim.Connect(l[i],left_cj_spikes[i])

        right_cj_spikes.append(sim.Create("spike_detector", 1, params={"withgid": True,"withtime": True}))
        sim.Connect(r[i],right_cj_spikes[i])

    pyramidal_spikes = sim.Create("spike_detector", 1, params={"withgid": True,"withtime": True})
    sim.Connect(pyramidal_cells,pyramidal_spikes)

    ## Input to the network

    ## PredNet representations

    representation_root_folder = 'data/NRP_reps/'

    representation_folders = [ "playground_ordered_testset1",
                               "playground_ordered_testset2",
                               "playground_ordered_testset3",
                               "playground_ordered_testset4",
                               "playground_ordered_testset5",
                               "playground_ordered_testset6" ]

    representations = [np.load(representation_root_folder + folder + "/both/representations.npy") for folder in representation_folders]

    representations_per_dataset = [reps.shape[0] for reps in representations]

    representations = np.vstack(representations)

    ### Position data

    # Two options for getting input data for the network. In all cases, position data is loaded/generated, then velocity derived in later cells
    # 'spiral': original spiral trajectory
    # 'rat': load one of the Sargolini group's datasets, from real rat foraging task data

    if spiral_or_rat == 'spiral':

        number_of_turns = 300
        numT = number_of_turns * 1000 * np.pi
        print(numT/1000)
        dt = 20
        t = np.arange(0,sim_len,dt)*1.
        time = [i * 1. for i in t if i < sim_len]
        ts = np.arange(0,numT,numT/len(t))/1000.
        V = 30
        dr = 5
        ph = -np.sqrt(((V * (4*np.pi) * ts) / dr))
        ra =  np.sqrt(((V * dr * ts) / np.pi))

        pos_x = ra * np.cos(ph) 
        pos_y = ra * np.sin(ph)

    elif spiral_or_rat == 'rat':

        # Roughly 92 20ms timesteps per representation

        start_sample = 0
        N_data_samples = 4000

        from scipy.ndimage import median_filter, gaussian_filter

        # Load rat trajectory data from file

        data_folders = ['NRP_2021_testset1', 'NRP_2021_testset2', 'NRP_2021_testset3',
                        'NRP_2021_testset4', 'NRP_2021_testset5', 'NRP_2021_testset6']

        #data_folders = ['NRP_2021_testset1', 'NRP_2021_testset2']

        rat_dataset = [np.loadtxt('data/NRP_data/{}/raw_pose.csv'.format(folder), skiprows = 1, delimiter = ',') for folder in data_folders]

        # Find where each dataset ends. Timesteps also reset at these points.

        dataset_ends = [len(testset) for testset in rat_dataset]

        # Find where they will end in the combined dataset

        dataset_ends_cumulative = [sum(dataset_ends[:i]) - 1 for i in range(1,len(dataset_ends)+1)]

        # Apply the required fixes as per the README

        rat_dataset[2][:,1] = rat_dataset[2][:,1] - 3 # Testset 3 needs adjusting -3m X and -6.1 Y
        rat_dataset[2][:,2] = rat_dataset[2][:,2] - 6.1# Testset 3 needs adjusting -3m X and -6.1 Y
        rat_dataset[3][:,2] = rat_dataset[3][:,2] - 6.1# Testset 4 needs adjusting -6.1 Y

        rat_dataset = np.vstack(rat_dataset)

        # Find the actual timestamps where each dataset end

        dataset_max_timestamps = rat_dataset[dataset_ends_cumulative, 0]

        # For each window of data, add on the previous maximum value; this should give a combined dataset with a monotonically increasing timestep

        dataset_cumulative_max_timestamps = [sum(dataset_max_timestamps[:i]) for i in range(1,len(dataset_max_timestamps)+1)]

        rat_dataset_timesteps_in_sequence = rat_dataset.copy()

        rat_dataset_timesteps_in_sequence[dataset_ends_cumulative[0]+1:, 0] = np.hstack([rat_dataset[i+1:j+1, 0] + m for i, j, m in zip(dataset_ends_cumulative[:-1], dataset_ends_cumulative[1:], dataset_cumulative_max_timestamps[:-1])])

        assert np.all(rat_dataset_timesteps_in_sequence[1:, 0] > rat_dataset_timesteps_in_sequence[:-1, 0])

        # Now do the same for the representations

        rat_representation_timestamps = [np.load('data/NRP_data/{}/representation_matched_poses.npy'.format(folder))[:rep_count, 0] for folder, rep_count in zip(data_folders, representations_per_dataset)]

        # Find where each dataset ends. Timesteps also reset at these points.

        dataset_ends = [len(testset) for testset in rat_representation_timestamps]

        # Find where they will end in the combined dataset

        dataset_ends_cumulative = [sum(dataset_ends[:i]) - 1 for i in range(1,len(dataset_ends)+1)]

        rat_representation_timestamps = np.hstack(rat_representation_timestamps)

        # Find the actual timestamps where each dataset end

        dataset_max_timestamps = rat_representation_timestamps[dataset_ends_cumulative]

        # For each window of data, add on the previous maximum value; this should give a combined dataset with a monotonically increasing timestep

        dataset_cumulative_max_timestamps = [sum(dataset_max_timestamps[:i]) for i in range(1,len(dataset_max_timestamps)+1)]

        rat_representations_timesteps_in_sequence = rat_representation_timestamps.copy()

        rat_representations_timesteps_in_sequence[dataset_ends_cumulative[0]+1:] = np.hstack([rat_representation_timestamps[i+1:j+1] + m for i, j, m in zip(dataset_ends_cumulative[:-1], dataset_ends_cumulative[1:], dataset_cumulative_max_timestamps[:-1])])

        assert np.all(rat_representations_timesteps_in_sequence[1:] > rat_representations_timesteps_in_sequence[:-1])

        rat_dataset_timesteps_in_sequence = rat_dataset_timesteps_in_sequence[start_sample:start_sample+N_data_samples]
        #rat_representations_timesteps_in_sequence = rat_representations_timesteps_in_sequence[:N_representations]
        rat_representations_timesteps_in_sequence = rat_representations_timesteps_in_sequence[(rat_representations_timesteps_in_sequence > rat_dataset_timesteps_in_sequence[0,0]) & (rat_representations_timesteps_in_sequence < rat_dataset_timesteps_in_sequence[-1,0])]

        # Now find where injections are required

        rat_injection_index = np.searchsorted(rat_dataset_timesteps_in_sequence[:, 0], rat_representations_timesteps_in_sequence)

        # Get rid of any redundant injections (only 2 at last check, so doesn't seem to be a systematic issue)

        rat_injection_index = rat_injection_index[np.nonzero(np.diff(rat_injection_index) > 0)]

        assert np.all(rat_injection_index[1:] > rat_injection_index[:-1])

        # Trim to desired interval

        # print(max(rat_dataset_timesteps_in_sequence[:,0]))
        # print(max(rat_representations_timesteps_in_sequence))
        # print(max(rat_injection_index))

        # rat_dataset_timesteps_in_sequence = rat_dataset_timesteps_in_sequence[:N_data_samples]
        # rat_representations_timesteps_in_sequence = rat_representations_timesteps_in_sequence[:N_representations]
    #     rat_injection_index = rat_injection_index[rat_injection_index < int(max(rat_representations_timesteps_in_sequence))]

    #     print(max(rat_dataset_timesteps_in_sequence[:,0]))
    #     print(max(rat_representations_timesteps_in_sequence))
    #     print(max(rat_injection_index))


        # Create variables for velocity calculation later

        pos_x = rat_dataset_timesteps_in_sequence[:,1]
        #pos_x = gaussian_filter(pos_x, sigma = smooth_sigma, mode = 'nearest')
        pos_y = rat_dataset_timesteps_in_sequence[:,2]
        #pos_y = gaussian_filter(pos_y, sigma = smooth_sigma, mode = 'nearest')

        theta = rat_dataset_timesteps_in_sequence[:,3]

        #print(f"Original count {len(pos_x)}")
        #timestamps = np.arange(0, len(rat_dataset_timesteps_in_sequence)) / 50 # Hz
        timestamps = rat_dataset_timesteps_in_sequence[:, 0] - np.min(rat_dataset_timesteps_in_sequence[:, 0])

        timestamps = timestamps / 1000000000 # Get from nanoseconds to seconds

        representation_timestamps = timestamps[rat_injection_index]

        #print(f"Representations: {len(representation_timestamps)}")

        #print(np.diff(timestamps))
        #print(np.mean(np.diff(timestamps)))
        #print(np.mean(np.diff(representation_timestamps)))

        time = timestamps * 1000 # Get from seconds to milliseconds

        representation_times = time[rat_injection_index]

    else:

        raise ValueError

    ### Calculate velocity

    # Calculate velocity to convert to step current
    # As with head direction network, small values are boosted, but values that were 0 or less are set back to 0

    vel_x = np.diff(pos_x)
    vel_y = np.diff(pos_y)

    if calibration_mode:

        vel_x = np.zeros_like(vel_x)
        vel_y = np.zeros_like(vel_y)

    vel_x,vel_y = vel_x*I_vel, vel_y*I_vel

    # Now we split this across the rings according to their direction of travel

    # Axes are:
    # Y, as usual
    # X_plus_60 (60 degree offset from Y around origin, diagonal bottom-right to upper-left)
    # X_plus_120 (120 degree offset from Y around origin, diagonal bottom-left to upper-right)

    velocity_magnitude = np.sqrt(vel_x ** 2 + vel_y ** 2)

    #if velocity_threshold:

    #    velocity_magnitude = velocity_magnitude + minimum_velocity

    #    velocity_magnitude[velocity_magnitude < minimum_velocity] = 0.

    velocity_angle = np.arctan2(vel_y, vel_x)

    # Calculate overall components for use in later analysis

    Y_input_total = velocity_magnitude * np.cos(velocity_angle)

    Y_plus_60_offset = np.radians(60)
    Y_plus_120_offset = np.radians(120)

    Y_plus_60_input_total = velocity_magnitude * np.cos(velocity_angle - Y_plus_60_offset)
    Y_plus_120_input_total = velocity_magnitude * np.cos(velocity_angle - Y_plus_120_offset)

    # Now split into positive and negative to feed to left and right conjunctive cells respectively

    velocity_component = np.cos(velocity_angle)

    positive_component = velocity_component.copy()
    negative_component = velocity_component.copy()

    positive_component[positive_component < 0] = 0.
    negative_component[negative_component > 0] = 0.

    Y_input_l = velocity_magnitude * positive_component
    Y_input_r = velocity_magnitude * negative_component

    Y_input_r = -Y_input_r

    Y_input_l_compliment = velocity_magnitude - Y_input_l
    Y_input_r_compliment = velocity_magnitude - Y_input_r

    Y_input_l = velocity_magnitude + Y_input_l + Y_input_r_compliment + minimum_input
    Y_input_r = velocity_magnitude + Y_input_r + Y_input_l_compliment + minimum_input


    velocity_component_60 = np.cos(velocity_angle - Y_plus_60_offset)

    positive_component_60 = velocity_component_60.copy()
    negative_component_60 = velocity_component_60.copy()

    positive_component_60[positive_component_60 < 0] = 0.
    negative_component_60[negative_component_60 > 0] = 0.

    Y_plus_60_input_l = velocity_magnitude * positive_component_60 # np.cos(positive_angle - Y_plus_60_offset)
    Y_plus_60_input_r = velocity_magnitude * negative_component_60 # np.cos(negative_angle - Y_plus_60_offset)

    Y_plus_60_input_r = -Y_plus_60_input_r

    Y_plus_60_input_l_compliment = velocity_magnitude - Y_plus_60_input_l
    Y_plus_60_input_r_compliment = velocity_magnitude - Y_plus_60_input_r

    Y_plus_60_input_l = velocity_magnitude + Y_plus_60_input_l + Y_plus_60_input_r_compliment + minimum_input
    Y_plus_60_input_r = velocity_magnitude + Y_plus_60_input_r + Y_plus_60_input_l_compliment + minimum_input


    velocity_component_120 = np.cos(velocity_angle - Y_plus_120_offset)

    positive_component_120 = velocity_component_120.copy()
    negative_component_120 = velocity_component_120.copy()

    positive_component_120[positive_component_120 < 0] = 0.
    negative_component_120[negative_component_120 > 0] = 0.

    Y_plus_120_input_l = velocity_magnitude * positive_component_120 # np.cos(positive_angle - Y_plus_120_offset)
    Y_plus_120_input_r = velocity_magnitude * negative_component_120 # np.cos(negative_angle - Y_plus_120_offset)

    Y_plus_120_input_r = -Y_plus_120_input_r

    Y_plus_120_input_l_compliment = velocity_magnitude - Y_plus_120_input_l
    Y_plus_120_input_r_compliment = velocity_magnitude - Y_plus_120_input_r

    Y_plus_120_input_l = velocity_magnitude + Y_plus_120_input_l + Y_plus_120_input_r_compliment + minimum_input
    Y_plus_120_input_r = velocity_magnitude + Y_plus_120_input_r + Y_plus_120_input_l_compliment + minimum_input


    if not stochastic_input_current:
    
        # Connect y input to conjunctive layers

        y_l_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(y_l_input,{'amplitude_times': time[1:],'amplitude_values': Y_input_l})

        y_r_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(y_r_input,{'amplitude_times': time[1:],'amplitude_values': Y_input_r})

        sim.Connect(y_l_input,l[0],'all_to_all')
        sim.Connect(y_r_input,r[0],'all_to_all')

        # Connect y_plus_60 input to conjunctive layers

        Y_plus_60_l_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_60_l_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_60_input_l})
        Y_plus_60_r_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_60_r_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_60_input_r})

        sim.Connect(Y_plus_60_l_input,l[1],'all_to_all')
        sim.Connect(Y_plus_60_r_input,r[1],'all_to_all')

        # Connect y_plus_120 input to conjunctive layers

        Y_plus_120_l_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_120_l_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_120_input_l})
        Y_plus_120_r_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_120_r_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_120_input_r})

        sim.Connect(Y_plus_120_l_input,l[2],'all_to_all')
        sim.Connect(Y_plus_120_r_input,r[2],'all_to_all')
        
        input_current_sigma = None
        
    elif stochastic_input_current:
        
        N_samples = len(Y_input_l)
        
        input_current_sigma = np.mean(np.diff(Y_input_l)) / 2
        
        Y_input_l = Y_input_l + input_rng.normal(scale = input_current_sigma, size = N_samples)
        Y_input_r = Y_input_r + input_rng.normal(scale = input_current_sigma, size = N_samples)
        Y_plus_60_input_l = Y_plus_60_input_l + input_rng.normal(scale = input_current_sigma, size = N_samples)
        Y_plus_60_input_r = Y_plus_60_input_r + input_rng.normal(scale = input_current_sigma, size = N_samples)
        Y_plus_120_input_l = Y_plus_120_input_l + input_rng.normal(scale = input_current_sigma, size = N_samples)
        Y_plus_120_input_l = Y_plus_120_input_l + input_rng.normal(scale = input_current_sigma, size = N_samples)
        
        # Connect y input to conjunctive layers

        y_l_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(y_l_input,{'amplitude_times': time[1:],'amplitude_values': Y_input_l})

        y_r_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(y_r_input,{'amplitude_times': time[1:],'amplitude_values': Y_input_r})

        sim.Connect(y_l_input,l[0],'all_to_all')
        sim.Connect(y_r_input,r[0],'all_to_all')

        # Connect y_plus_60 input to conjunctive layers

        Y_plus_60_l_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_60_l_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_60_input_l})
        Y_plus_60_r_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_60_r_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_60_input_r})

        sim.Connect(Y_plus_60_l_input,l[1],'all_to_all')
        sim.Connect(Y_plus_60_r_input,r[1],'all_to_all')

        # Connect y_plus_120 input to conjunctive layers

        Y_plus_120_l_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_120_l_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_120_input_l})
        Y_plus_120_r_input = sim.Create('step_current_generator', 1)
        sim.SetStatus(Y_plus_120_r_input,{'amplitude_times': time[1:],'amplitude_values': Y_plus_120_input_r})

        sim.Connect(Y_plus_120_l_input,l[2],'all_to_all')
        sim.Connect(Y_plus_120_r_input,r[2],'all_to_all')

    ### Bump-forming current generator

    # Inject current for a given duration to start the bump off

    bump_init = sim.Create('step_current_generator', 1, params = {'amplitude_times':[0.1,0.1+I_init_dur],'amplitude_values':[I_init,0.0]})
    sim.Connect(bump_init,[exc[0][I_init_pos]])
    sim.Connect(bump_init,[exc[1][I_init_pos]])
    sim.Connect(bump_init,[exc[2][I_init_pos]])

    ## Run simulation in tandem with representation corrections

    # Get the memory set up and data loaded
    
    sense_memories = [] # Save representations from PredNet
    location_memories = [] # Save ring coordinates from NEST

    recall_threshold = 0.8

    history_timestep_window = 5

    # Run simulation in small steps, injecting when required

    # From the docs: Calling SetStatus() inside a RunManager() context or between Prepare() and Cleanup() will lead to unpredictable results.

    time_deltas = np.diff(time).astype(int)

    start_delay = 100

    no_spike_timeout = 5

    recalls = 0
    new_memories = 0

    correction_duration = 500. # 500.
    correction_current = 450. # 450 is the sweet spot to apply corrections without wild spinning of the bump
    confidence_scaling = False # Multiply correction_current by Pearson correlation

    injection_type = "mono"

    current_exc_state = np.zeros(shape = (rings, N_ex))
    exc_state_history = np.zeros(shape = (len(timestamps), rings, N_ex))

    current_rp_state = np.zeros(shape = (rings, N_pa_cells_per_ring))

    injection_state_history = np.zeros_like(exc_state_history)

    memory_state_history = []#pd.DataFrame(columns = ("Timestep", "Memory Created", "Recalled Memory", 
                           #                        "Active Ring 1 RP", "Active Ring 2 RP", "Active Ring 3 RP",
                           #                        "Recalled Ring 1 RP", "Recalled Ring 2 RP", "Recalled Ring 3 RP"))

    representation_index = 0

    input_device = None

    if simulate_or_load == 'simulate':

        for time_delta, t, tick in zip(time_deltas, time, range(len(time))):

            if input_device is not None and corrections == True and np.all(None not in current_most_active_rp) and t == representation_times[representation_index]:

                if injection_type == 'mono':

                    if not confidence_scaling:

                        sim.SetStatus([input_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current, 0.0]})

                    elif confidence_scaling and best_match_value is not None:

                        sim.SetStatus([input_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current * best_match_value, 0.0]})

                        best_match_value == None

                elif injection_type == 'gaussian':

                    left_device = input_device - 1
                    right_device = input_device + 1

                    if left_device not in input_grid_devices:

                        left_device = input_grid_devices[-1]

                    if right_device not in input_grid_devices:

                        right_device = input_grid_devices[0]

                    if not confidence_scaling:

                        sim.SetStatus([input_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current, 0.0]})
                        sim.SetStatus([left_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current, 0.0]})
                        sim.SetStatus([right_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current, 0.0]})

                    elif confidence_scaling and best_match_value is not None:

                        sim.SetStatus([input_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current * best_match_value, 0.0]})
                        sim.SetStatus([left_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current * best_match_value, 0.0]})
                        sim.SetStatus([right_device], {'amplitude_times': [t, t+correction_duration],'amplitude_values': [correction_current * best_match_value, 0.0]})

                        best_match_value == None

                    input_device == None

            sim.Prepare()

            sim.Run(time_delta)

            ring1_exc, ring1_spikes_exc = np.unique(sim.GetStatus(exc_spikes[0])[0]['events']['senders'], return_counts = True)
            ring2_exc, ring2_spikes_exc = np.unique(sim.GetStatus(exc_spikes[1])[0]['events']['senders'], return_counts = True)
            ring3_exc, ring3_spikes_exc = np.unique(sim.GetStatus(exc_spikes[2])[0]['events']['senders'], return_counts = True)

            current_exc_state[0, ring1_exc-min(exc[0])] = ring1_spikes_exc if ring1_spikes_exc.size > 0 else 0
            current_exc_state[1, ring2_exc-min(exc[1])] = ring2_spikes_exc if ring2_spikes_exc.size > 0 else 0
            current_exc_state[2, ring3_exc-min(exc[2])] = ring3_spikes_exc if ring3_spikes_exc.size > 0 else 0

            ring1_most_active_exc_index = np.argmax(current_exc_state[0, :]) if np.argmax(current_exc_state[0, :]) is not None else None
            ring2_most_active_exc_index = np.argmax(current_exc_state[1, :]) if np.argmax(current_exc_state[1, :]) is not None else None
            ring3_most_active_exc_index = np.argmax(current_exc_state[2, :]) if np.argmax(current_exc_state[2, :]) is not None else None

            current_most_active_exc = (ring1_most_active_exc_index, ring2_most_active_exc_index, ring3_most_active_exc_index)

            if np.all(None not in current_most_active_exc) and t > start_delay:

                exc_state_history[tick, 0, ring1_exc-min(exc[0])] = ring1_spikes_exc
                exc_state_history[tick, 1, ring2_exc-min(exc[1])] = ring2_spikes_exc
                exc_state_history[tick, 2, ring3_exc-min(exc[2])] = ring3_spikes_exc

            ring1_rp, ring1_spikes_rp = np.unique(sim.GetStatus(pa_spikes[0])[0]['events']['senders'], return_counts = True)
            ring2_rp, ring2_spikes_rp = np.unique(sim.GetStatus(pa_spikes[1])[0]['events']['senders'], return_counts = True)
            ring3_rp, ring3_spikes_rp = np.unique(sim.GetStatus(pa_spikes[2])[0]['events']['senders'], return_counts = True)

            current_rp_state[0, ring1_rp-min(pa_cells[0])] = ring1_spikes_rp if ring1_spikes_rp.size > 0 else 0
            current_rp_state[1, ring2_rp-min(pa_cells[1])] = ring2_spikes_rp if ring2_spikes_rp.size > 0 else 0
            current_rp_state[2, ring3_rp-min(pa_cells[2])] = ring3_spikes_rp if ring3_spikes_rp.size > 0 else 0

            ring1_most_active_rp_index = np.argmax(current_rp_state[0, :]) if np.argmax(current_rp_state[0, :]) is not None else None
            ring2_most_active_rp_index = np.argmax(current_rp_state[1, :]) if np.argmax(current_rp_state[1, :]) is not None else None
            ring3_most_active_rp_index = np.argmax(current_rp_state[2, :]) if np.argmax(current_rp_state[2, :]) is not None else None

            current_most_active_rp = (ring1_most_active_rp_index, ring2_most_active_rp_index, ring3_most_active_rp_index)

            recalled_ring_state_r1, recalled_ring_state_r2, recalled_ring_state_r3 = -1, -1, -1

            sim.Cleanup()

            # Reset spike detectors so that per-time-delta spike counts are recorded only

            if tick % history_timestep_window == 0:

                sim.SetStatus(exc_spikes[0], {'n_events': 0})
                sim.SetStatus(exc_spikes[1], {'n_events': 0})
                sim.SetStatus(exc_spikes[2], {'n_events': 0})

            if corrections == True and np.all(None not in current_most_active_rp) and t == representation_times[representation_index]:

                if len(sense_memories) == 0:

                    sense_memories.append(representations[representation_index])

                if len(location_memories) == 0:

                    location_memories.append(current_most_active_rp)
                    #location_memories.append((0,0,0))

                    memory_state_history.append({"Timestep": tick, "Memory Created": new_memories, "Active Ring 1": current_most_active_rp[0], 
                                                                        "Active Ring 2": current_most_active_rp[1], "Active Ring 3": current_most_active_rp[2]})

                pearson_coorelation = np.empty(shape = (len(sense_memories)))

                for sm, sense_memory in enumerate(sense_memories):

                    pearson_coorelation[sm] = pearsonr(representations[representation_index], sense_memory)[0]

                best_match_index = np.argmax(pearson_coorelation) if pearson_coorelation.size > 0 else None

                best_match_element = sense_memories[best_match_index] if best_match_index is not None else None

                best_match_value = pearson_coorelation[best_match_index] if best_match_index is not None else None


                if best_match_value is None and timestamp > start_delay:

                    no_spike_timeout -= 1


                elif best_match_value >= recall_threshold and t > representation_times[0]: # The 2nd condition is to prevent recalls happening immediately for the first (automatically added) memory. Recalls can still happen later to this memory as per usual

                    recalls += 1

                    recalled_ring_state_r1, recalled_ring_state_r2, recalled_ring_state_r3 = location_memories[best_match_index]

                    # recalled_ring_state_r1 += 1
                    # recalled_ring_state_r2 += 1
                    # recalled_ring_state_r3 += 1

    #                     ring1_injection_device = input_device_to_ring_cells_connections[pa_to_pyramidal_connections["Ring 1 Index"] == recalled_ring_state_r1]["Source Device"]
    #                     ring2_injection_device = input_device_to_ring_cells_connections[pa_to_pyramidal_connections["Ring 2 Index"] == recalled_ring_state_r2]["Source Device"]
    #                     ring3_injection_device = input_device_to_ring_cells_connections[pa_to_pyramidal_connections["Ring 3 Index"] == recalled_ring_state_r3]["Source Device"]

                    #try:

                    input_device = int(input_device_to_ring_cells_connections.query("`Ring 1 Virtual RP` == @recalled_ring_state_r1 and `Ring 2 Virtual RP` == @recalled_ring_state_r2 and `Ring 3 Virtual RP` == @recalled_ring_state_r3")['Source Device'].values[0])

                    #except:

                        #print(f"Query unsuccessful; likely corrective input targets do not exist on ring. Corrective input targets 1,2,3: {recalled_ring_state_r1},{recalled_ring_state_r2},{recalled_ring_state_r3}")

                    #print(f"Latest exc spike time: {sim.GetStatus(exc_spikes[0])[0]['events']['times'][-1]}")

                    target_exc_cell_r1 = int(input_device_to_ring_cells_connections.query("`Ring 1 Virtual RP` == @recalled_ring_state_r1 and `Ring 2 Virtual RP` == @recalled_ring_state_r2 and `Ring 3 Virtual RP` == @recalled_ring_state_r3")['Ring 1 Index'].values[0])
                    target_exc_cell_r2 = int(input_device_to_ring_cells_connections.query("`Ring 1 Virtual RP` == @recalled_ring_state_r1 and `Ring 2 Virtual RP` == @recalled_ring_state_r2 and `Ring 3 Virtual RP` == @recalled_ring_state_r3")['Ring 2 Index'].values[0])
                    target_exc_cell_r3 = int(input_device_to_ring_cells_connections.query("`Ring 1 Virtual RP` == @recalled_ring_state_r1 and `Ring 2 Virtual RP` == @recalled_ring_state_r2 and `Ring 3 Virtual RP` == @recalled_ring_state_r3")['Ring 3 Index'].values[0])

                    #print(f"Injecting with device {input_device} targetting RPs: {(recalled_ring_state_r1, recalled_ring_state_r2, recalled_ring_state_r3)} at time {t}, timestep: {tick} until {t+correction_duration}")

                    injection_state_history[tick:tick+int(correction_duration // 20), 0, target_exc_cell_r1] = 1
                    injection_state_history[tick:tick+int(correction_duration // 20), 1, target_exc_cell_r2] = 1
                    injection_state_history[tick:tick+int(correction_duration // 20), 2, target_exc_cell_r3] = 1

                    memory_state_history.append({"Timestep": tick, "Recalled Memory": best_match_index, "Recalled Ring 1": recalled_ring_state_r1, 
                                                                        "Recalled Ring 2": recalled_ring_state_r2, "Recalled Ring 3": recalled_ring_state_r3}) 

                elif best_match_value < recall_threshold:

                    new_memories += 1

                    sense_memories.append(representations[representation_index])

                    location_memories.append(current_most_active_rp)

                    memory_state_history.append({"Timestep": tick, "Memory Created": new_memories, "Active Ring 1": current_most_active_rp[0], 
                                                                        "Active Ring 2": current_most_active_rp[1], "Active Ring 3": current_most_active_rp[2]}) 

    #             else:

    #                 print("Help")

                if no_spike_timeout <= 0:

                    print(f"No spikes have been recorded for {no_spike_timeout} input cycles, stopping...")

                    break

                if representation_index < (len(representation_times) - 1):

                    representation_index += 1

            print(f"Timestep: {tick+1}/{len(time_deltas)}; Sim Time: {int(t)}; Ring State: {current_most_active_exc}; RP State: {current_most_active_rp}; Injection Sites: {(recalled_ring_state_r1, recalled_ring_state_r2, recalled_ring_state_r3)}; Memories Stored: {new_memories}; Recall events: {recalls}", end = '\r')

    pd.DataFrame(memory_state_history)

    if corrections:

        np.save("corrected_exc_state_history.npy", exc_state_history)

    elif not corrections:

        np.save("uncorrected_exc_state_history.npy", exc_state_history)


    circular_mean = True # If False, ring position is taken to be the index of the most active cell

    metric_plot = 'distance'

    cumulative_or_mean = 'mean'
    

    cell_indices_for_plotting = (np.linspace(0, 2*np.pi, num = N_ex, endpoint = False) + ((2*np.pi) / N_ex)) % N_ex

    most_active_cell_indices = ((np.argmax(exc_state_history, axis = 2) + 1) / N_ex * 2*np.pi) % N_ex
    most_active_cell_values = np.max(exc_state_history, axis = 2)

    
    ring_1_mean_activity_index = ring_mean_activity(exc_state_history[:, 0, :], centre = True)
    ring_2_mean_activity_index = ring_mean_activity(exc_state_history[:, 1, :], centre = True)
    ring_3_mean_activity_index = ring_mean_activity(exc_state_history[:, 2, :], centre = True)

    if circular_mean:

        ring_1_unwrapped = np.unwrap(ring_1_mean_activity_index, period = N_ex // (1 * np.pi))
        ring_2_unwrapped = np.unwrap(ring_2_mean_activity_index, period = N_ex // (1 * np.pi))
        ring_3_unwrapped = np.unwrap(ring_3_mean_activity_index, period = N_ex // (1 * np.pi))

    else:

        ring_1_unwrapped = np.unwrap(np.argmax(exc_state_history[:, 0, :], axis = 1), period = N_ex // (1 * np.pi))
        ring_2_unwrapped = np.unwrap(np.argmax(exc_state_history[:, 1, :], axis = 1), period = N_ex // (1 * np.pi))
        ring_3_unwrapped = np.unwrap(np.argmax(exc_state_history[:, 2, :], axis = 1), period = N_ex // (1 * np.pi))

    ring_xy = ring2cart(ring_1_unwrapped, ring_2_unwrapped, ring_3_unwrapped, offset = -90)

    ring_xy[:, 0] = (ring_xy[:, 0] - min(ring_xy[:, 0])) / (max(ring_xy[:, 0]) - min(ring_xy[:, 0]))
    ring_xy[:, 1] = (ring_xy[:, 1] - min(ring_xy[:, 1])) / (max(ring_xy[:, 1]) - min(ring_xy[:, 1]))

    ring_xy[:, 0] = ring_xy[:, 0] * (max(pos_x) - min(pos_x)) + min(pos_x)
    ring_xy[:, 1] = ring_xy[:, 1] * (max(pos_y) - min(pos_y)) + min(pos_y)


    if metric_plot == 'spearman':

        ground_truth_vector_magnitude = np.sqrt(pos_x[start_delay:] ** 2 + pos_y[start_delay:] ** 2)
        estimated_vector_magnitude = np.sqrt(ring_xy[start_delay:,0] ** 2 + ring_xy[start_delay:,1] ** 2)

        spearman_r_over_time = [spearmanr(ground_truth_vector_magnitude[:i+2], estimated_vector_magnitude[:i+2], alternative = 'greater')[0] for i in range(len(pos_x[start_delay:])-1)]
        spearman_p_over_time = [spearmanr(ground_truth_vector_magnitude[:i+3], estimated_vector_magnitude[:i+3], alternative = 'greater')[1] for i in range(len(pos_x[start_delay:])-1)]

    elif metric_plot == 'error':

        pointwise_error_over_time = np.abs(pos_x[:] - ring_xy[:, 0]) + np.abs(pos_y[:] - ring_xy[:, 1])
        cumulative_error_over_time = np.cumsum(pointwise_error_over_time)

        if cumulative_or_mean == 'mean':

            cumulative_error_over_time = cumulative_error_over_time / np.arange(1, len(cumulative_error_over_time) + 1)

    elif metric_plot == 'distance':

        pointwise_distance_over_time = np.sqrt((pos_x[:] - ring_xy[:, 0]) ** 2 + (pos_y[:] - ring_xy[:, 1]) ** 2)
        cumulative_distance_over_time = np.cumsum(pointwise_distance_over_time)

        if cumulative_or_mean == 'mean':

            cumulative_distance_over_time = cumulative_distance_over_time / np.arange(1, len(cumulative_distance_over_time) + 1)
            

    if not os.path.exists("results_window_size_15.csv"):

        results_dataframe = pd.DataFrame({"Master Seed": master_seed, "Membrane Seed": membrane_seed, "Input Seed": input_seed, "Input Sigma": input_current_sigma,
                                          "Corrected":corrections, "Minimum Input": minimum_input, "Window Size": window_size, "RP Offset": rp_offset,
                                          "Correction Duration": correction_duration, "Correction Current": correction_current, "Confidence Scaling": confidence_scaling,
                                          "Recall Threshold": recall_threshold, "Memories Stored": len(sense_memories), "Recalls": recalls, 
                                          "Mean Error": cumulative_distance_over_time[-1], "Final Position Error": pointwise_distance_over_time[-1]},
                                        index = [1])

        results_dataframe.to_csv("results_window_size_15.csv", index = False)

    else:

        results_dataframe = pd.read_csv("results_window_size_15.csv")

        new_results = pd.DataFrame({"Master Seed": master_seed, "Membrane Seed": membrane_seed, "Input Seed": input_seed, "Input Sigma": input_current_sigma,
                                    "Corrected":corrections, "Minimum Input": minimum_input, "Window Size": window_size, "RP Offset": rp_offset,
                                    "Correction Duration": correction_duration, "Correction Current": correction_current, "Confidence Scaling": confidence_scaling,
                                    "Recall Threshold": recall_threshold, "Memories Stored": len(sense_memories), "Recalls": recalls, 
                                    "Mean Error": cumulative_distance_over_time[-1], "Final Position Error": pointwise_distance_over_time[-1]},
                                    index = [len(results_dataframe) + 1])

        results_dataframe = pd.concat([results_dataframe, new_results])

        results_dataframe.to_csv("results_window_size_15.csv", index = False)