In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import *

In [None]:
# no. of stimuli x2 for not-stimuli as well
nf = 2*2
stim = {
    0:"X",
    1:"Y",
    2:"NotX",
    3:"NotY"
}
# variables
dt = 0.1
a = 0.8

In [None]:
#generate a range of s and tau_star
def s_gen(tau_min=1, tau_max=100., buff_max=None, k=20, ntau=40, g=0.0):
    c = (tau_max/tau_min)**(1./(ntau-1))-1
    tau_star = tau_min*(1+c)**np.arange(ntau).astype(float)
    s = 1/tau_star
    return [tau_star, s]
[tau_star, S] = s_gen()

In [None]:
S

In [None]:
def stim_gen(i):
    global nf
    stim_vec = np.zeros(nf)
    if i is not None:
        stim_vec[i] = 1
    return stim_vec

In [None]:
def input_update(a,s,s_idx,fi_IN: None):
    global M
    global F
    global P
    #if fi_IN is not None:
        #print("Inputing stim", stim[fi_IN])
    f_IN = stim_gen(fi_IN)
    #eq 1
    F[s_idx] += -s*F[s_idx] + f_IN
    F[s_idx] = np.where(F[s_idx]<0, 0, F[s_idx])
    #eq 2
    P_IN = np.dot(M[s_idx], f_IN)
    #eq 3
    P[s_idx] +=  s*P[s_idx] + P_IN - P[0]
    P[s_idx] = np.where(P[s_idx]<0, 0, P[s_idx])
    #find stimulis from the past that are still active & update associations
    if fi_IN is not None:
        for past_stim in np.flatnonzero(F[s_idx]):
            #Avoiding self-prediction updates
            if past_stim != fi_IN:
                #eq4 - make sure we're indexing M correctly
                M[s_idx][fi_IN, past_stim] = a * M[s_idx][fi_IN, past_stim] + (1-a) * F[s_idx][past_stim]
def detect_notstim(a,t,s,s_idx,fi_IN:None):
    global M
    global F
    global P
    predicted = np.where(P[s_idx]>=1)[0]
    #iterate through them
    for p in predicted:
        if p == fi_IN:
            print(stim[p], "Observed & Predicted with p=%s! at at s=%s and t=%s"
                  %(P[s_idx][p],s,t))
            P[s_idx][p] = 0
        else:
            P[s_idx][p] = 0
            if p < 2:
                print(stim[p+2], "predicted at s=%s and t=%s" %(s,t))
                #add 2 to get the not-stimuli equivalent
                input_update(a,s,s_idx,fi_IN=p+2)
            else:
                print(stim[p], "predicted at s=%s and t=%s" %(s,t))
                input_update(a,s,s_idx, fi_IN=p)
def timestep_update(a,t,S,fi_IN: None):
    for s_idx, s in enumerate(S):
        input_update(a,s,s_idx,fi_IN)
        detect_notstim(a,t,s,s_idx,fi_IN) 

In [None]:
M = np.zeros([len(S), nf, nf])
F = np.zeros([len(S), nf])
P = np.zeros([len(S), nf])
#modify in loop to change tracking element
P_tracker = np.zeros(nf)
P_max = np.zeros(nf)
M_tracker = list()
F_tracker = np.zeros(nf)
#alter this for different stimuli presentation at different times
t_max = 10
f = [[0,1],[1,1.5],[0,4],[0,10.1]]
#pick an s value to track
s_track = 2

In [None]:
T = np.arange(0,t_max, dt)
fi_prev = None
for t in T:
    if f[0][1] == t:
        fi_IN = f[0][0]
        timestep_update(a,t,S,fi_IN=fi_IN)
        f.pop(0)
    else:
        timestep_update(a,t,S,fi_IN=None)
    P_tracker = np.vstack((P_tracker,P[s_track]))
    P_max = np.vstack((P_max,P[-1]))
    M_tracker.append(M[s_track].copy())
    F_tracker = np.vstack((F_tracker,F[s_track]))

In [None]:
#fig = plt.figure(figsize = (5,10))
#ax = fig.subplots(2,1)
#ax[0].imshow(F_tracker, aspect='auto')
#ax[0].set_title("F Values, s=%s" %(S[s_track]))
#ax[0].set_ylabel("Time")
#ax[0].set_xticks(np.arange(4),['X', 'Y','Not X', 'Not Y'])
#ax[1].imshow(P_tracker, aspect='auto')
#ax[1].set_title("P Values, s=%s" %(S[s_track]))
#ax[1].set_ylabel("Time")
#ax[1].set_xticks(np.arange(4),['X', 'Y','Not X', 'Not Y'])
#ax[2].imshow(P_max, aspect='auto')
#ax[2].set_title("P Values, s=%s" %(S[-1]))
#ax[2].set_ylabel("Time")
#ax[2].set_xticks(np.arange(4),['X', 'Y','Not X', 'Not Y'])

In [None]:
fig = plt.figure(figsize = (6,10))
ax = fig.subplots(2,1)
ax[0].plot(T, F_tracker[1:,:])
ax[0].set_title("F Values, s=%s" %(S[s_track]))
ax[0].set_xlabel("Time")
#ax[0].legend(['X', 'Y','Not X', 'Not Y'])

ax[1].plot(T, P_tracker[1:,:])
ax[1].set_title("P Values, s=%s" %(S[s_track]))
ax[1].set_xlabel("Time")
ax[1].legend(['X', 'Y','Not-X (X predicted, not observed)', 'Not Y (Y predicted, not observed)']
             ,bbox_to_anchor=(1.1, 1))

In [None]:
#plot tracked variables through time - doesn't work yet
x = np.arange(nf)
fig = plt.figure()
ax = fig.subplots(3,1)
i = 0
F_plot = ax[0].scatter(x, F_tracker[i])
P_plot = ax[2].scatter(x, P_tracker[i])
M_plot = ax[1].pcolormesh(M_tracker[i])

def update(i=0.0):
    idx = int(i*10)
    F_plot.set_offsets(np.c_[x,F_tracker[idx]])
    P_plot.set_offsets(np.c_[x,P_tracker[idx]])
    M_plot.set_array(M_tracker[idx])
    #scatter.set_ydata(Tracker[idx])
    fig.canvas.draw_idle()

interact(update, i = (0.0,t_max,dt));