In [1]:
from brian2 import *
import numpy as np
import matplotlib.pyplot as plt
from sklearn.feature_selection import mutual_info_regression

In [2]:
def strobe(tvec,signal,period):
    n = 0
    strobedSignal = []
    strobedTime = []
    for t in tvec:
        if t % period == 0:
#             print(t)
            strobedTime.append(t)
            strobedSignal.append(signal[n])
        n += 1
    return [strobedTime, strobedSignal]

In [4]:
# compute the rank of a spike train
def sTrainToRank(sTrain, delta_t, duration):
    """
    This function takes a spike train as input (a list of times when spikes occur)
    and returns a rank for that spike train by converting it to a binary sequence
    and converting that to a base 10 number.
    
    For example, a spike train sTrain = np.array([0.5,1.3,3.4,6.7,8.9,9.3]) would
    get converted to the string 1101001011 if a time bin delta_t=1 was used. This
    binary string would then be converted to 843 in base 10.
    
    sTrain = the list of spike times
    delta_t = the size of the time bin
    duration = the length of the time series to consider
    """
    Ts = int(duration) # consider time sequences of length 4*ms

    binTrain = np.zeros(Ts//delta_t)
    count = 0
    for j in range(len(binTrain)):
        state = int(0)
        for s in sTrain:
            condition = (s >= j*delta_t and s < (j+1)*delta_t)
            if condition:
                state = 1
                count += 1 # spike count
        binTrain[j] = int(state)

    # convert the list binTrain to a string
    
    binTrainStr=""
    for i in binTrain:
        binTrainStr += str(int(i))

    # convert the binary sequence to a number in base 10
    binTrainBase10 = int(binTrainStr, 2)
    ret = [binTrainBase10, binTrainStr, count]
    return ret

sTrain = np.array([0.5,1.3,3.4,6.7,8.9,9.3])
delta_t = 1
duration = 10
binTrainBase10, binTrainStr, count = sTrainToRank(sTrain, delta_t, duration)
print(binTrainBase10)
print(binTrainStr)
print(f'spike count = {count}')

843
1101001011
spike count = 6


In [59]:
start_scope()
duration = 5000*ms
dt = defaultclock.dt
tVec = np.arange(0,duration/ms+dt/ms,dt/ms)
noise = np.random.normal(0,1,len(tVec))
f1 = 5*Hz # frequency of I(t) in (Hz)
f2 = 10*Hz
f3 = 2*Hz
det = np.sin(2*np.pi*tVec*f1/(1000*Hz)) + np.sin(2*np.pi*tVec*f2/(1000*Hz)) - np.sin(2*np.pi*tVec*f3/(1000*Hz))
signal = det + noise
# signal = noise
In = TimedArray(signal,dt)

tau = 10*ms
eqs = '''
dv/dt = (1-v+In(t))/tau : 1 (unless refractory)
'''

G = NeuronGroup(1, eqs, threshold='v>1', reset='v=0', refractory = 5*ms, method='exact')
M = StateMonitor(G, ['v'], record=True)
S = SpikeMonitor(G)
run(duration)

# fig, ax = plt.subplots(2,1)
# ax[0].plot(tVec, signal)
# ax[0].set_xticks([])
# ax[0].set_ylabel('stimulus')
# ax[1].plot(M.t/ms, M.v[0])
# ax[1].plot(S.t/ms,np.zeros(len(S.t/ms)),marker='o',linestyle='',color='black')
# ax[1].set_ylabel('voltage')
# ax[1].set_xlabel('time (ms)')
# plt.show()

In [60]:
T=50
st = S.t/ms % T
stOG = S.t/ms
delta_t = 5
stimVec = strobe(tVec,signal,T)[1]
rankVec = np.zeros(len(stimVec))
countVec = np.zeros(len(stimVec))
# print(f'stOG = {stOG}')
# print(f'st = {st}')
# print('\ndecomposed:')
n0 = 0
j = 0
for n in range(len(st)-1):
    if st[n+1] < st[n]:
        sub_st = st[n0:n+1]
#         print(sub_st)
        binTrainBase10, binTrainStr, count = sTrainToRank(sub_st, delta_t, T)
        print(f'{sub_st} = {binTrainStr} has rank {binTrainBase10} and spike count = {count}')
        rankVec[j] = binTrainBase10
        countVec[j] = count
        n0 = n+1
        j+=1
    if stOG[n+1] > duration/ms - T:
#         print(stOG[n+1])
        sub_st = st[n+1:]
        binTrainBase10, binTrainStr, count = sTrainToRank(sub_st, delta_t, T)
        print(f'{sub_st} = {binTrainStr} has rank {binTrainBase10} and spike count = {count}')
        rankVec[j] = binTrainBase10
        countVec[j] = count
        break

# print(f'\nstimVec = {stimVec}')
# print(f'rankVec = {rankVec}')
# print(f'countVec = {countVec}')

[12.8 23.3 34.7 49.5] = 0010101001 has rank 169 and spike count = 4
[15.6 25.6 35.5 46.9] = 0001010101 has rank 85 and spike count = 4
[12.5 44.1] = 0010000010 has rank 130 and spike count = 2
[ 7.3 18.2 29.9 46.6] = 0101010001 has rank 337 and spike count = 4
[46.5] = 0000000001 has rank 1 and spike count = 1
[ 6.3 15.1 23.5 31.5 40.7 49.7] = 0101101011 has rank 363 and spike count = 6
[11.3 41. ] = 0010000010 has rank 130 and spike count = 2
[12.6 22.6 35.9] = 0010100100 has rank 164 and spike count = 3
[ 3.6 12.9 21.5 29.7 38.1 46.6] = 1010110101 has rank 693 and spike count = 6
[ 6.  16.7 29.5 42.8] = 0101010010 has rank 338 and spike count = 4
[ 4.8 15.2 26.2 42.8] = 1001010010 has rank 594 and spike count = 4
[11.7 22.1 32.3 45.6] = 0010101001 has rank 169 and spike count = 4
[14.6 24.7 34.1 44.2] = 0010101010 has rank 170 and spike count = 4
[ 6.6 29.4 47.6] = 0100010001 has rank 273 and spike count = 3
[ 9.4 20.5 31.8 47.8] = 0100101001 has rank 297 and spike count = 4
[46.8] =

In [61]:
X = np.reshape(countVec,(len(countVec),1))
Y = stimVec
MI = mutual_info_regression(X,Y)[0]
print(MI)

0.10963847121227621


In [62]:
def computeSMI(X, Y, m, k):
    '''
    computes the sliced mutual information (see Goldfeld & Greenewald 2021)
        between X and Y with m slices.
    Assumes X has shape (dimX) X (obs) and Y has shape (dimY) X (obs) where
        'obs' is the number of observations.
    This code uses the KSG method for estimating mutual information, imported
        from sklearn.feature_selection
    '''
    dimX = X.shape[0]
    dimY = Y.shape[0]
    SMI = 0
    for mi in range(m):
        Theta = np.random.uniform(0,1,dimX)
        Phi = np.random.uniform(0,1,dimY)
        newX = np.dot(Theta,X)
        newX = np.reshape(newX, (len(newX),1))
        newY = np.dot(Phi,Y)
        Si = mutual_info_regression(newX, newY, n_neighbors=k)[0]
        SMI += Si/np.log(2)
    SMI = SMI/m
    return SMI

In [66]:
X = countVec + np.random.normal(0,10**(-4),len(countVec))
print(len(X))
X = np.reshape(X, (1,len(X)))
Y = np.array(stimVec)
print(len(Y))
Y = np.reshape(Y, (1,len(Y)))
m=3
k=3
MIfunc = computeSMI(X,Y,m,k)
print(MIfunc)

101
101
0.15247453922492588


In [40]:
print(countVec)

[1. 1. 2. 1. 2. 1. 1. 2. 0. 0. 0.]


In [69]:
#### SLICED MI USED HERE ####

m=3
k=1
SCs = np.unique(countVec) # all observed spike counts
# SCs = [1]
# print(SCs)
MItemp = 0
for sc in SCs:
    print(f'spike count = {sc}')
#     idx = np.argwhere(countVec==sc)[:,0]
#     idx = np.argwhere(countVec==sc)
    idx = np.nonzero(countVec==sc)[0]
    probSC = len(idx)/len(countVec)
    print(f'probability of spike count = {sc} is {probSC}')
    x = rankVec[idx]
#     print(rankVec[idx])
#     x = np.reshape(x,(len(x),1))
    y = np.array(stimVec)
    y = y[idx]
    X = np.reshape(x, (1,len(x)))
    Y = np.reshape(y, (1,len(y)))
    if len(x)>1: # don't use sub-samples with only 1 element
#         mi = mutual_info_regression(x, y, n_neighbors=1)[0]
        mi = computeSMI(X,Y,m,k)
        print(f'\tconditional timing MI = {mi}')
        MItemp+=probSC*mi
print(f'temporal MI = {MItemp}')
#     print(rankVec[idx])

Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7fb2e3d98160>
Traceback (most recent call last):
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<loc

spike count = 0.0
probability of spike count = 0.0 is 0.48514851485148514
	conditional timing MI = 1.6017132519074586e-16
spike count = 1.0
probability of spike count = 1.0 is 0.0297029702970297
	conditional timing MI = 0.0
spike count = 2.0
probability of spike count = 2.0 is 0.0891089108910891
	conditional timing MI = 1.3108296608288743
spike count = 3.0
probability of spike count = 3.0 is 0.07920792079207921
	conditional timing MI = 1.3662665536037746
spike count = 4.0
probability of spike count = 4.0 is 0.19801980198019803
	conditional timing MI = 0.18654813484905475
spike count = 5.0
probability of spike count = 5.0 is 0.0594059405940594
	conditional timing MI = 0.0
spike count = 6.0
probability of spike count = 6.0 is 0.0594059405940594
	conditional timing MI = 0.0
temporal MI = 0.2619659611214967


In [65]:
#### NO SLICED MI USED HERE ####

SCs = np.unique(countVec) # all observed spike counts
# print(SCs)
MItemp = 0
for sc in SCs:
    print(f'spike count = {sc}')
#     idx = np.argwhere(countVec==sc)[:,0]
#     idx = np.argwhere(countVec==sc)
    idx = np.nonzero(countVec==sc)[0]
    probSC = len(idx)/len(countVec)
    print(f'probability of spike count = {sc} is {probSC}')
    x = rankVec[idx]
#     print(rankVec[idx])
    x = np.reshape(x,(len(x),1))
    y = np.array(stimVec)
    y = y[idx]
    if len(x)>1: # don't use sub-samples with only 1 element
        mi = mutual_info_regression(x, y, n_neighbors=1)[0]
        print(f'\tconditional timing MI = {mi}')
        MItemp+=probSC*mi
print(f'temporal MI = {MItemp}')
#     print(rankVec[idx])

Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7fb2e3d98310>
Traceback (most recent call last):
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/home/zmobille3/anaconda/lib/python3.9/site-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'


spike count = 0.0
probability of spike count = 0.0 is 0.48514851485148514
	conditional timing MI = 1.1102230246251565e-16
spike count = 1.0
probability of spike count = 1.0 is 0.0297029702970297
	conditional timing MI = 0
spike count = 2.0
probability of spike count = 2.0 is 0.0891089108910891
	conditional timing MI = 0.9085978835978836
spike count = 3.0
probability of spike count = 3.0 is 0.07920792079207921
	conditional timing MI = 0.9470238095238095
spike count = 4.0
probability of spike count = 4.0 is 0.19801980198019803
	conditional timing MI = 0.13180531370933857
spike count = 5.0
probability of spike count = 5.0 is 0.0594059405940594
	conditional timing MI = 0
spike count = 6.0
probability of spike count = 6.0 is 0.0594059405940594
	conditional timing MI = 0
temporal MI = 0.18207601685899213
