In [14]:
from general import *

In [15]:
class MotionPrimitiveNeuron():
    NeuronState = namedtuple('NeuronState', ['V', 'w', 'spk', 'I'])

    def __init__(self, parameters):
        super(MotionPrimitiveNeuron, self).__init__()
        self.tau = parameters['tau']
        self.V_R = parameters['V_R']
        self.V_T = parameters['V_T']
        self.w = parameters['w']
        self.n = parameters['n']
        self.N_input = parameters['N_input']
        self.dt = parameters['dt']
        self.state = None

    def initialize_state(self):
        self.state = None

    def forward(self, input):
        if self.state is None:
            self.state = self.NeuronState(V=np.full((self.n,), self.V_R),
                                          w=np.array(self.w),
                                          spk=np.zeros(self.n),
                                          I=np.zeros((self.n, self.N_input)))
            
        V = self.state.V
        w = self.state.w
        I = self.state.I
        

        V += self.dt*(self.V_R-V)/self.tau
        V += np.sum(w*input, axis=-1)

        spk = np.heaviside(V - self.V_T, 0)
        V = (1 - spk) * V + spk * self.V_R

        self.state = self.NeuronState(V=V, w=w, spk=spk, I=I)

        return V, spk

In [16]:
def get_encoding(w_pos=[0, 0, 0, 0, 0, 0, 0], w_vel=[0, 0, 0, 0, 0, 0, 0], n=6):
    encoding = ['none', 'Vel-', 'Vel+', 'Pos-', 'Pos+']
    perm = np.array(list(product(encoding, repeat=3)))
    base_perm = np.array(list(product([-np.inf, 0, 1, 2, 3], repeat=3)))

    synapse_type = []

    for permutation in perm:

        if 'none' in permutation:
            synapse_type.append(0 if (permutation == 'Pos+').sum() + (permutation == 'Pos-').sum() == 2
                                else
                                (1 if (permutation == 'Vel+').sum() + (permutation == 'Vel-').sum() == 2
                                 else
                                 (2 if (permutation == 'Pos+').sum() + (permutation == 'Pos-').sum() == 1
                                       and (permutation == 'Vel+').sum() + (permutation == 'Vel-').sum() == 1
                                  else -1)))
        else:
            synapse_type.append(3 if (permutation == 'Pos+').sum() + (permutation == 'Pos-').sum() == 2
                                else
                                (4 if (permutation == 'Vel+').sum() + (permutation == 'Vel-').sum() == 2
                                 else
                                 (5 if (permutation == 'Pos+').sum() + (permutation == 'Pos-').sum() == 3
                                  else 6)))

    zero_index = np.where(np.array(synapse_type) == -1)[0]

    synapse_type = list(np.delete(synapse_type, zero_index))
    perm = np.delete(perm, zero_index, axis=0)
    base_perm = np.delete(base_perm, zero_index, axis=0)
    weights = np.zeros_like(perm, dtype=float)

    for i, j in np.ndindex(weights.shape):
        if 'Pos' in perm[i, j]:
            weights[i, j] = w_pos[synapse_type[i]]
        elif 'Vel' in perm[i, j]:
            weights[i, j] = w_vel[synapse_type[i]]

    negative_mask = np.zeros_like(perm, dtype=float)
    negative_mask[perm == 'none'] = 1
    negative_mask = np.tile(negative_mask, (6, 1))
    positive_mask = 1 - negative_mask

    weights = np.tile(weights, (6, 1))
    synapse_type = synapse_type * 6

    extra = np.array([0, 4, 8])
    extra = np.tile(extra, (base_perm.shape[0], 1))
    final_perm = (base_perm + extra).clip(min=0)

    extra_2 = np.linspace(0, 12 * (n - 1), num=n).repeat(3 * final_perm.shape[0])
    final_perm = (np.tile(final_perm.flatten(), n) + extra_2).astype(int)

    base_perm = base_perm + 1
    base_perm[base_perm == -np.inf] = 0

    return perm, base_perm.astype(int), final_perm, synapse_type, weights, positive_mask, negative_mask


def prepare_spikes_primitive(spike_velocity, spike_position, permutations, mask):
    toepel = ()
    for i in range(18):
        toepel += (spike_velocity[[0 + 2 * i, 1 + 2 * i]], spike_position[[0 + 2 * i, 1 + 2 * i]])

    pos_vel_spikes = np.concatenate(toepel)
    pos_vel_spikes = pos_vel_spikes[permutations].reshape(mask.shape) * mask

    return pos_vel_spikes


def convert_to_bins(old_array, n_bins, minimum=0, sum_bool=False):
    old_array = np.transpose(old_array)

    while old_array.shape[1] != n_bins:
        try:
            old_array = np.sum(old_array.reshape(old_array.shape[0], n_bins, -1), axis=2)
        except:
            old_array = np.column_stack((old_array, old_array[:, -1]))

    if not sum_bool:
        old_array = (old_array > minimum).astype(int)

    return np.transpose(old_array)

In [17]:
with open('data/joint_angles_nostep', 'rb') as file:
    joint_angles = pickle.load(file)

with open('data/spike_position', 'rb') as file:
    spike_position = pickle.load(file)
    
with open('data/spike_motion', 'rb') as file:
    spike_motion = pickle.load(file) 
    
W_POS = [16.5e-3, 0, 16.5e-3, 15e-3, 16.5e-3, 9e-3, 0e-3]
W_VEL = [0e-3, 10.5e-3, 12e-3, 3e-3, 7.5e-3, 0e-3, 10e-3]

permutations_string, permutations_base, permutations_final, synapse_type, weights_primitive, mask_positive, mask_negative = get_encoding(W_POS, W_VEL)

time = np.linspace(0, constants['T_TOTAL'], num=parameters['N_STEPS'])

primitive_parameters = {'tau': 0.5e-3, 'V_T': -50e-3, 'V_R': -70e-3, 'n': len(synapse_type), 'w': weights_primitive, 'N_input': 3, 'dt': constants['dt']}

primitive_neuron = MotionPrimitiveNeuron(primitive_parameters)
primitive_neuron.initialize_state()

spike_primitive = np.empty((parameters['N_STEPS'], primitive_parameters['n'], parameters['N_SIMULATIONS']))

In [18]:
for k in range(parameters['N_SIMULATIONS']):        
    for i in range(parameters['N_STEPS']):
        pos_vel_spikes = prepare_spikes_primitive(spike_motion[i, :, k], spike_position[i, :, k], permutations_final, mask_positive)
        _, spike_primitive[i, :, k] = primitive_neuron.forward(pos_vel_spikes)
        
with open('data/spike_primitive', 'wb') as file:
        pickle.dump(spike_primitive, file)

In [19]:
ground_truth = np.zeros_like(spike_primitive)
true_positive, false_positive, true_negative, false_negative = [np.zeros((len(synapse_type), parameters['N_SIMULATIONS'])) for _ in range(4)]
true_positive_groups, false_positive_groups, true_negative_groups, false_negative_groups = [np.zeros(7) for _ in range(4)]
mcc_list = np.zeros(7)

In [20]:
for k in range(parameters['N_SIMULATIONS']): 
    ground_mot, ground_pos = np.zeros((parameters['N_STEPS'], 36)), np.zeros((parameters['N_STEPS'], 36))
    for i in range(constants['N_ANGLES']):
        mid = np.max(joint_angles[:, i, k]) / 2 + np.min(joint_angles[:, i, k]) / 2
        diff = np.diff(joint_angles[:, i, k]) / constants['dt']

        ground_mot[np.where(diff < 0), 0 + 2 * i] = 1
        ground_mot[np.where(diff > 0), 1 + 2 * i] = 1
        ground_pos[np.where(joint_angles[:, i, k] < mid), 0 + 2 * i] = 1
        ground_pos[np.where(joint_angles[:, i, k] > mid), 1 + 2 * i] = 1
        
    for j in range(parameters['N_STEPS']):
        ground_truth_j = prepare_spikes_primitive(ground_mot[j, :], ground_pos[j, :], permutations_final, mask_positive) + mask_negative
        ground_truth_j = np.sum(ground_truth_j, axis=1)
        
        ground_truth[j, ground_truth_j > 2.9, k] = 1
        ground_truth[j, ground_truth_j < 2.9, k] = 0

    ground_truth_bins = convert_to_bins(ground_truth[:, :, k], 100)
    spike_primitive_bins = convert_to_bins(spike_primitive[:, :, k], 100)
    
    for i in range(primitive_parameters['n']):
        true_positive[i, k], false_positive[i, k], true_negative[i, k], false_negative[i, k] = (
            get_confusion_matrix(spike_primitive_bins[:, i], ground_truth_bins[:, i]))

for i in range(7):
    indices = np.where(np.array(synapse_type) == i)
    
    true_positive_groups[i] = np.sum(true_positive[indices, :])
    true_negative_groups[i] = np.sum(true_negative[indices, :])
    false_positive_groups[i] = np.sum(false_positive[indices, :])
    false_negative_groups[i] = np.sum(false_negative[indices, :])
    
    mcc = matthews_correlation(true_positive_groups[i], true_negative_groups[i], false_positive_groups[i], false_negative_groups[i])
        
    mcc_list[i] = mcc
    
average_mcc = np.average(mcc_list, weights=np.bincount(synapse_type))

In [21]:
table = {'MCC': np.append(np.around(mcc_list, 3), np.around(average_mcc, 3))}
df = pd.DataFrame(data=table, index=['p-p', 'v-v', 'p-v', 'p-p-v', 'v-v-p', 'p-p-p', 'v-v-v', 'mean'])
df.to_csv('results/primitive_accuracy_table.csv')

print(df)

         MCC
p-p    0.754
v-v    0.391
p-v    0.534
p-p-v  0.548
v-v-p  0.473
p-p-p  0.427
v-v-v  0.359
mean   0.512
