In [16]:
import random
import numpy as np
import json
from tqdm import tqdm
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import linear_sum_assignment
from collections import Counter
import time
import os
import pandas as pd
from scipy.signal import correlate
import matplotlib.pyplot as plt

In [36]:
'''
Setup
'''

import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import sys
np.set_printoptions(threshold=sys.maxsize)
disp_figs = True
import colorsys

def create_color_spectrum(num_labels):
    golden_ratio_conjugate = 0.618033988749895
    hues = np.arange(num_labels)
    hues = (hues * golden_ratio_conjugate) % 1.0
    saturations = np.ones(num_labels) * 0.8
    lightness = np.ones(num_labels) * 0.6

    # Convert HSL to RGB and then to hexadecimal
    colors = []
    for h, s, l in zip(hues, saturations, lightness):
        r, g, b = [int(255 * x) for x in colorsys.hls_to_rgb(h, l, s)]
        colors.append(f'#{r:02x}{g:02x}{b:02x}')

    return colors

# Existing colors represented as hexadecimal strings
existing_colors = np.array(['#000000','#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                            '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'])

# Create a palette with 101 colors (11 existing + 90 new)
num_new_colors = 90
new_colors = create_color_spectrum(num_new_colors)
palette = np.concatenate([existing_colors, new_colors])

In [37]:
N = 100
M = 4
D = 71
T = 1000
seed=0

num_SM_events = 5
SM_total_spikes = 10
noise = 50

params = {
    'N':N,
    'M':M,
    'D':D,
    'T':T,
    'seed':seed,
    'num_SM_events':num_SM_events,
    'SM_total_spikes':SM_total_spikes,
    'noise':noise
}

In [40]:
# Loop through each 'M' dimension
K_rand = np.random.rand(N,D,M)
K_dense = np.zeros_like(K_rand)

B_rand = np.random.rand(M,T-D)
B_dense = np.zeros_like(B_rand) 

A_rand = np.zeros((N,T,M))
A_dense = np.zeros_like(A_rand)
sorted_indices_A = np.argsort(A_rand[:,:,-1],axis=None)
top_indices_A = np.unravel_index(sorted_indices_A[-noise:], (N, T))
A_dense[top_indices_A[0], top_indices_A[1], -1] =1

for m in range(M):
    # Sort the current 'M' dimension
    sorted_indices_K = np.argsort(K_rand[:, :, m], axis=None)
    sorted_indices_B = np.argsort(K_rand[m, :], axis=None)
    
    # Find the indices of the 'X' highest values in this dimension
    top_indices_K = np.unravel_index(sorted_indices_K[-SM_total_spikes:], (N, D))
    top_indices_B = np.unravel_index(sorted_indices_B[-num_SM_events:], (T))
    
    # Set the corresponding elements to 1 in the binary matrix
    K_dense[top_indices_K[0], top_indices_K[1], m] = 1
    B_dense[m, top_indices_B[0]] = 1
    
K_sparse = np.where(K_dense)
B_sparse = np.where(B_dense)

for i in range(len(B_sparse[0])):
    t = B_sparse[1][i]
    b = B_sparse[0][i]
    A_dense[:, t:t+D, b] += K_dense[...,b-1]

A_sparse = np.where(A_dense)
    
    