In [None]:
from src.tabs.temp_matching.utils import *
import pandas as pd
import itertools
from tqdm import tqdm
import warnings
from joblib import Parallel, delayed
from sk

In [2]:
data_struct = load_data()

In [3]:
def score_session(session_idx, similarity_metric, time_bound, multi_level,stim_temp_trial_type,nostim_temp_trial_type,bin_size,feature_type):
    # input processing
    SESSION_IDX = session_idx
    SIMILIARITY_METRIC = modified_cosine if similarity_metric == 'cosine' else euclidean
    TIME_BOUND = (time_bound[0]/1000,time_bound[1]/1000)
    TWO_TEMP_MODE = not multi_level
    BIN_SIZE = bin_size/1000
    FEATURE_TYPE = FeatureType[feature_type]

    region_filt = data_struct['trials'][SESSION_IDX]['brain_region'] == 'wS1'
    if np.sum(region_filt) < 10:
        return None

    trial_onsets = data_struct['trial_onset'][SESSION_IDX]
    neural_activity = np.array(data_struct['trials'][SESSION_IDX]['spikes'],dtype=object)[region_filt]
    stim_amps = data_struct['trials'][SESSION_IDX]['stim_amp']
    real_stims_binary = stim_amps.astype(bool)
    trial_results = np.array(data_struct['trials'][session_idx]['result'])

    template_stim_amps = [0,4] if TWO_TEMP_MODE else range(5)
    trial_filt = np.vstack([data_struct['trials'][SESSION_IDX]['stim_amp'] == i for i in template_stim_amps])
    temps_count = trial_filt.shape[0]
    stim_templates = []

    for temp_idx in range(temps_count):
        if temp_idx == 0:
            match nostim_temp_trial_type:
                case 'All':
                    filtered_onsets = trial_onsets[trial_filt[temp_idx,:]]
                case 'CR':
                    filtered_onsets = trial_onsets[trial_filt[temp_idx,:] & (trial_results=='CR')]
                case 'FA': 
                    filtered_onsets = trial_onsets[trial_filt[temp_idx,:] & (trial_results=='FA')]
        else:
            match stim_temp_trial_type:
                case 'All':
                    filtered_onsets = trial_onsets[trial_filt[temp_idx,:]]
                case 'Hit':
                    filtered_onsets = trial_onsets[trial_filt[temp_idx,:] & (trial_results=='hit')] 
                case 'Miss':
                    filtered_onsets = trial_onsets[trial_filt[temp_idx,:] & (trial_results=='miss')] 
        stim_templates.append(get_template(neural_activity,filtered_onsets,BIN_SIZE,TIME_BOUND,FEATURE_TYPE))
    stim_templates = np.array(stim_templates)

    temp_matcher = TemplateMatching(templates=stim_templates,similarity_metric=SIMILIARITY_METRIC)
    
    trial_templates = []
    trials_count = len(trial_onsets)
    for trial_idx in range(trials_count):
        trial_templates.append(get_template(neural_activity,[trial_onsets[trial_idx]],BIN_SIZE,TIME_BOUND,FEATURE_TYPE))
    trial_templates = np.array(trial_templates)

    soft_decode_result = np.zeros(trials_count)
    for trial_idx in range(trials_count):
        sample_distances = temp_matcher.decode_soft(trial_templates[trial_idx,:])
        soft_decode_result[trial_idx] = confidence_calc_from_distance(sample_distances,1)
    CONFIDENCE_THRESHOLD = find_optimal_threshold(real_stims_binary,soft_decode_result)
    hard_decode_result = soft_decode_result >= CONFIDENCE_THRESHOLD

    return f1_score(real_stims_binary,hard_decode_result)

In [4]:
def _safe_score(score_func, keys, combo):
    params = dict(zip(keys, combo))
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("error")  # Turn warnings into exceptions
            score = score_func(**params)
    except Exception:
        score = None
    return {**params, 'score': score}

def evaluate_grid(score_func,param_grid, n_jobs=-1,parallel = True):
    keys = list(param_grid.keys())
    combinations = list(itertools.product(*(param_grid[k] for k in keys)))

    if parallel:
        results = Parallel(n_jobs=n_jobs)(
            delayed(_safe_score)(score_func, keys, combo)
            for combo in tqdm(combinations)
        )
    else:
        results = [_safe_score(score_func,keys,combo) for combo in tqdm(combinations)]

    return pd.DataFrame(results)


full_param_grid = {
    'session_idx': list(range(29)),
    'similarity_metric': ['cosine', 'euclidean'],
    'time_bound': [(0,50), (0,100), (0,200), (0,500)],
    'multi_level': [True, False],
    'stim_temp_trial_type': ['All', 'Hit', 'Miss'],
    'nostim_temp_trial_type': ['All', 'CR', 'FA'],
    'bin_size': [5, 10, 50, 100],
    'feature_type': [ft.name for ft in FeatureType]
}

In [5]:
# round 1: one session, all trial types
param_grid_1 = {
    'session_idx': list([0]),
    'similarity_metric': ['cosine', 'euclidean'],
    'time_bound': [(0,50), (0,100), (0,200), (0,500)],
    'multi_level': [True, False],
    'stim_temp_trial_type': ['All'],
    'nostim_temp_trial_type': ['All'],
    'bin_size': [5, 10, 50, 100],
    'feature_type': [ft.name for ft in FeatureType]
}

# Run in parallel
df1 = evaluate_grid(score_session,param_grid_1, parallel=False)
df1

  0%|          | 0/256 [00:00<?, ?it/s]

100%|██████████| 256/256 [08:28<00:00,  1.99s/it]


Unnamed: 0,session_idx,similarity_metric,time_bound,multi_level,stim_temp_trial_type,nostim_temp_trial_type,bin_size,feature_type,score
0,0,cosine,"(0, 50)",True,All,All,5,PEAK,0.688406
1,0,cosine,"(0, 50)",True,All,All,5,FULL,0.809524
2,0,cosine,"(0, 50)",True,All,All,5,FULL_NORMALIZED,0.830065
3,0,cosine,"(0, 50)",True,All,All,5,SUM,0.744526
4,0,cosine,"(0, 50)",True,All,All,10,PEAK,0.724638
...,...,...,...,...,...,...,...,...,...
251,0,euclidean,"(0, 500)",False,All,All,50,SUM,0.671264
252,0,euclidean,"(0, 500)",False,All,All,100,PEAK,0.690763
253,0,euclidean,"(0, 500)",False,All,All,100,FULL,0.671264
254,0,euclidean,"(0, 500)",False,All,All,100,FULL_NORMALIZED,0.737179


In [6]:
# round 2: selctive trial types
param_grid_2 = {
    'session_idx': [0],
    'similarity_metric': ['euclidean'],
    'time_bound': [(0,500)],
    'multi_level': [True],
    'stim_temp_trial_type': ['All', 'Hit', 'Miss'],
    'nostim_temp_trial_type': ['All', 'CR', 'FA'],
    'bin_size': [5],
    'feature_type': ['FULL_NORMALIZED']
}

# Run in parallel
df2 = evaluate_grid(score_session,param_grid_2, parallel=False)
df2

100%|██████████| 9/9 [00:25<00:00,  2.81s/it]


Unnamed: 0,session_idx,similarity_metric,time_bound,multi_level,stim_temp_trial_type,nostim_temp_trial_type,bin_size,feature_type,score
0,0,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.993151
1,0,euclidean,"(0, 500)",True,All,CR,5,FULL_NORMALIZED,0.972603
2,0,euclidean,"(0, 500)",True,All,FA,5,FULL_NORMALIZED,0.961938
3,0,euclidean,"(0, 500)",True,Hit,All,5,FULL_NORMALIZED,0.92459
4,0,euclidean,"(0, 500)",True,Hit,CR,5,FULL_NORMALIZED,0.865672
5,0,euclidean,"(0, 500)",True,Hit,FA,5,FULL_NORMALIZED,0.812721
6,0,euclidean,"(0, 500)",True,Miss,All,5,FULL_NORMALIZED,0.708738
7,0,euclidean,"(0, 500)",True,Miss,CR,5,FULL_NORMALIZED,0.701923
8,0,euclidean,"(0, 500)",True,Miss,FA,5,FULL_NORMALIZED,0.712195


In [7]:
# round 3: go across sessions
param_grid_3 = {
    'session_idx': list(range(29)),
    'similarity_metric': ['euclidean'],
    'time_bound': [(0,500)],
    'multi_level': [True],
    'stim_temp_trial_type': ['All'],
    'nostim_temp_trial_type': ['All'],
    'bin_size': [5],
    'feature_type': ['FULL_NORMALIZED']
}

# Run in parallel
df3 = evaluate_grid(score_session,param_grid_3, parallel=False)
df3

100%|██████████| 29/29 [05:06<00:00, 10.58s/it]


Unnamed: 0,session_idx,similarity_metric,time_bound,multi_level,stim_temp_trial_type,nostim_temp_trial_type,bin_size,feature_type,score
0,0,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.993151
1,1,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,1.0
2,2,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.988662
3,3,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.997669
4,4,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.99061
5,5,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,
6,6,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,
7,7,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.979644
8,8,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,
9,9,euclidean,"(0, 500)",True,All,All,5,FULL_NORMALIZED,0.993197
