In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import *

In [4]:
def stim_gen(i):
    global nf
    stim_vec = np.zeros(nf)
    if i is not None:
        stim_vec[i] = 1
    return stim_vec

class iSITH():
    def __init__(self, tau_min=.1, tau_max=3, buff_max=None, k=50, ntau=20, dt=1, g=0.0):
        super(iSITH, self).__init__()
        self.k = k
        self.tau_min = tau_min
        self.tau_max = tau_max
        if buff_max is None:
            buff_max = 3*tau_max
        self.buff_max = buff_max
        self.ntau = ntau
        self.dt = dt
        self.g = g

        self.c = (tau_max/tau_min)**(1./(ntau-1))-1
        self.tau_star = tau_min*(1+self.c)**np.arange(ntau).astype(float)
        self.s = 1/self.tau_star

class worker():
    def __init__(self, s, a, nf=2):
        self.a = a
        self.s = s
        self.M = np.zeros([nf,nf])
        self.F = np.zeros(nf)
        self.P = np.zeros(nf)
    def update(self, dt, fi_IN: None, f_IN: None, P_smax: float):
        if fi_IN is not None:
            self.F += (-self.s*self.F)*dt + f_IN
            P_IN = np.dot(self.M, f_IN)
            self.P += (self.s*self.P)*dt + P_IN - P_smax
            for past_stim in np.flatnonzero(self.F):
                if past_stim != fi_IN:
                    self.M[fi_IN,past_stim] = self.a*self.M[fi_IN,past_stim]+(1-self.a)*self.F[past_stim]
        else:
            self.F += (-self.s*self.F)*dt
            self.P += (self.s*self.P)*dt  - P_smax
        return (self.F, self.P)

In [5]:
# no. of stimuli x2 for not-stimuli as well
nf = 2#*2
# variables
dt = 0.01
#A = np.arange(0,1,0.2)
a = 0.5
e = 0.0000005

In [182]:
S = iSITH(tau_min=0.1,tau_max=3,ntau=20).s
#stim presentation
t_x1 = 1
t_y = 2
t_x2 = 7
#time steps
t_max = 9#t_x2+(t_y-t_x1)+4
T = np.arange(0,t_max, dt)

In [184]:
workers = []
predictions = np.zeros([len(S), nf])
for s in S:
    workers.append(worker(s, a))
    
P_smax = 0
F_xtrack = np.zeros(len(S))
P_ytrack = np.zeros(len(S))

f = [[0,t_x1],[1,t_y],[0,t_x2],[0,t_max+0.1]]
(stim, stim_t) = f.pop(0)

for t in T:
    if stim_t == t:
        fi_IN = stim
        f_IN = stim_gen(fi_IN)
        (stim, stim_t) = f.pop(0)
    else: 
        fi_IN = None
        f_IN = None
    P_sum = np.zeros(nf)
    F = np.zeros([len(S),nf])
    P = np.zeros([len(S),nf])
    for i in range(len(S)):
        (F[i], P[i]) = workers[i].update(dt=dt,fi_IN=fi_IN,f_IN=f_IN,P_smax=P_smax)        
    P_smax = workers[0].P
    P_sum = P_y.sum()
    #if (t>t_x2)&(P_sum < e):
        #print("Not Y predicted at %s" %(t))
    F_xtrack = np.row_stack((F_xtrack, F[:,0]))
    P_ytrack = np.row_stack((P_ytrack, P[:,1]))
F_xtrack = F_xtrack[1:,:]
P_ytrack = P_ytrack[1:,:]

In [6]:
plt.imshow(F_xtrack[100:201,:].T)

In [7]:
#plt.imshow(P_ytrack[700:801,:].T)