In [2]:
import random
import numpy as np
import json
from tqdm import tqdm
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import linear_sum_assignment
from collections import Counter
import time
import os
import pandas as pd
from scipy.signal import correlate
np. set_printoptions(threshold=np. inf)

In [3]:
date = 'Sept27th-'
filename = ''.join((date,'scan_stats_SPADE.json'))
filename

'Sept27th-scan_stats_SPADE.json'

In [4]:
import scan
import simulate_data

In [5]:
default_params = {
    'M':8,
    'N':40,
    'D':51,
    'T':1000,
    'seed':0,
    'num_SM_events':16,
    'SM_total_spikes':10,
    'noise':100
}
scan_dict = {
    'M':[1,4,8,16,32],
    'N':[10,20,40,80,120],
    'D':[11,31,51,71,101],
    'num_SM_events':[2,4,8,16,32],
    'SM_total_spikes':[3,5,10,20,50],
    'noise':[0,50,100,500,1000]
}

In [6]:
def get_acc(matrix_x,matrix_y):
    # Calculate cross-correlation matrix
    cc = np.zeros((matrix_x.shape[2], matrix_y.shape[2]))

    for x_channel_idx in range(matrix_x.shape[2]):
        for y_channel_idx in range(matrix_y.shape[2]):
            cc[x_channel_idx, y_channel_idx], _ = max_overlap(matrix_x[...,x_channel_idx],matrix_y[...,y_channel_idx])
    SM_acc = np.max(cc,axis=1)
    return SM_acc, cc

In [7]:
def max_overlap(image, kernel):
    result = np.zeros((image.shape[1]+kernel.shape[1]-1))
    for n in range(image.shape[0]):
        result += correlate(image[n,:], kernel[n,:], mode = 'full')
    return np.max(result)/max(np.sum(image),np.sum(kernel)), np.argmax(result)

In [8]:
def get_imgs(K_dense, pattern_template):
    if len(pattern_template) == 0:
        print('FAIL')
        return pattern_template, None
    
    win_size = (K_dense.shape[0],1+max([max(k[:,0]) for k in pattern_template]))
    pattern_img = np.zeros((len(pattern_template),*win_size))
    for p,pattern in enumerate(pattern_template):
        for (i,j) in pattern:
            pattern_img[p,j,i] = 1
            
    pattern_img = pattern_img.transpose((1,2,0))
    
    return pattern_template, pattern_img

In [9]:
from scipy.signal import correlate
def get_acc(ground_truths,detected_patterns):
    # Calculate cross-correlation matrix
    cross_corr_matrix = np.zeros((ground_truths.shape[2], detected_patterns.shape[2]))
    SM_acc = np.zeros((ground_truths.shape[2]))
    
    if len(detected_patterns) == 0:
        return SM_acc, cross_corr_matrix
    
    for ground_truths_idx in range(ground_truths.shape[2]):
        for detected_patterns_idx in range(detected_patterns.shape[2]):
            cross_corr = np.zeros((ground_truths.shape[1]+detected_patterns.shape[1]-1))
            for n in range(ground_truths.shape[0]):
                cross_corr += correlate(ground_truths[n, :, ground_truths_idx], detected_patterns[n, :, detected_patterns_idx], mode='full')
            max_corr = np.max(cross_corr) / max(np.sum(ground_truths[...,ground_truths_idx]),np.sum(detected_patterns[...,detected_patterns_idx]))
            cross_corr_matrix[ground_truths_idx, detected_patterns_idx] = max_corr
#     print(cross_corr_matrix)
#     print( np.sum(ground_truths[...,ground_truths_idx]))
    SM_acc = np.max(cross_corr_matrix,axis=1)
    return SM_acc, cross_corr_matrix

In [10]:
def check_ground_truth(pattern_template, K_dense):
    _, pattern_img = get_imgs(K_dense, pattern_template)
    SM_acc, cc = get_acc(K_dense, pattern_img)
    return SM_acc, cc, pattern_img
    

In [11]:
param_combinations = []
for param_name, param_values in scan_dict.items():
    for param_value in param_values:
        params = default_params.copy()
        params[param_name] = param_value
        param_combinations.append(params)

In [12]:
len(param_combinations)

30

In [13]:
# Special SPADE packages...
import quantities as pq
import neo
import elephant
import viziphant
from neo.core import SpikeTrain

In [14]:
def fix_spade(patterns, win_size):
    spade_patterns = []
    for pattern in patterns:
        spade_patterns.append(np.array([np.sort(np.array(pattern['itemset'])) % win_size, np.sort(np.array(pattern['neurons']))]).T)
    return spade_patterns

In [None]:
import numpy as np
import random
import os
import pandas as pd
import json
from tqdm import tqdm

# Define default parameters and scan values
default_params = {
    'M': 8,
    'N': 40,
    'D': 51,
    'T': 1000,
    'seed': 0,
    'num_SM_events': 8,
    'SM_total_spikes': 10,
    'noise': 100
}

scan_dict = {
    'M': [1, 4, 8, 16, 32],
    'N': [10, 20, 40, 80, 120],
    'D': [11, 31, 51, 71, 101], 
    'num_SM_events': [2, 4, 8, 16, 32],
    'SM_total_spikes': [3, 5, 10, 20, 50],
    'noise': [0, 50, 100, 500, 1000]
}

# Generate all parameter combinations
param_combinations = []
for param_name, param_values in scan_dict.items():
    for param_value in param_values:
        params = default_params.copy()
        params[param_name] = param_value
        param_combinations.append(params)

num_samples = len(param_combinations)
results = []

# Iterate through parameter combinations
for idx, params in tqdm(enumerate(param_combinations)):

    if os.path.isfile(filename):
        df = pd.read_json(filename)
        with open(filename, 'r') as results_file:
            results = json.load(results_file)
            
    print(params)
    A_dense, A_sparse, B_dense, B_sparse, K_dense, K_sparse = simulate_data.generate_synthetic_data(params)
    
    if len(A_sparse[0]) <= 2000: # and idx not in df['idx'].tolist():
        start = time.time()
        spike_trains = [SpikeTrain(A_sparse[1][A_sparse[0]==n] ,units= pq.ms, t_stop=params['T']) for n in range(params['N'])]
        patterns = elephant.spade.spade(spike_trains, bin_size=pq.ms, winlen=params['D'], output_format='patterns')['patterns']
        spade_patterns = fix_spade(patterns, params['D'])
        win_size = (K_dense.shape[0],1+max([max(k[:,0]) for k in spade_patterns]))
        spade_imgs = np.zeros((*win_size,len(spade_patterns)))
        for p, pattern in enumerate(spade_patterns):
            for (i,j) in pattern: 
                spade_imgs[j,i,p] = 1

        SM_acc, cc = get_acc(K_dense,spade_imgs)


        end = time.time()

        window_time = np.nan
        cluster_time = np.nan
        sequence_time = np.nan

        result = {
            'idx':idx,
            'M':params['M'],
            'N':params['N'],
            'D':params['D'],
            'T':params['T'],
            'num_SM_events':params['num_SM_events'],
            'SM_total_spikes':params['SM_total_spikes'],
            'noise':params['noise'],
            'window_time': window_time,
            'cluster_time': cluster_time,
            'sequence_time': sequence_time,
            'total_time': end-start,
            'total_spikes':len(A_sparse[1]),
            'total_patterns':len(spade_patterns),
            'SM_acc':SM_acc.tolist()
        }
    else:
        result = {
            'idx':idx,
            'M':params['M'],
            'N':params['N'],
            'D':params['D'],
            'T':params['T'],
            'num_SM_events':params['num_SM_events'],
            'SM_total_spikes':params['SM_total_spikes'],
            'noise':params['noise'],
            'window_time': np.nan,
            'cluster_time': np.nan,
            'sequence_time': np.nan,
            'total_time': np.nan,
            'total_spikes':len(A_sparse[1]),
            'total_patterns':np.nan,
            'SM_acc':[np.nan]
        }

    results.append(result)

    with open(filename, 'w') as results_file:
        json.dump(results, results_file, indent=4)


1it [00:00,  6.98it/s]

{'M': 1, 'N': 40, 'D': 51, 'T': 1000, 'seed': 0, 'num_SM_events': 8, 'SM_total_spikes': 10, 'noise': 100}
Time for data mining: 0.08275675773620605
{'M': 4, 'N': 40, 'D': 51, 'T': 1000, 'seed': 0, 'num_SM_events': 8, 'SM_total_spikes': 10, 'noise': 100}
Time for data mining: 0.2990283966064453


2it [00:02,  1.37s/it]

{'M': 8, 'N': 40, 'D': 51, 'T': 1000, 'seed': 0, 'num_SM_events': 8, 'SM_total_spikes': 10, 'noise': 100}
Time for data mining: 1.008591651916504


3it [00:29, 12.99s/it]

{'M': 16, 'N': 40, 'D': 51, 'T': 1000, 'seed': 0, 'num_SM_events': 8, 'SM_total_spikes': 10, 'noise': 100}
Time for data mining: 7.367804288864136


4it [08:00, 185.88s/it]

{'M': 32, 'N': 40, 'D': 51, 'T': 1000, 'seed': 0, 'num_SM_events': 8, 'SM_total_spikes': 10, 'noise': 100}
{'M': 8, 'N': 10, 'D': 51, 'T': 1000, 'seed': 0, 'num_SM_events': 8, 'SM_total_spikes': 10, 'noise': 100}
Time for data mining: 5.9182703495025635


In [None]:
df = pd.read_json(filename)

In [None]:
df

In [17]:
df.groupby(["M"])

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x0000021A3D62A7D0>

In [24]:
# Calculate the correlation matrix
correlation_matrix = df.corr()

# Display the correlation matrix
print(correlation_matrix)

                        M         N         D   T  num_SM_events  \
M                1.000000 -0.019511 -0.003807 NaN      -0.022546   
N               -0.019511  1.000000 -0.003488 NaN      -0.020657   
D               -0.003807 -0.003488  1.000000 NaN      -0.004031   
T                     NaN       NaN       NaN NaN            NaN   
num_SM_events   -0.022546 -0.020657 -0.004031 NaN       1.000000   
SM_total_spikes -0.024370 -0.022329 -0.004357 NaN      -0.025802   
noise           -0.031629 -0.028979 -0.005655 NaN      -0.033487   
window_time      0.331701 -0.062696  0.011927 NaN       0.448662   
cluster_time     0.404064  0.059379 -0.059149 NaN       0.476124   
sequence_time   -0.049144 -0.004694 -0.015947 NaN       0.010094   
total_time       0.118995 -0.006469  0.026695 NaN       0.410689   
total_patterns  -0.102824  0.030625 -0.003024 NaN      -0.030117   

                 SM_total_spikes     noise  window_time  cluster_time  \
M                      -0.024370 -0.031629

  correlation_matrix = df.corr()


In [None]:
plt.figure