In [1]:
import random
import numpy as np
import json
from tqdm import tqdm
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import linear_sum_assignment
from collections import Counter


In [2]:
def generate_synthetic_data(params):
    '''
    Synthetic Data Generation
    '''
    # Dense K: matrix of binary images of sizeNxDxM
    # Sparse K: set of (delay d, neuron a, and pg b)

    M,N,D,T,nrn_fr,pg_fr,background_noise_fr = params['M'], params['N'], params['D'], params['T'], params['nrn_fr'],params['pg_fr'],params['background_noise_fr'],    
    '''
    Synthetic Data Generation
    '''
    # Dense K: matrix of binary images of sizeNxDxM
    # Sparse K: set of (delay d, neuron a, and pg b)

    K_dense = np.random.rand(N,D,M)*1000
    nrn_frs = np.zeros((M))
    for m in range(M):
        nrn_frs[m] = np.random.poisson(nrn_fr)
        K_dense[:,:,m] = (K_dense[:,:,m] < nrn_frs[m]).astype('int')
    K_sparse = np.where(K_dense)
    K_sparse = (K_sparse[0],K_sparse[1],K_sparse[2]+1)


    # dense B: the binary image of the occurrences of the spiking motif as a ( M x T) matrix
    # spare B: set of all times t and pg's b
    B_dense = np.random.rand(M,T)*1000
    pg_frs = np.zeros((M))
    for m in range(M):
        pg_frs[m] = np.random.poisson(pg_fr)
        B_dense[m,:] = (B_dense[m,:] < pg_frs[m]).astype('int')
    B_sparse = np.where(B_dense)
    B_sparse = (B_sparse[0]+1,B_sparse[1])# This way the first motif starts at index 1 instead of index 0

    # now to make the full raster plot keeping the labels in-tact
    # dense A: the layered binary images of all neuron spikes by PG ( N x T x M
    A_dense = np.zeros((N,T+D,M+1))
    A_dense[...,0] = np.random.rand(N,T+D)*1000
    A_dense[...,0] = (A_dense[...,0] < background_noise_fr).astype('int')
    for i in range(len(B_sparse[0])):
        t = B_sparse[1][i]
        b = B_sparse[0][i]
        A_dense[:, t:t+D, b] += K_dense[...,b-1]

    A_sparse = np.where(A_dense)
    A_dense = np.sum(A_dense,axis=2)
    A_dense[A_dense>1] = 1
    return A_dense, A_sparse, B_dense, B_sparse, K_dense, K_sparse

In [10]:
def scan_raster(T_labels, N_labels, window_dim = None):
    '''
    T_labels an array of spiketimes
    N_labels corresponding array of neuron labels
    window_dim is the size of the window to cluster the spikes
    '''
    print(f'Cleaning spikes...',end='\r')
    if window_dim == None:
        window_dim = 100
        
    T_labels = np.round(T_labels).astype(int)
    T_labels, N_labels = np.unique(np.array([T_labels,N_labels]),axis=1) # This removes any spikes that occur at the same neuron at the same time
    N=max(N_labels)+1

    print(f'Windowing... {len(T_labels)}')
    windows = np.zeros((len(T_labels)),dtype='object')
    for i,window_time in enumerate(T_labels):
        condition = (T_labels > window_time-window_dim) & (T_labels < window_time + window_dim)
        window = np.array([T_labels[condition]-window_time, N_labels[condition]]).T
        window =  {tuple(row) for row in  window}
        windows[i] = window

        
    # Set the cutoff value for clustering
    cutoff = 0
    lr = 0.01

    max_iter=50
    lr = 0.01
    iter_ = 0

    opt_cutoff = cutoff
    max_seq_rep = 0
    sim_mats = _get_sim_mats(windows, T_labels, N_labels)

    while iter_ <= max_iter: # this is just a for loop...
        clusters = _cluster_windows(cutoff, N_labels, sim_mats)
        cluster_sq, _sq_counts, sublist_keys_filt = _check_seq(clusters, T_labels, N_labels)

        if len(sublist_keys_filt) != 0:
            max_ = np.max([len(k) for k in sublist_keys_filt])
            if max_seq_rep < max_:
                max_seq_rep = max_
                opt_cutoff=cutoff

        cutoff += lr
        iter_ +=1


        print(f'iter - {iter_/max_iter} | cutoff - {cutoff} | opt_cutoff - {opt_cutoff} | most_detections - {max_seq_rep}',end='\r')

    clusters = _cluster_windows(opt_cutoff, N_labels, sim_mats)
    cluster_sq, sq_counts, sublist_keys_filt = _check_seq(clusters, T_labels, N_labels)
        

    ''' to get the timings'''

    # Sort y according to x
    sorted_indices = np.argsort(T_labels)
    sorted_x = T_labels[sorted_indices]

    all_times = []
    all_labels = []
    for key in sublist_keys_filt:
        pattern_repetition_labels = np.zeros((len(cluster_sq[str(key)]),len(clusters)))
        for i,k in enumerate(cluster_sq[str(key)]):
            pattern_repetition_labels[i][clusters==k] = 1
            pattern_repetition_labels[i] *= np.cumsum(pattern_repetition_labels[i])
        pattern_repetition_labels = np.sum(pattern_repetition_labels,axis=0,dtype='int')
        all_labels.append(pattern_repetition_labels)

        sorted_y = pattern_repetition_labels[sorted_indices]
        pattern_times = np.array([sorted_x[sorted_y==i][0] for i in range(1,max(pattern_repetition_labels)+1)])
        all_times.append(pattern_times)

    pattern_template = []
    patterns = []
    for i in range(len(all_times)):
        pattern = []
        pattern_template.append([])
        for time in all_times[i]:
            condition = (T_labels > time-window_dim*2) & (T_labels < time + window_dim*2)
            pattern = [tuple(k) for k in np.array([T_labels[condition]-time, N_labels[condition]]).T] # creating a list of tuples
            pattern_template[-1] += pattern # adds all points of each pattern to template_pattern
            patterns.append(pattern)

    for i,pattern in enumerate(pattern_template):
        counts = [pattern.count(k) for k in pattern]
        pattern_template[i] = np.array(pattern)[np.where(counts == np.max(counts))[0]]
        pattern_template[i][:,0] -= min(pattern_template[i][:,0])
        pattern_template[i] = np.unique(pattern_template[i],axis=0)
    
    if len(pattern_template) == 0:
        return pattern_template, sublist_keys_filt, None
    
    win_size = (N,1+max([max(k[:,0]) for k in pattern_template]))
    pattern_img = np.zeros((len(pattern_template),*win_size))
    for p,pattern in enumerate(pattern_template):
        for (i,j) in pattern:
            pattern_img[p,j,i] = 1

    return pattern_template, sublist_keys_filt, pattern_img

def _get_sim_mats(windows, T_labels, N_labels):
    sim_mats = np.zeros(np.max(N_labels),dtype='object')
    for n in np.unique(N_labels):
        idc = np.where(N_labels==n)[0]
        windows_n = windows[idc]
        if len(windows_n) > 1:
            x = np.zeros((len(windows_n),len(windows_n)))
            for i in range(windows_n.shape[0]):
                for j in range(windows_n.shape[0]):
                    common_rows = windows_n[i].intersection(windows_n[j])
                    num_identical_rows = len(common_rows)
                    x[i,j] = len(common_rows)/min(len(windows_n[i]),len(windows_n[j]))
            np.fill_diagonal(x,0)# make sure the diagonals are zero, this is important the more spikes there are...
            sim_mats[n] = x-1 
    return sim_mats

def _cluster_windows(cutoff, N_labels, sim_mats):
    clusters = np.zeros_like(N_labels)
    for n in np.unique(N_labels):
        idc = np.where(N_labels==n)[0]
        if (type(sim_mats[n]) == np.ndarray) and (not np.all(sim_mats[n] == 0)):
            l = max(clusters)+1
            clusters[idc]= l+fcluster(linkage(sim_mats[n], method='complete'), cutoff, criterion='distance')
    return clusters

def _check_seq(clusters, T_labels, N_labels):

    time_differences = []
    cluster_sq = {}
    for cluster in np.unique(clusters):
        temp = list(np.diff(np.unique(T_labels[clusters == cluster])))
        str_temp = str(temp)
        time_differences.append(temp)
        if str_temp in cluster_sq.keys():
            cluster_sq[str_temp] = cluster_sq[str_temp] + [cluster]
        else:
            cluster_sq[str_temp] = [cluster]

    # Convert the list of lists to a set of tuples to remove duplicates
    unique_sublists_set = set(tuple(sublist) for sublist in time_differences if sublist)

    # Convert the set of tuples back to a list of lists
    unique_sublists = [list(sublist) for sublist in unique_sublists_set]

    # Count the occurrences of each unique sublist in the original list
    sublist_counts = Counter(tuple(sublist) for sublist in time_differences if sublist)

    # Print the unique sublists and their respective counts
    sq_counts = np.zeros(len(sublist_counts)) 
    for i,sublist in enumerate(unique_sublists):
        count = sublist_counts[tuple(sublist)]
        sq_counts[i] = count
    #     print(f"{sublist}: {count} occurrences")
    sublist_keys_np = np.array([list(key) for key in sublist_counts.keys()],dtype='object')
    sublist_keys_filt = sublist_keys_np[np.array(list(sublist_counts.values())) >1] # only bother clustering repetitions that appear for more than one neuron
    
    return cluster_sq, sq_counts, sublist_keys_filt


In [11]:
from scipy.signal import correlate
def get_acc(ground_truths,detected_patterns):
    # Calculate cross-correlation matrix
    cross_corr_matrix = np.zeros((ground_truths.shape[2], detected_patterns.shape[2]))
    SM_acc = np.zeros((ground_truths.shape[2]))
    
    if len(detected_patterns == 0):
        return SM_acc, cross_corr_matrix
    
    for ground_truths_idx in range(ground_truths.shape[2]):
        for detected_patterns_idx in range(detected_patterns.shape[2]):
            cross_corr = np.zeros((ground_truths.shape[1]+detected_patterns.shape[1]-1))
            for n in range(ground_truths.shape[0]):
                cross_corr += correlate(ground_truths[n, :, ground_truths_idx], detected_patterns[n, :, detected_patterns_idx], mode='full')
            max_corr = np.max(cross_corr) / max(np.sum(ground_truths[...,ground_truths_idx]),np.sum(detected_patterns[...,detected_patterns_idx]))
            cross_corr_matrix[ground_truths_idx, detected_patterns_idx] = max_corr
#     print(cross_corr_matrix)
#     print( np.sum(ground_truths[...,ground_truths_idx]))
    SM_acc = np.max(cross_corr_matrix,axis=1)
    return SM_acc, cross_corr_matrix

In [12]:
def main():
        # Define the number of random samples you want to take
    num_samples = 5  # Adjust this based on your computational resources
    
    trials = 1
    
    # List to hold the results
    results = []
    
    param_combinations = np.array(np.meshgrid(*scan_dict.values())).T.reshape(-1, len(scan_dict))
    num_iterations = len(param_combinations)
    
    # Generate random indices for sampling
    random_indices = random.sample(range(num_iterations), num_samples)
    
    # Iterate through parameter combinations
    for idx in tqdm(random_indices):
        for trial in range(0,trials):
            seed=trial
            np.random.seed(seed)
            params = {key: int(val) for key, val in zip(scan_dict.keys(), param_combinations[idx])}

            # Run your program here to generate performance results
            print("Params:", params)
            print("Generating raster plot...")
            _, A_sparse, _, B_sparse, K_dense, K_sparse = generate_synthetic_data(params)
            print("Clustering...")
            pattern_template, sublist_keys_filt, pattern_img = scan_raster(A_sparse[1],A_sparse[0],window_dim=params['D'])
            if type(pattern_img) != np.ndarray:
                performance_result = (0,0)
            else:
                pattern_img = np.transpose(pattern_img,axes=[1,2,0])
                SM_acc, _ = get_acc(K_dense, pattern_img)
                performance_result = (np.sum(SM_acc>0.8)/len(SM_acc), np.mean(SM_acc))

            # Create a dictionary to store the result
            result = {
                'idc': idx,
                'trial':trial,
                'data':[A_sparse,K_sparse,B_sparse],
                **params,  # Unpack the parameters as separate columns
                'performance':performance_result
            }
            
            print(performance_result)

            # Append the result to the list
            results.append(result)
        # Write the entire list of results to a JSON file
    with open('scan_stats.json', 'w') as results_file:
        json.dump(results, results_file, indent=4)
        
    return results


In [13]:
'''
Model default parameters
'''

M = 4 # Number of Spiking motifs
N = 20 # Number of input neurons
D = 71 # temporal depth of receptive field
T = 1000
nrn_fr = 15 # hz
pg_fr = 6 # hz
background_noise_fr = 10 # hz
seed=41
np.random.seed(seed)

In [14]:
default_params = {
    'M':M,
    'N':N,
    'D':D,
    'T':T,
    'nrn_fr':nrn_fr,
    'pg_fr':pg_fr,
    'background_noise_fr':background_noise_fr,
    'seed':seed
}
scan_dict = {
    'M':[1,4,16,32,64],
    'N':[5,30,60,100],
    'D':[10,30,70,150],
    'T':[1000],
    'nrn_fr':[5,10,15],
    'pg_fr':[3,4,5,8,10],
    'background_noise_fr':[0,1,2,5,10]
}

In [17]:
_, A_sparse, _, B_sparse, K_dense, K_sparse = generate_synthetic_data(default_params)

In [18]:
A_sparse[0]

array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,
        3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
        5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
        5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
        5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6

In [15]:
results = main()

  0%|                                                                                            | 0/5 [00:00<?, ?it/s]

Params: {'M': 32, 'N': 5, 'D': 150, 'T': 1000, 'nrn_fr': 5, 'pg_fr': 5, 'background_noise_fr': 0}
Generating raster plot...
Clustering...
Cleaning spikes...Windowing... 706


  0%|                                                                                            | 0/5 [00:02<?, ?it/s]


IndexError: index 4 is out of bounds for axis 0 with size 4

In [19]:
params = {'M': 32, 'N': 5, 'D': 150, 'T': 1000, 'nrn_fr': 5, 'pg_fr': 5, 'background_noise_fr': 0}

In [21]:
_, A_sparse, _, B_sparse, K_dense, K_sparse = generate_synthetic_data(params)
N_labels, T_labels = A_sparse[0], A_sparse[1]

In [22]:

print(f'Cleaning spikes...',end='\r')
if window_dim == None:
    window_dim = 100

T_labels = np.round(T_labels).astype(int)
T_labels, N_labels = np.unique(np.array([T_labels,N_labels]),axis=1) # This removes any spikes that occur at the same neuron at the same time
N=max(N_labels)+1

print(f'Windowing... {len(T_labels)}')
windows = np.zeros((len(T_labels)),dtype='object')
for i,window_time in enumerate(T_labels):
    condition = (T_labels > window_time-window_dim) & (T_labels < window_time + window_dim)
    window = np.array([T_labels[condition]-window_time, N_labels[condition]]).T
    window =  {tuple(row) for row in  window}
    windows[i] = window


# Set the cutoff value for clustering
cutoff = 0
lr = 0.01

max_iter=50
lr = 0.01
iter_ = 0

opt_cutoff = cutoff
max_seq_rep = 0
sim_mats = _get_sim_mats(windows, T_labels, N_labels)

while iter_ <= max_iter: # this is just a for loop...
    clusters = _cluster_windows(cutoff, N_labels, sim_mats)
    cluster_sq, _sq_counts, sublist_keys_filt = _check_seq(clusters, T_labels, N_labels)

    if len(sublist_keys_filt) != 0:
        max_ = np.max([len(k) for k in sublist_keys_filt])
        if max_seq_rep < max_:
            max_seq_rep = max_
            opt_cutoff=cutoff

    cutoff += lr
    iter_ +=1


    print(f'iter - {iter_/max_iter} | cutoff - {cutoff} | opt_cutoff - {opt_cutoff} | most_detections - {max_seq_rep}',end='\r')

clusters = _cluster_windows(opt_cutoff, N_labels, sim_mats)
cluster_sq, sq_counts, sublist_keys_filt = _check_seq(clusters, T_labels, N_labels)


''' to get the timings'''

# Sort y according to x
sorted_indices = np.argsort(T_labels)
sorted_x = T_labels[sorted_indices]

all_times = []
all_labels = []
for key in sublist_keys_filt:
    pattern_repetition_labels = np.zeros((len(cluster_sq[str(key)]),len(clusters)))
    for i,k in enumerate(cluster_sq[str(key)]):
        pattern_repetition_labels[i][clusters==k] = 1
        pattern_repetition_labels[i] *= np.cumsum(pattern_repetition_labels[i])
    pattern_repetition_labels = np.sum(pattern_repetition_labels,axis=0,dtype='int')
    all_labels.append(pattern_repetition_labels)

    sorted_y = pattern_repetition_labels[sorted_indices]
    pattern_times = np.array([sorted_x[sorted_y==i][0] for i in range(1,max(pattern_repetition_labels)+1)])
    all_times.append(pattern_times)

pattern_template = []
patterns = []
for i in range(len(all_times)):
    pattern = []
    pattern_template.append([])
    for time in all_times[i]:
        condition = (T_labels > time-window_dim*2) & (T_labels < time + window_dim*2)
        pattern = [tuple(k) for k in np.array([T_labels[condition]-time, N_labels[condition]]).T] # creating a list of tuples
        pattern_template[-1] += pattern # adds all points of each pattern to template_pattern
        patterns.append(pattern)

for i,pattern in enumerate(pattern_template):
    counts = [pattern.count(k) for k in pattern]
    pattern_template[i] = np.array(pattern)[np.where(counts == np.max(counts))[0]]
    pattern_template[i][:,0] -= min(pattern_template[i][:,0])
    pattern_template[i] = np.unique(pattern_template[i],axis=0)

if len(pattern_template) == 0:
    return pattern_template, sublist_keys_filt, None

win_size = (N,1+max([max(k[:,0]) for k in pattern_template]))
pattern_img = np.zeros((len(pattern_template),*win_size))
for p,pattern in enumerate(pattern_template):
    for (i,j) in pattern:
        pattern_img[p,j,i] = 1

return pattern_template, sublist_keys_filt, pattern_img

def _get_sim_mats(windows, T_labels, N_labels):
intersect = lambda a,b : a[((a[:,None] == B).all(-1).any(1))]
sim_mats = np.zeros(np.max(N_labels),dtype='object')
for n in np.unique(N_labels):
    idc = np.where(N_labels==n)[0]
    windows_n = windows[idc]
    if len(windows_n) > 1:
        x = np.zeros((len(windows_n),len(windows_n)))
        for i in range(windows_n.shape[0]):
            for j in range(windows_n.shape[0]):
                common_rows = windows_n[i].intersection(windows_n[j])
                num_identical_rows = len(common_rows)
                x[i,j] = len(common_rows)/min(len(windows_n[i]),len(windows_n[j]))
        np.fill_diagonal(x,0)# make sure the diagonals are zero, this is important the more spikes there are...
        sim_mats[n] = x-1 
return sim_mats

def _cluster_windows(cutoff, N_labels, sim_mats):
clusters = np.zeros_like(N_labels)
for n in np.unique(N_labels):
    idc = np.where(N_labels==n)[0]
    if (type(sim_mats[n]) == np.ndarray) and (not np.all(sim_mats[n] == 0)):
        l = max(clusters)+1
        clusters[idc]= l+fcluster(linkage(sim_mats[n], method='complete'), cutoff, criterion='distance')
return clusters

def _check_seq(clusters, T_labels, N_labels):

time_differences = []
cluster_sq = {}
for cluster in np.unique(clusters):
    temp = list(np.diff(np.unique(T_labels[clusters == cluster])))
    str_temp = str(temp)
    time_differences.append(temp)
    if str_temp in cluster_sq.keys():
        cluster_sq[str_temp] = cluster_sq[str_temp] + [cluster]
    else:
        cluster_sq[str_temp] = [cluster]

# Convert the list of lists to a set of tuples to remove duplicates
unique_sublists_set = set(tuple(sublist) for sublist in time_differences if sublist)

# Convert the set of tuples back to a list of lists
unique_sublists = [list(sublist) for sublist in unique_sublists_set]

# Count the occurrences of each unique sublist in the original list
sublist_counts = Counter(tuple(sublist) for sublist in time_differences if sublist)

# Print the unique sublists and their respective counts
sq_counts = np.zeros(len(sublist_counts)) 
for i,sublist in enumerate(unique_sublists):
    count = sublist_counts[tuple(sublist)]
    sq_counts[i] = count
#     print(f"{sublist}: {count} occurrences")
sublist_keys_np = np.array([list(key) for key in sublist_counts.keys()],dtype='object')
sublist_keys_filt = sublist_keys_np[np.array(list(sublist_counts.values())) >1] # only bother clustering repetitions that appear for more than one neuron

return cluster_sq, sq_counts, sublist_keys_filt

IndentationError: expected an indented block after function definition on line 98 (3102548471.py, line 99)