In [1]:
import random
import numpy as np
import json
from tqdm import tqdm
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import linear_sum_assignment
from collections import Counter
import time
import os
import pandas as pd
from scipy.signal import correlate
np. set_printoptions(threshold=np. inf)

In [2]:
date = 'Sept26th-'
filename = ''.join((date,'scan_stats.json'))
filename

'Sept26th-scan_stats.json'

In [3]:
import scan
import simulate_data

In [4]:
default_params = {
    'M':8,
    'N':40,
    'D':51,
    'T':1000,
    'seed':0,
    'num_SM_events':16,
    'SM_total_spikes':10,
    'noise':100
}
scan_dict = {
    'M':[1,4,8,16,32],
    'N':[10,20,40,80,120],
    'D':[11,31,51,71,101],
    'T':[1000],
    'seed':[0],
    'num_SM_events':[2,4,8,16,32],
    'SM_total_spikes':[3,5,10,20,50],
    'noise':[0,50,100,500,1000]
}

In [5]:
from scipy.signal import correlate
def get_acc(ground_truths,detected_patterns):
    # Calculate cross-correlation matrix
    cross_corr_matrix = np.zeros((ground_truths.shape[2], detected_patterns.shape[2]))
    SM_acc = np.zeros((ground_truths.shape[2]))
    
    if len(detected_patterns) == 0:
        return SM_acc, cross_corr_matrix
    
    for ground_truths_idx in range(ground_truths.shape[2]):
        for detected_patterns_idx in range(detected_patterns.shape[2]):
            cross_corr = np.zeros((ground_truths.shape[1]+detected_patterns.shape[1]-1))
            for n in range(ground_truths.shape[0]):
                cross_corr += correlate(ground_truths[n, :, ground_truths_idx], detected_patterns[n, :, detected_patterns_idx], mode='full')
            max_corr = np.max(cross_corr) / max(np.sum(ground_truths[...,ground_truths_idx]),np.sum(detected_patterns[...,detected_patterns_idx]))
            cross_corr_matrix[ground_truths_idx, detected_patterns_idx] = max_corr
#     print(cross_corr_matrix)
#     print( np.sum(ground_truths[...,ground_truths_idx]))
    SM_acc = np.max(cross_corr_matrix,axis=1)
    return SM_acc, cross_corr_matrix

In [6]:
param_combinations = []
for param_name, param_values in scan_dict.items():
    for param_value in param_values:
        params = default_params.copy()
        params[param_name] = param_value
        param_combinations.append(params)

In [8]:
len(param_combinations)

32

In [None]:
import numpy as np
import random
import os
import pandas as pd
import json
from tqdm import tqdm

# Define default parameters and scan values
default_params = {
    'M': 8,
    'N': 40,
    'D': 51,
    'T': 1000,
    'seed': 0,
    'num_SM_events': 16,
    'SM_total_spikes': 10,
    'noise': 100
}

scan_dict = {
    'M': [1, 4, 8, 16, 32],
    'N': [10, 20, 40, 80, 120],
    'D': [11, 31, 51, 71, 101], 
    'num_SM_events': [2, 4, 8, 16, 32],
    'SM_total_spikes': [3, 5, 10, 20, 50],
    'noise': [0, 50, 100, 500, 1000]
}

# Generate all parameter combinations
param_combinations = []
for param_name, param_values in scan_dict.items():
    for param_value in param_values:
        params = default_params.copy()
        params[param_name] = param_value
        param_combinations.append(params)

num_samples = len(param_combinations)
results = []

# Iterate through parameter combinations
for idx, params in tqdm(enumerate(param_combinations)):

    if os.path.isfile(filename):
        df = pd.read_json(filename)
        with open(filename, 'r') as results_file:
            results = json.load(results_file)
    start = time.time()
    A_dense, A_sparse, B_dense, B_sparse, K_dense, K_sparse = simulate_data.generate_synthetic_data(params)
    pattern_template, sublist_keys_filt, window_time, cluster_time, sequence_time = scan.scan_raster(A_sparse[1], A_sparse[0], window_dim=params['D'])
    end = time.time()
    result = {
        'M':params['M'],
        'N':params['N'],
        'D':params['D'],
        'T':params['T'],
        'num_SM_events':params['num_SM_events'],
        'SM_total_spikes':params['SM_total_spikes'],
        'noise':params['noise'],
        'window_time': window_time,
        'cluster_time': cluster_time,
        'sequence_time': sequence_time,
        'total_time': end-start
    }

    results.append(result)

    with open(filename, 'w') as results_file:
        json.dump(results, results_file, indent=4)


0it [00:00, ?it/s]

260 Windows
progress - 96.0% | cutoff - 9.48 | opt_cutoff - 1.37 | most_detections - 15etections - 15

1it [00:00,  1.33it/s]

2 patterns found...s... 50% 10.01 | opt_cutoff - 1.37 | most_detections - 15
734 Windows
6 patterns found...s... 83% 10.01 | opt_cutoff - 0.47000000000000003 | most_detections - 15tections - 155


2it [00:02,  1.48s/it]

1359 Windows
8 patterns found...s... 88% 10.01 | opt_cutoff - 0.47000000000000003 | most_detections - 15tections - 155


3it [00:13,  5.56s/it]

2571 Windows
39 patterns found...... 97% 10.01 | opt_cutoff - 0.34 | most_detections - 3etections - 33


4it [00:31, 10.47s/it]

4903 Windows
73 patterns found...... 99% 10.01 | opt_cutoff - 0.42 | most_detections - 2etections - 22


5it [02:02, 39.56s/it]

1298 Windows
14 patterns found...... 93% 10.01 | opt_cutoff - 0.83 | most_detections - 4etections - 4


6it [02:13, 29.84s/it]

1335 Windows
2 patterns found...s... 50% 10.01 | opt_cutoff - 0.67 | most_detections - 15tections - 15


7it [02:22, 22.97s/it]

1361 Windows
10 patterns found...... 90% 10.01 | opt_cutoff - 0.47000000000000003 | most_detections - 15tections - 155


8it [02:38, 20.97s/it]

1367 Windows
22 patterns found...... 95% 10.01 | opt_cutoff - 0.34 | most_detections - 15tections - 155


9it [02:50, 18.23s/it]

1369 Windows
17 patterns found...... 94% 10.01 | opt_cutoff - 0.34 | most_detections - 15tections - 155detections - 2


10it [03:04, 16.73s/it]

1363 Windows
23 patterns found...... 96% 10.01 | opt_cutoff - 1.09 | most_detections - 15tections - 15_detections - 2


11it [03:09, 13.14s/it]

1360 Windows
21 patterns found...... 95% 10.01 | opt_cutoff - 0.61 | most_detections - 15tections - 15_detections - 8


12it [03:16, 11.23s/it]

1354 Windows
12 patterns found...... 92% 10.01 | opt_cutoff - 0.51 | most_detections - 15tections - 15


13it [03:25, 10.60s/it]

1364 Windows
13 patterns found...... 92% 10.01 | opt_cutoff - 0.38 | most_detections - 15tections - 155


14it [03:38, 11.27s/it]

1363 Windows
9 patterns found...s... 89% 10.01 | opt_cutoff - 0.34 | most_detections - 14tections - 144


15it [03:57, 13.79s/it]

1351 Windows
15 patterns found...... 93% 10.01 | opt_cutoff - 0.38 | most_detections - 15tections - 155


16it [04:07, 12.43s/it]

1361 Windows
13 patterns found...... 92% 10.01 | opt_cutoff - 0.47000000000000003 | most_detections - 15tections - 155


17it [04:26, 14.53s/it]

260 Windows
progress - 94.0% | cutoff - 8.79 | opt_cutoff - 0.09 | most_detections - 1detections - 11

18it [04:28, 10.71s/it]

1 patterns found...s... 0%- 10.01 | opt_cutoff - 0.09 | most_detections - 1
412 Windows
11 patterns found...... 91% 10.01 | opt_cutoff - 0.51 | most_detections - 5etections - 54


19it [04:30,  8.17s/it]

731 Windows
19 patterns found...... 95% 10.01 | opt_cutoff - 0.3 | most_detections - 7etections - 77


20it [04:34,  6.90s/it]

1354 Windows
5 patterns found...s... 80% 10.01 | opt_cutoff - 0.67 | most_detections - 16tections - 165detections - 15


21it [04:49,  9.44s/it]

2578 Windows
23 patterns found...... 96% 10.01 | opt_cutoff - 0.47000000000000003 | most_detections - 31tections - 311


22it [05:24, 17.04s/it]

480 Windows
progress - 82.0% | cutoff - 5.14 | opt_cutoff - 0.14 | most_detections - 1detections - 11

23it [05:26, 12.46s/it]

1 patterns found...s... 0%- 10.01 | opt_cutoff - 0.14 | most_detections - 1etections - 1
736 Windows
3 patterns found...s... 67% 10.01 | opt_cutoff - 0.47000000000000003 | most_detections - 15tections - 155


24it [05:29,  9.53s/it]

1360 Windows
8 patterns found...s... 88% 10.01 | opt_cutoff - 0.42 | most_detections - 15tections - 155


25it [05:36,  8.79s/it]

2567 Windows
18 patterns found...... 94% 10.01 | opt_cutoff - 0.56 | most_detections - 15tections - 153detections - 13


26it [06:08, 15.92s/it]

6008 Windows
142 patterns found..... 99% 10.01 | opt_cutoff - 0.67 | most_detections - 6etections - 6t_detections - 1


27it [10:24, 88.03s/it]

40000 Windows
Windowing... 27%

In [9]:
param_combinations

[{'M': 1,
  'N': 40,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 4,
  'N': 40,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 8,
  'N': 40,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 16,
  'N': 40,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 32,
  'N': 40,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 8,
  'N': 10,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 8,
  'N': 20,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 8,
  'N': 40,
  'D': 51,
  'T': 1000,
  'seed': 0,
  'num_SM_events': 16,
  'SM_total_spikes': 10,
  'noise': 100},
 {'M': 8,
  'N

In [47]:
num_samples = 3
param_combinations = np.array(np.meshgrid(*scan_dict.values())).T.reshape(-1, len(scan_dict))
num_iterations = len(param_combinations)

# Generate random indices for sampling
random_indices = random.sample(range(num_iterations), num_samples)
results = []


# Iterate through parameter combinations
for idx in tqdm(random_indices):
    
    params = {key: int(val) for key, val in zip(scan_dict.keys(), param_combinations[idx])}
    if os.path.isfile(filename):
        df = pd.read_json(filename)
        with open(filename, 'r') as results_file:
            results = json.load(results_file)
            
            
    A_dense, A_sparse, B_dense, B_sparse, K_dense, K_sparse = simulate_data.generate_synthetic_data(params)
    pattern_template, sublist_keys_filt, window_time, cluster_time, sequence_time = scan.scan_raster(A_sparse[1], A_sparse[0], window_dim=params['D'])

    print(params)


    result = {
        'idx':idx,
        'window_time':window_time,   
        'cluster_time':cluster_time,
        'sequence_time':sequence_time
    }

    results.append(result)

    with open(filename, 'w') as results_file:
        json.dump(results, results_file, indent=4)
    
    
    

  0%|                                                                                            | 0/3 [00:00<?, ?it/s]

1186 Windows
4 patterns found...s... 75% 10.01 | opt_cutoff - 0.67 | most_detections - 3etections - 3


 33%|████████████████████████████                                                        | 1/3 [00:02<00:04,  2.15s/it]

364 Windows
progress - 98.0% | cutoff - 10.01 | opt_cutoff - 0.13 | most_detections - 1etections - 11

 67%|████████████████████████████████████████████████████████                            | 2/3 [00:02<00:01,  1.28s/it]

1 patterns found...s... 0%
120000 Windows
Windowing... 2%

 67%|████████████████████████████████████████████████████████                            | 2/3 [00:11<00:05,  5.97s/it]

Windowing... 2%Windowing... 2%Windowing... 2%Windowing... 2%Windowing... 2%Windowing... 2%Windowing... 2%Windowing... 2%Windowing... 2%




KeyboardInterrupt: 

In [46]:
params['D']

11

In [33]:
params

{'M': 16,
 'N': 10,
 'D': 11,
 'T': 1000,
 'seed': 0,
 'num_SM_events': 8,
 'SM_total_spikes': 5,
 'noise': 1000}

In [48]:
df = pd.read_json(filename)

In [49]:
df.head()

Unnamed: 0,idx,test,window_time,cluster_time,sequence_time
0,6090,test,,,
1,12262,test,,,
2,6635,test,,,
3,15176,,0.860845,1.173667,0.001995
4,4466,,0.059833,0.575803,0.0


In [51]:
N = 1000
M = 10
D = 71
T = 1000
seed=0

num_SM_events = 10
SM_total_spikes = 150
noise = 1000

params = {
    'N':N,
    'M':M,
    'D':D,
    'T':T,
    'seed':seed,
    'num_SM_events':num_SM_events,
    'SM_total_spikes':SM_total_spikes,
    'noise':noise
}

start = time.time()
A_dense, A_sparse, B_dense, B_sparse, K_dense, K_sparse = simulate_data.generate_synthetic_data(params)
pattern_template, sublist_keys_filt, window_time, cluster_time, sequence_time = scan.scan_raster(A_sparse[1], A_sparse[0], window_dim=params['D'])
end = time.time()
print(end-start)

15867 Windows
38 patterns found...... 97% 10.01 | opt_cutoff - 0.38 | most_detections - 19tections - 199detections - 7


KeyboardInterrupt: 