-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmazeAgentUtils_reviewResponses.py
2031 lines (1726 loc) · 92.9 KB
/
mazeAgentUtils_reviewResponses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from datetime import datetime
import numbers
from pprint import pprint as pprintq
import os
from scipy.stats import vonmises
from scipy.spatial import distance_matrix
import dill
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt
import tomplotlib.tomplotlib as tpl
tpl.figureDirectory = './figures'
tpl.setColorscheme(colorscheme=2)
#Default parameters for MazeAgent
defaultParams = {
#Maze params
'mazeType' : 'oneRoom', #type of maze, define in getMaze() function
'stateType' : 'gaussian', #feature on which to TD learn (onehot, gaussian, gaussianCS, circles, bump)
'movementPolicy' : 'raudies', #movement policy (raudies, random walk, windows screensaver)
'roomSize' : 1, #maze size scaling parameter, metres
'dt' : None, #simulation time disretisation (defualts to largest )
'dx' : 0.01, #space discretisation (for plotting, movement is continuous)
'speedScale' : 0.16, #movement speed scale, metres/second
'rotSpeedScale' : None, #rotational speed scale, radians/second
'initPos' : [0.1,0.1], #initial position [x0, y0], metres
'initDir' : [1,0], #initial direction, unit vector
'nCells' : None, #how many features to use
'centres' : None, #array of receptive field positions. Overwrites nCells
'sigma' : 1, #basis cell width scale (irrelevant for onehots)
'doorsClosed' : True, #whether doors are opened or closed in multicompartment maze
'reorderCells' : True, #whether to reorde the cell centres which have been provided
'firingRateLookUp' : False, #use quantised lookup table for firing rates
'biasDoorCross' : False, #if True, in twoRoom maze door crossings are biased towards
'biasWallFollow' : True, #if True, agent aligns to wall when gets too near.
#TD params
'tau' : 4, #TD decay time, seconds
'TDdx' : 0.01, #rough distance between TD learning updates, metres
'alpha' : 0.01, #TD learning rate
'successorFeatureNorm': 100, #linear scaling on successor feature definition found to improve learning stability
'TDreg' : 0.01, #L2 regularisation
#STDP params
'peakFiringRate' : 5, #peak firing rate of a cell (middle of place field,preferred theta phase)
'tau_STDP_plus' : 20e-3, #pre trace decay time
'tau_STDP_minus' : 40e-3, #post trace decay time
'a_STDP' : -0.4, #pre-before-post potentiation factor (post-before-pre = 1)
'eta' : 0.05, #STDP learning rate
'baselineFiringRate' : 0, #baseline firing rate for cells
'use_full_STDP_rule' : False, #whether to use full STDP rule
'online_mapping' : 'identity', #how to map CA3-->CA1 during learning
'rownorm' : False,
#Theta precession params
'thetaFreq' : 10, #theta frequency
'precessFraction' : 0.5, #fraction of 2pi the prefered phase moves through
'kappa' : 1, # von mises spread parameter
}
class MazeAgent():
"""MazeAgent defines an agent moving around a maze.
The agent moves according to a predefined movement policy
As the agent moves it learns a successor representation over state vectors according to a TD learning rule
The movement polcy is
(i) continuous in space. There is no discretisation of location. Time is discretised into steps of dt
(ii) completely decoupled from the TD learning.
TD learning is
(i) state general. i.e. it learns generic SRs for feature vectors which are not necessarily onehot. See de Cothi and Barry, 2020
(ii) time continuous. Defined in terms of a memory decay time tau, not unitless gamma. Any two states can be used fro a TD learning step irrespective of their seperation in time.
As the rat moves and learns its position and time stamps are continually saved. Periodically a snapshot of the current SR matrix and state of other parameters in the maze are also saved.
"""
def __init__(self,
params={},
loadFromFileCalled=None):
"""Sets the parameters of the maze anad agent (using default if not provided)
and initialises everything. This includes:
•initilising history dataframes
•making the maze (a dictionary of "walls" which cant be crossed)
•setting position, velocity, time
•discretising space into coordinates for later plotting
•initialising basis features (gaussian centres, fourier frequencies etc.)
•initialising SR matrix
Args:
params (dict, optional): A dictionary of parameters which you want to differ from the default. Defaults to {}.
"""
if loadFromFileCalled is not None:
self.loadFromFile(name=loadFromFileCalled)
else:
print("Setting parameters")
for key, value in defaultParams.items():
setattr(self, key, value)
self.updateParams(params)
print("Initialising")
self.initialise()
print("DONE")
def updateParams(self,
params : dict):
"""Updates parameters from a dictionary.
All parameters found in params will be updated to new value
Args:
params (dict): dictionary of parameters to change
initialise (bool, optional): [description]. Defaults to False.
"""
for key, value in params.items():
setattr(self, key, value)
def initialise(self): #should only be called once at the start
"""Initialises the maze and agent. Should only be called once at the start.
"""
#initialise history dataframes
print(" making state/history dataframes")
self.mazeState = {}
self.history = pd.DataFrame(columns = ['t','pos','delta','runID'])
self.snapshots = pd.DataFrame(columns = ['t','M','W','mazeState'])
self.spikedata = {'CA3':{'times':[],'ids':[]}, 'CA1':{'times':[],'ids':[]}}
#set pos/vel
print(" initialising velocity, position and direction")
self.pos = np.array(self.initPos)
self.speed = self.speedScale
self.dir = np.array(self.initDir)
#time and runID
print(" setting time/run counters")
self.t = 0
self.runID = 0
self.thetaPhase = self.thetaFreq*(self.t%(1/self.thetaFreq))*2*np.pi
#make maze
print(" making the maze walls")
self.walls = getWalls(mazeType=self.mazeType, roomSize=self.roomSize)
walls = self.walls.copy()
if self.doorsClosed == False:
del walls['doors']
self.mazeState['walls'] = walls
elif self.doorsClosed == True:
self.mazeState['walls'] = walls
#extent, xArray, yArray, discreteCoords
print(" discretising position for later plotting")
if abs((self.roomSize / self.dx) - round(self.roomSize / self.dx)) > 0.00001:
print(" dx must be an integer fraction of room size, setting it to %.4f, %g along room length" %(self.roomSize / round(self.roomSize / self.dx), round(self.roomSize / self.dx)))
self.dx = self.roomSize / round(self.roomSize / self.dx)
minx, maxx, miny, maxy = 0, 0, 0, 0
for room in self.walls:
wa = self.walls[room]
minx, maxx, miny, maxy = min(minx,np.min(wa[...,0])), max(maxx,np.max(wa[...,0])), min(miny,np.min(wa[...,1])), max(maxy,np.max(wa[...,1]))
self.extent = np.array([minx,maxx,miny,maxy])
self.width = maxx-minx
self.height = maxy-miny
self.xArray = np.arange(minx + self.dx/2, maxx, self.dx)
self.yArray = np.arange(miny + self.dx/2, maxy, self.dx)[::-1]
x_mesh, y_mesh = np.meshgrid(self.xArray,self.yArray)
coordinate_mesh = np.array([x_mesh, y_mesh])
self.discreteCoords = np.swapaxes(np.swapaxes(coordinate_mesh,0,1),1,2) #an array of discretised position coords over entire map extent
self.mazeState['extent'] = self.extent
#handle None params
print(" handling undefined parameters")
if self.dt == None:
self.dt = min(self.tau_STDP_plus,self.tau_STDP_minus) / 2
if self.pos is None:
ex = self.extent
self.pos = np.array([ex[0] + 0.2*(ex[1]-ex[0]),ex[2] + 0.2*(ex[3]-ex[2])])
if self.dir is None:
if self.mazeType == 'longCorridor': self.dir = np.array([0,1])
elif self.mazeType == 'loop': self.dir = np.array([1,0])
else: self.dir = np.array([1,1]) / np.sqrt(2)
if self.rotSpeedScale is None:
if self.mazeType == 'loop' or self.mazeType == 'longCorridor':
self.rotSpeedScale = np.pi
else:
self.rotSpeedScale = 3*np.pi
if (self.nCells is None) and (self.centres is None):
ex = self.extent
area, pcarea = (ex[1]-ex[0])*(ex[3]-ex[2]), np.pi * ((self.sigma/2)**2)
cellsPerArea = 10
self.nCells = int(cellsPerArea * area / pcarea) #~10 in any given place
if self.mazeType == 'TMaze':
self.LRDecisionPending=True
self.doorPassage = False
self.doorPassageTime = 0
self.lastTurnUpdate = -1
self.randomTurnSpeed = 0
#initialise basis cells and M (successor matrix)
print(" initialising basis features for learning")
if self.stateType in ['gaussian', 'gaussianCS','gaussianThreshold', 'circles','onehot','bump']:
if self.centres is not None: #if we don't provide locations for cell centres...
self.nCells = self.centres.shape[0]
self.stateSize = self.nCells
else: #scatter some ourselves (making sure they aren't too close)
self.stateSize=self.nCells
xcentres = np.random.uniform(self.extent[0],self.extent[1],self.nCells)
ycentres = np.random.uniform(self.extent[2],self.extent[3],self.nCells)
self.centres = np.array([xcentres,ycentres]).T
inds = self.centres[:,0].argsort()
self.centres = self.centres[inds]
print(" checking basis cells aren't too close")
min_d = 0.1/0.9
done = False
while done != True:
min_d *= 0.9
print(" min seperation distance: %.1f cm" %(min_d*100))
count = 0
while count <= 10:
d = distance_matrix(self.centres,self.centres)
d += 0.1*np.eye(d.shape[0])
d_xid, d_yid = np.where(d < min_d)
print(' ',int(len(d_xid)/2),' overlapping pairs',end='\n')
if len(d_xid) == 0:
done = True
break
to_remove = []
for i in range(len(d_xid)):
if d_xid[i] < d_yid[i]:
to_remove.append(d_xid[i])
to_remove = np.unique(to_remove)
xcentres = np.random.uniform(self.extent[0],self.extent[1],len(to_remove))
ycentres = np.random.uniform(self.extent[2],self.extent[3],len(to_remove))
self.centres[to_remove] = np.array([xcentres,ycentres]).T
count += 1
self.M = np.eye(self.stateSize)
self.W = self.M.copy() / self.nCells
self.M_theta = self.M.copy()
self.W_notheta = self.W.copy()
#order the place cells so successor matrix has some structure:
if self.reorderCells==True:
if self.mazeType == 'twoRooms': #from centre outwards
middle = np.array([self.extent[1]/2,self.extent[3]/2])
distance_to_centre = np.linalg.norm(middle - self.centres,axis=1)
distance_to_centre = distance_to_centre * (2*(self.centres[:,0]>middle[0])-1)
inds = distance_to_centre.argsort()
self.centres = self.centres[inds]
else: #from left to right
inds = self.centres[:,0].argsort()
self.centres = self.centres[inds]
elif self.stateType == 'fourier':
self.stateSize = self.nCells
self.kVectors = np.random.rand(self.nCells,2) - 0.5
self.kVectors /= np.linalg.norm(self.kVectors, axis=1)[:,None]
self.kFreq = 2*np.pi / np.random.uniform(0.01,1,size=(self.nCells))
self.phi = np.random.uniform(0,2*np.pi,size=(self.nCells))
self.M = np.eye(self.stateSize)
#self.M = np.zeros((self.stateSize,self.stateSize))
if hasattr(self.sigma,"__len__"):
if self.sigma.__len__() == self.nCells:
self.sigmas = self.sigma
else:
self.sigmas = np.array([self.sigma]*self.nCells)
#array of states, one for each discretised position coordinate
print(" calculating state vector at all discretised positions")
self.statesAlreadyInitialised = False
self.discreteStates = self.positionArray_to_stateArray(self.discreteCoords,stateType=self.stateType,verbose=True) #an array of discretised position coords over entire map extent
self.statesAlreadyInitialised = True
#store time zero snapshot
snapshot = pd.DataFrame({'t':[self.t], 'M': [self.M.copy()], 'W': [self.W.copy()],'W_notheta': [self.W_notheta.copy()], 'mazeState':[self.mazeState]})
self.snapshots = self.snapshots.append(snapshot)
#STDP stuff
print(" initialising STDP weight matrix and traces")
self.preTrace = np.zeros(self.nCells) #causes potentiation
self.preTrace_notheta = np.zeros(self.nCells) #causes potentiation
self.postTrace = np.zeros(self.nCells) #causes depression
self.postTrace_notheta = np.zeros(self.nCells) #causes depression
self.lastSpikeTime = np.array(-10.0)
self.lastSpikeTime_notheta = np.array(-10.0)
self.spikeCount = np.array(0)
self.spikeCount_notheta = np.array(0)
def runRat(self,
trainTime=10,
saveEvery=0.5,
TDSRLearn=True,
STDPLearn=True):
"""The main experiment call.
A "run" consists of a period where the agent explores the maze according to the movement policy.
As it explores it learns, by TD, a successor representation over state vectors.
The can be called multiple times. Each successive run will be saved in self.history with an increasing runID
Snapshots of the current SR matrix and mazeState can be saved along the way
Runs can be interrupted with KeyboardInterrupt, data will still be saved.
Args:
trainTime (int, optional): How long to explore in minutes. Defaults to 10.
saveEvery (int, optional): Frequency to save snapshots, in minutes. Defaults to 1.
TDSRLearn (bool,optional): toggles whether to do TD learning
STDPLearn (bool, optional): toggles whether to do STDP learning
"""
steps = int(trainTime * 60 / self.dt) #number of steps to perform
hist_t = np.zeros(steps)
hist_pos = np.zeros((steps,2))
hist_delta = np.zeros(steps)
lastTDstep, distanceToTD = 0, np.random.exponential(self.TDdx) #2cm scale
"""Main training loop. Principally on each iteration:
• always updates motion policy
• often does TD learning step
• sometimes saves snapshot"""
for i in tqdm(range(steps)): #main training loop
try:
#update pos, velocity, direction and time according to movement policy
self.movementPolicyUpdate()
if i > 1:
# print(self.pos)
"""STDP learning step"""
if (STDPLearn == True) and (self.stateType in ['bump','gaussian', 'gaussianCS','gaussianThreshold', 'circles']):
if self.use_full_STDP_rule == True:
_ = self.STDPLearningStep_detailed(dt = self.t - hist_t[i-1])
else:
_ = self.STDPLearningStep(dt = self.t - hist_t[i-1])
"""TD learning step"""
if TDSRLearn == True:
alpha = self.alpha
try: alpha_ = alpha[0] * np.exp(-(i/steps)*(np.log(self.alpha[0]/self.alpha[1]))) #decaying alpha
except: alpha_ = self.alpha
if np.linalg.norm(self.pos - hist_pos[lastTDstep]) >= distanceToTD: #if it's moved over 2cm meters from last step
dtTD = self.t - hist_t[lastTDstep]
delta = self.TDLearningStep(pos=self.pos, prevPos=hist_pos[lastTDstep], dt=dtTD, tau=self.tau, alpha=alpha_)
lastTDstep = i
distanceToTD = np.random.exponential(self.TDdx)
hist_delta[i] = delta
self.thetaPhase = self.thetaFreq*(self.t%(1/self.thetaFreq))*2*np.pi #8Hz theta
#update history arrays
hist_pos[i] = self.pos
hist_t[i] = self.t
#save snapshot
if (isinstance(saveEvery, numbers.Number)) and (i % int(saveEvery * 60 / self.dt) == 0):
snapshot = pd.DataFrame({'t':[self.t], 'M': [self.M.copy()], 'W': [self.W.copy()], 'W_notheta':[self.W_notheta.copy()], 'mazeState':[self.mazeState]})
self.snapshots = self.snapshots.append(snapshot)
except KeyboardInterrupt:
print("Keyboard Interrupt:")
break
# except ValueError as error:
# print("ValueError:")
# print(error)
# print(f" Rat position: {self.pos}")
# break
self.runID += 1
runHistory = pd.DataFrame({'t':list(hist_t[:i]), 'pos':list(hist_pos[:i]),'delta':list(hist_delta[:i])})
self.history = self.history.append(runHistory)
snapshot = pd.DataFrame({'t': [self.t], 'M': [self.M.copy()], 'W': [self.W.copy()], 'W_notheta':[self.W_notheta.copy()], 'mazeState':[self.mazeState]})
self.snapshots = self.snapshots.append(snapshot)
#find and save grid/place cells so you don't have to repeatedly calculate them when plotting
print("Calculating place and grid cells")
self.gridFields = self.getGridFields(self.M)
self.placeFields = self.getPlaceFields(self.M)
if TDSRLearn == True:
# plotter = Visualiser(self)
# plotter.plotTrajectory(starttime=(self.t/60)-0.2, endtime=self.t/60)
delta = np.array(hist_delta)
time = np.array(hist_t)
time = time[delta!=0] / 60
delta = delta[delta!=0]
time, delta = time[::10], delta[::10]
smooth_delta = [np.mean(delta[max(0,i-100):min(i+100,len(delta))]) for i in range(len(delta))]
fig, ax = plt.subplots(figsize=(2,1))
ax.scatter(time,delta,s=0.5,alpha=0.5)
ax.scatter(time,smooth_delta,s=1,alpha=0.5,c='C2')
ax.set_xlabel("Time / min")
ax.set_ylabel("Update size")
def TDLearningStep(self, pos, prevPos, dt, tau, alpha):
"""TD learning step
Improves estimate of SR matrix, M, by a TD learning step.
By default this is done using learning rule for generic feature vectors (see de Cothi and Barry 2020).
If stateType is onehot, additional efficiencies can be gained by using onehot specific learning rule (see Stachenfeld et al. 2017)
Does time continuous TD learning (see Doya, 2000)
Args:
pos: position at t+dt (t)
prevPos (array): position at t (t-dt)
dt (float): time difference between two positions
tau (float or int): memory decay time (analogous to gamma in TD, gamma = 1 - dt/tau)
alpha (float): learning rate
mask (bool or str): whether to mask TM update to update only cells near current location
asynchronus (bool): update cells asynchronusly (like hopfield)
"""
state = self.posToState(pos,stateType=self.stateType)
prevState = self.posToState(prevPos,stateType=self.stateType)
data = ( (state, prevState, self.M ) ,
(self.thetaModulation(state), self.thetaModulation(state), self.M_theta) )
for i, (state, prevState, M) in enumerate(data):
#onehot optimised TD learning
if self.stateType == 'onehot':
s_t = np.argwhere(prevState)[0][0]
s_tplus1 = np.argwhere(state)[0][0]
Delta = state + (tau / dt) * ((1 - dt/tau) * M[:,s_tplus1] - M[:,s_t])
M[:,s_t] += alpha * Delta - 2 * alpha * self.TDreg * M[:,s_t]
#normal TD learning
else:
delta = ((tau * dt) / (tau + dt)) * self.successorFeatureNorm * prevState + M @ ((tau/(tau + dt))*state - prevState)
Delta = np.outer(delta, prevState)
M += alpha * Delta - 2 * alpha * self.TDreg * M #regularisation
if i == 0:
Del = Delta
return np.linalg.norm(Del)
def STDPLearningStep(self,dt):
"""Takes the curent theta phase and estimate firing rates for all basis cells according to a simple theta sweep model.
From here it samples spikes and performs STDP learning on a weight matrix.
Args:
dt (float): Time step length
Returns:
float array: vector of firing rates for this time step
"""
state = self.posToState(self.pos)
data = ( (state,
self.W_notheta,
self.preTrace_notheta,
self.postTrace_notheta,
self.lastSpikeTime_notheta,
self.spikeCount_notheta),
(self.thetaModulation(state),
self.W,
self.preTrace,
self.postTrace,
self.lastSpikeTime,
self.spikeCount),
)
for i, (firingRate, W, preTrace, postTrace, lastSpikeTime, spikeCount) in enumerate(data):
firingRate_ = self.peakFiringRate * firingRate + self.baselineFiringRate #scale firing rate and add noise
n_spike_list = np.random.poisson(firingRate_*dt)
spikingNeurons = (n_spike_list != 0) #in short time dt cells can spike 0 or 1 time only (good enough approximation)
spikeCount += sum(spikingNeurons)
spikeTimes = np.random.uniform(self.t,self.t+dt,self.nCells)[spikingNeurons]
spikeIDs = np.arange(self.nCells)[spikingNeurons]
spikeList = np.vstack((spikeIDs,spikeTimes)).T
spikeList = spikeList[np.argsort(spikeList[:,1])]
for spikeInfo in spikeList:
cell, time = int(spikeInfo[0]), spikeInfo[1]
timeDiff = time - lastSpikeTime
preTrace *= np.exp(- timeDiff / self.tau_STDP_plus) #traces for all cells decay...
postTrace *= np.exp(- timeDiff / self.tau_STDP_minus) #traces for all cells decay...
W[cell,:] += self.eta * preTrace #weights to postsynaptic neuron (should increase when post fires)
W[:,cell] += self.eta * postTrace #weights to presynaptic neuron (should decrease when post fires)
postTrace[cell] += self.a_STDP #update trace (post trace probably negative)
preTrace[cell] += 1 #update trace
lastSpikeTime += timeDiff
if i == 1:
thetaFiringRate = firingRate_
return thetaFiringRate
def STDPLearningStep_detailed(self,dt):
"""Takes the curent theta phase and estimate firing rates for all basis cells according to a simple theta sweep model.
From here it samples spikes and performs STDP learning on a weight matrix.
Args:
dt (float): Time step length
Returns:
float array: vector of firing rates for this time step
"""
state = self.posToState(self.pos)
data = ( (state,
self.W_notheta,
self.preTrace_notheta,
self.postTrace_notheta,
self.lastSpikeTime_notheta,
self.spikeCount_notheta),
(self.thetaModulation(state),
self.W,
self.preTrace,
self.postTrace,
self.lastSpikeTime,
self.spikeCount),
)
for i, (firingRate, W, preTrace, postTrace, lastSpikeTime, spikeCount) in enumerate(data):
preFiringRate_ = self.peakFiringRate * firingRate + self.baselineFiringRate #scale firing rate and add noise
if self.online_mapping == "identity":
mapMatrix = np.identity(self.nCells)
elif self.online_mapping == "Widentity":
mapMatrix = W + 0.5*np.identity(self.nCells)
elif self.online_mapping == "W":
mapMatrix = W
else:
mapMatrix = self.online_mapping
postFiringRate_ = np.maximum(0,np.matmul(mapMatrix,preFiringRate_))
firingRate_ = np.concatenate((preFiringRate_,postFiringRate_))
layerLabel_ = np.array(['pre']*len(preFiringRate_) + ['post']*len(postFiringRate_))
neuronIDs = np.concatenate((np.arange(len(preFiringRate_)), np.arange(len(postFiringRate_))))
n_spike_list = np.random.poisson(firingRate_*dt)
spikingNeurons = (n_spike_list != 0) #in short time dt cells can spike 0 or 1 time only (good enough approximation)
spikeCount += sum(spikingNeurons)
spikeTimes = np.random.uniform(self.t,self.t+dt,len(neuronIDs))[spikingNeurons]
spikeIDs = neuronIDs[spikingNeurons]
spikeLayerLabels = layerLabel_[spikingNeurons]
spikeList = np.vstack((spikeIDs,spikeTimes,spikeLayerLabels)).T
spikeList = spikeList[np.argsort(spikeList[:,1])]
for spikeInfo in spikeList:
cell, time, layer = int(spikeInfo[0]), float(spikeInfo[1]), spikeInfo[2]
timeDiff = time - lastSpikeTime
preTrace *= np.exp(- timeDiff / self.tau_STDP_plus) #traces for all cells decay...
postTrace *= np.exp(- timeDiff / self.tau_STDP_minus) #traces for all cells decay...
if layer == 'pre':
W[:,cell] += self.eta * postTrace #weights from presynaptic neuron should decrease when pre fires (post-before-PRE)
preTrace[cell] += 1 #update trace
if layer == 'post':
W[cell,:] += self.eta * preTrace #weights to postsynaptic neuron should increase when post fires (pre-before-POST)
postTrace[cell] += self.a_STDP #update trace (post trace probably negative)
lastSpikeTime += timeDiff
if i == 1:
thetaFiringRate = firingRate_
if self.rownorm == True:
# self.W = self.W / np.linalg.norm(self.W,axis=1)[:,np.newaxis]
# self.W_notheta = self.W_notheta / np.linalg.norm(self.W_notheta,axis=1)[:,np.newaxis]
sumW = np.sum(self.W,axis=1)
sumW[sumW<1]=1
self.W = self.W / sumW[:,np.newaxis]
sumWnt = np.sum(self.W,axis=1)
sumWnt[sumWnt<1]=1
self.W_notheta = self.W_notheta / sumWnt[:,np.newaxis]
#save spike data
CA3spiketimes = spikeTimes[spikeLayerLabels=='pre']
CA3spikeids = spikeIDs[spikeLayerLabels=='pre']
CA1spiketimes = spikeTimes[spikeLayerLabels=='post']
CA1spikeids = spikeIDs[spikeLayerLabels=='post']
self.spikedata['CA3']['times'].extend(CA3spiketimes)
self.spikedata['CA3']['ids'].extend(CA3spikeids)
self.spikedata['CA1']['times'].extend(CA1spiketimes)
self.spikedata['CA1']['ids'].extend(CA1spikeids)
return thetaFiringRate
def thetaModulation(self, firingRate, position=None, direction=None):
"""Takes a firing rate vector and modulates it to account for theta phase precession
Args:
firingRate (np.array): The raw (position dependent) firing rate vector to be modulated
position (np.array(2,), optional): The agent position. Defaults to None.
direction (np.array(2,), optional): The agent direction. Defaults to None.
"""
if position is None:
position = self.pos
if direction is None:
direction = self.dir
vectorToCells = self.vectorsToCellCentres(position)
sigmasToCellMidline = (np.dot(vectorToCells,direction) / np.linalg.norm(direction)) / self.sigmas #as mutiple of sigma
preferedThetaPhase = np.pi + sigmasToCellMidline * self.precessFraction * np.pi
phaseDiff = preferedThetaPhase - self.thetaPhase
modulatedFiringRate = firingRate * vonmises.pdf(phaseDiff,kappa=self.kappa) * 2*np.pi
return modulatedFiringRate
def movementPolicyUpdate(self):
"""Movement policy update.
In principle this does a very simple thing:
• updates time by dt,
• updates position along the velocity direction
• updates velocity (speed and direction) accoridng to a movement policy
In reality it's a complex function as the policy requires checking for immediate or upcoming collisions with all walls at each step.
This is done by function self.checkWallIntercepts()
What it does with this info (bounce off wall, turn to follow wall, etc.) depends on policy.
"""
dt = self.dt
self.t += dt
proposedNewPos = self.pos + self.speed * self.dir * dt
proposedStep = np.array([self.pos,proposedNewPos])
if (self.biasDoorCross == True) and (self.mazeType == 'twoRooms'):
#if agent crosses into door zone there's its turn direction is biased to try and cross the door
#this is done by setting agents direction in the right direction and not changing it again until after it's crossed
doorRegionSize = 1
if self.doorPassage == False:
#if step cross into door region
if (np.linalg.norm(self.pos - np.array([self.roomSize,self.roomSize/2])) > doorRegionSize) and (np.linalg.norm(proposedNewPos - np.array([self.roomSize,self.roomSize/2])) < doorRegionSize) and (abs(self.pos[0] - self.roomSize) > 0.01):
if 100*np.random.uniform(0,1) < 50: #start a doorPassage
self.doorPassage = True
self.doorPassageTime = self.t
return
else: #ignore this
pass
if self.doorPassage == True:
if ((self.pos[0]<(self.roomSize)) != (proposedNewPos[0]<(self.roomSize))) or ((self.t - self.doorPassageTime)*self.speedScale > 2*doorRegionSize):
self.doorPassage = False
if ((self.pos[0]<(self.roomSize)) != (proposedNewPos[0]<(self.roomSize))):
print("crossed",self.t)
if ((self.t - self.doorPassageTime)*self.speedScale > 2*doorRegionSize):
print("time")
checkResult = self.checkWallIntercepts(proposedStep)
if self.movementPolicy == 'randomWalk':
if checkResult[0] != 'collisionNow':
self.pos = proposedNewPos
randomTurnSpeed = np.random.normal(0,self.rotSpeedScale)
self.dir = turn(self.dir,turnAngle=randomTurnSpeed*dt)
elif checkResult[0] == 'collisionNow':
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
if self.movementPolicy == 'trueRandomWalk':
if checkResult[0] != 'collisionNow':
self.pos = proposedNewPos
self.dir = turn(self.dir,turnAngle=np.random.uniform(0,2*np.pi))
elif checkResult[0] == 'collisionNow':
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
if self.movementPolicy == 'leftRightRandomWalk':
if checkResult[0] != 'collisionNow':
self.pos = proposedNewPos
self.dir = turn(self.dir,turnAngle=np.random.choice([0,np.pi]))
elif checkResult[0] == 'collisionNow':
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
if self.movementPolicy == 'raudies':
if checkResult[0] == 'collisionNow':
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
elif ((checkResult[0] == 'collisionAhead') and (self.biasWallFollow==True)):
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'follow')
elif (checkResult[0] == 'noImmediateCollision') or (((checkResult[0] == 'collisionAhead') and (self.biasWallFollow==False))):
self.pos = proposedNewPos
self.speed = np.random.rayleigh(self.speedScale)
if self.t - self.lastTurnUpdate >= 0.1: #turn updating done at intervals independednt of dt or else many small turns cancel out but few big ones dont
randTurnMean = 0
if self.doorPassage == True:
d_theta = theta(self.dir) - theta(np.array([self.roomSize,self.roomSize/2]) - self.pos)
if d_theta > 0: randTurnMean = -self.rotSpeedScale
else: randTurnMean = self.rotSpeedScale
self.randomTurnSpeed = np.random.normal(randTurnMean,self.rotSpeedScale)
self.lastTurnUpdate = self.t
self.dir = turn(self.dir, turnAngle=self.randomTurnSpeed*dt)
if self.movementPolicy == 'windowsScreensaver':
if checkResult[0] != 'collisionNow':
self.pos = proposedNewPos
elif checkResult[0] == 'collisionNow':
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
if self.movementPolicy == '1DOrnUhl':
if checkResult[0] != 'collisionNow':
self.pos = proposedNewPos
elif checkResult[0] == 'collisionNow':
wall = checkResult[1]
self.dir = wallBounceOrFollow(self.dir,wall,'bounce')
self.speed += ornstein_uhlenbeck(dt=dt, x=self.speed, drift=self.speedScale,noise_scale=self.speedScale, coherence_time=5)
self.speed = max(0,self.speed)
if self.mazeType == 'loop':
self.pos[0] = self.pos[0] % self.roomSize
if self.mazeType == 'TMaze':
if (self.pos[0] > self.roomSize+0.05) and (self.LRDecisionPending==True):
if np.random.choice([0,1],p=[0.66,0.34]) == 0:
self.dir = np.array([0,1])
else:
self.dir = np.array([0,-1])
self.LRDecisionPending=False
if self.pos[1] > self.extent[3] or self.pos[1] < self.extent[2]:
self.pos = np.array([0,1])
self.dir = np.array([1,0])
self.LRDecisionPending=True
#catchall instances a rat escapes the maze by accident, pops it 2cm within maze
if ((self.pos[0] < self.extent[0]) or
(self.pos[0] > self.extent[1]) or
(self.pos[1] < self.extent[2]) or
(self.pos[1] > self.extent[3])):
print(self.pos)
self.pos[0] = max(self.pos[0],self.extent[0]+0.02)
self.pos[0] = min(self.pos[0],self.extent[1]-0.02)
self.pos[1] = max(self.pos[1],self.extent[2]+0.02)
self.pos[1] = min(self.pos[1],self.extent[3]-0.02)
print("Rat escaped!")
if self.mazeType == 'TMaze':
self.dir=np.array([1,0])
self.LRDecisionPending = True
# plotter = Visualiser(self)
# plotter.plotTrajectory(starttime=(self.t/60)-0.2, endtime=self.t/60)
def vectorsToCellCentres(self,pos,distance=False):
"""Takes a posisiton vector shape (2,) and returns an array of shape (nCells,2) of the
shortest vector path to all cells, taking into account loop geometry etc.
Args:
pos (array): position vector shape (2,)
Returns:
vectorToCells (array): shape (30,2)
"""
if self.mazeType == 'loop' and self.doorsClosed == False:
pos_plus = pos + np.array([self.roomSize,0])
pos_minus = pos - np.array([self.roomSize,0])
positions = np.array([pos,pos_plus,pos_minus])
vectors = self.centres[:,np.newaxis,:] - positions[np.newaxis,:,:]
shortest = np.argmin(np.linalg.norm(vectors,axis=-1),axis=1)
shortest_vectors = np.diagonal(vectors[:,shortest,:],axis1=0,axis2=1).T
else:
shortest_vectors = self.centres - self.pos
return shortest_vectors
def distanceToCellCentres(self, pos):
"""Calculates distance to cell centres.
In the case of the two room maze, this distance is the shortest feasible walk carefully accounting for doorways etc.
Args:
pos (no.array): The position to calculate the distances from
Returns:
np.array: (nCells,) array of distances
"""
if self.mazeType == 'twoRooms':
distances = np.zeros(self.nCells)
wall_x = self.walls['doors'][0][0][0]
wall_y1, wall_y2 = self.walls['doors'][0][0][1], self.walls['doors'][0][1][1]
for i in range(self.nCells):
vec = np.array(pos - self.centres[i])
if ((self.centres[i][0] < wall_x) and (pos[0] < wall_x)) or ((self.centres[i][0] > wall_x) and (pos[0] > wall_x)):
distances[i] = np.linalg.norm(vec)
else: #cell and position in different rooms
if self.doorsClosed == True:
distances[i] = 100*self.roomSize
print(doorsClosed)
else:
step = np.array([pos,self.centres[i]])
if self.checkWallIntercepts(step)[0] == 'collisionNow':
pastBottomWall = np.linalg.norm(np.array([wall_x,wall_y1]) - pos) + np.linalg.norm(np.array([wall_x,wall_y1]) - self.centres[i])
pastTopWall = np.linalg.norm(np.array([wall_x,wall_y2]) - pos) + np.linalg.norm(np.array([wall_x,wall_y2]) - self.centres[i])
distances[i] = min(pastBottomWall,pastTopWall)
else:
distances[i] = np.linalg.norm(vec)
else:
shortest_vector = self.vectorsToCellCentres(pos)
distances = np.linalg.norm(shortest_vector,axis=1)
return distances
def toggleDoors(self, doorsClosed = None): #this function could be made more advanced to toggle more maze options
"""Opens or closes door and updates mazeState
mazeState stores the most recent version of the maze walls dictionary which will include 'door' wall only if doorsClosed is True
Args:
doorsClosed ([bool], optional): True is doors to be closed, False if doors to be opened. Defaults to None, in which case current door state is flipped.
Returns:
[dict]: the walls dictionary
"""
if doorsClosed is not None:
self.doorsClosed = doorsClosed
else: self.doorsClosed = not self.doorsClosed
walls = self.walls.copy()
if self.doorsClosed == False:
del walls['doors']
self.mazeState['walls'] = walls
elif self.doorsClosed == True:
self.mazeState['walls'] = walls
self.discreteStates = self.positionArray_to_stateArray(self.discreteCoords,stateType=self.stateType) #an array of discretised position coords over entire map extent
return self.mazeState['walls']
def checkWallIntercepts(self,proposedStep,collisionDistance=0.1): #proposedStep = [pos,proposedNextPos]
"""Given the cuurent proposed step [currentPos, nextPos] it calculates whether a collision with any of the walls exists along this step.
There are three possibilities from most worrying to least:
• there is a collision ON the current step. Do something immediately.
• there is a collision along the current trajectory in the next few cm's, but not on the current step. Consider doing something.
• there is no collision coming up soon. Carry on as you are.
Args:
proposedStep (array): The proposed step. np.array( [ [x_current, y_current] , [x_next, y_next] ] )
Returns:
tuple: (str, array), (<whether there is no collision, collision now or collision ahead> , <the wall in question>)
"""
s1, s2 = np.array(proposedStep[0]), np.array(proposedStep[1])
pos = s1
ds = s2 - s1
stepLength = np.linalg.norm(ds)
ds_perp = perp(ds)
collisionList = [[],[]]
futureCollisionList = [[],[]]
#check if the current step results in a collision
walls = self.mazeState['walls'] #current wall state
for wallObject in walls.keys():
for wall in walls[wallObject]:
w1, w2 = np.array(wall[0]), np.array(wall[1])
dw = w2 - w1
dw_perp = perp(dw)
# calculates point of intercept between the line passing along the current step direction and the lines passing along the walls,
# if this intercept lies on the current step and on the current wall (0 < lam_s < 1, 0 < lam_w < 1) this implies a "collision"
# if it lies ahead of the current step and on the current wall (lam_s > 1, 0 < lam_w < 1) then we should "veer" away from this wall
# this occurs iff the solution to s1 + lam_s*(s2-s1) = w1 + lam_w*(w2 - w1) satisfies 0 <= lam_s & lam_w <= 1
with np.errstate(divide='ignore'):
lam_s = (np.dot(w1, dw_perp) - np.dot(s1, dw_perp)) / (np.dot(ds, dw_perp))
lam_w = (np.dot(s1, ds_perp) - np.dot(w1, ds_perp)) / (np.dot(dw, ds_perp))
#there are two situations we need to worry about:
# • 0 < lam_s < 1 and 0 < lam_w < 1: the collision is ON the current proposed step . Do something immediately.
# • lam_s > 1 and 0 < lam_w < 1: the collision is on the current trajectory, some time in the future. Maybe do something.
if (0 <= lam_s <= 1) and (0 <= lam_w <= 1):
collisionList[0].append(wall)
collisionList[1].append([lam_s,lam_w])
continue
if (lam_s > 1) and (0 <= lam_w <= 1):
if lam_s * stepLength <= collisionDistance: #if the future collision is under collisionDistance away
futureCollisionList[0].append(wall)
futureCollisionList[1].append([lam_s,lam_w])
continue
if len(collisionList[0]) != 0:
wall_id = np.argmin(np.array(collisionList[1])[:,0]) #first wall you collide with on step
wall = collisionList[0][wall_id]
return ('collisionNow', wall)
elif len(futureCollisionList[0]) != 0:
wall_id = np.argmin(np.array(futureCollisionList[1])[:,0]) #first wall you would collide with along current step
wall = futureCollisionList[0][wall_id]
return ('collisionAhead', wall)
else:
return ('noImmediateCollision',None)
def getPlaceFields(self, M=None, threshold=None):
"""Calculates receptive fiels of all place cells
There is one place cell for each feature cell.
A place cell (as in de Cothi 2020) is defined as a thresholded linear combination of feature cells
where the linear combination is a row of the SR matrix.
Args:
M (array): SR matrix
Returns:
array: Receptive fields of shape [nCells, nX, nY]
"""
if M is None:
M = self.M
M = M.copy()
#normalise:
# M = M / np.diag(M)[:,np.newaxis]
if threshold is None:
placeCellThreshold = 0.9 #place cell threshold value (fraction of its maximum)
else:
placeCellThreshold = threshold
placeFields = np.einsum("ij,klj->ikl",M,self.discreteStates)
threshold = placeCellThreshold*np.amax(placeFields,axis=(1,2))[:,None,None]
# threshold = placeCellThreshold
placeFields = np.maximum(0,placeFields - threshold)
return placeFields
def getGridFields(self, M, alignToFinal=False):
"""Calculates receptive fiels of all grid cells
There is an equal number of grid cells as place cells and feature cells.
A grid cell (as in de Cothi 2020) is defined as a thresholded linear combination of feature cells
where the linear combination weights are the eigenvectors of the SR matrix.
Args:
M (array): SR matrix
alignToFinal (bool): Since negative of eigenvec is also eigenvec try maximise overlap with final one (for making animations)
Returns:
array: Receptive fields of shape [nCells, nX, nY]
"""
M = M.copy()
_, eigvecs = np.linalg.eig(M) #"v[:,i] is the eigenvector corresponding to the eigenvalue w[i]"
eigvecs = np.real(eigvecs)
gridCellThreshold = 0
gridFields = np.einsum("ij,kli->jkl",eigvecs,self.discreteStates)
threshold = gridCellThreshold*np.amax(gridFields,axis=(1,2))[:,None,None]
if alignToFinal == True:
grids_final_flat = np.reshape(self.gridFields,(self.stateSize,-1))
grids_flat = np.reshape(gridFields,(self.stateSize,-1))
dotprods = np.empty(grids_flat.shape[0])
for i in range(len(dotprods)):
dotprodsigns = np.sign(np.diag(np.matmul(grids_final_flat,grids_flat.T)))
gridFields *= dotprodsigns[:,None,None]
gridFields = np.maximum(0,gridFields)
return gridFields
def posToState(self, pos, stateType=None, normalise=True, cheapNormalise=False,initialisingCells=False): #pos is an [n1, n2, n3, ...., 2] array of 2D positions
if (self.statesAlreadyInitialised == False) or (self.firingRateLookUp == False):
#calculates the firing rate of all cells
pos = np.array(pos)
if stateType == None: stateType = self.stateType
vector_to_cells = self.centres - pos
distance_to_cells = [np.linalg.norm(vector_to_cells,axis=1)]
closest_cell_ID = np.argmin(distance_to_cells)
if (self.mazeType == 'loop') and (self.doorsClosed == False):
distance_to_cells.append(np.linalg.norm(self.centres - pos + [self.extent[1],0],axis=1))
distance_to_cells.append(np.linalg.norm(self.centres - pos - [self.extent[1],0],axis=1))
if (self.mazeType == 'twoRooms'):
distance_to_cells = [self.distanceToCellCentres(pos)]
if stateType == 'onehot':
state = np.zeros(self.nCells)
state[closest_cell_ID] = 1
if stateType == 'gaussianThreshold':
state = np.zeros(self.nCells)
for distance in distance_to_cells:
state += np.maximum(np.exp(-distance**2 / (2*(self.sigmas**2))) - np.exp(-1/2) , 0) / (1-np.exp(-1/2))
# state = state / (self.sigmas) #normalises so same no. spikes emitted for all cell sizes
if stateType == 'gaussian':
state = np.zeros(self.nCells)
for distance in distance_to_cells:
state += np.exp(-distance**2 / (2*(self.sigmas**2)))
state = state/self.sigmas
if stateType == 'bump':
state = np.zeros(self.nCells)
for distance in distance_to_cells:
state[distance<self.sigmas] += np.e * np.exp(-1/(1-(distance/self.sigmas)**2))[distance<self.sigmas]
state[distance>=self.sigmas] += 0
state = state/self.sigmas