In [None]:
from pathlib import Path
from datetime import datetime
from joblib import Parallel, delayed
import spikeinterface.extractors as se
import spikeinterface.sorters as ss
import itertools

In [None]:
from ecephys_analyses.data import channel_groups, paths, parameters

In [None]:
from ecephys_analyses.sorting import run_sorting

In [None]:
ss.installed_sorters()

In [None]:

sorting_conditions = [
    'ks2_5_raw_df',
#     'ks2_5_raw_minFR=0',
#     'ks2_5_raw_rigid',
#     'ks2_5_raw_8s-batches',
    'ks2_5_raw_16s-batches_minFR=0',
    'ks2_5_raw_rigid_minFR=0',
]

use_catgt_data = False  # Use catgt-preprocessed data

##########
## Manual:
# Subject, condition
data_conditions = [
    ('Doppio', 'drift_test_01'),
    ('Doppio', 'drift_test_02'),
    ('Valentino', 'drift_test_01'),
    ('Valentino', 'drift_test_02'),
    ('Allan', 'drift_test_01_imec0'),
    ('Allan', 'drift_test_01_imec1'),
]
#########

bad_channels = None  # TODO
rerun_existing = False  # Ignore if 'spike_times.npy' in output dir
dry_run = False # Create output dirs, don't run sorting
clean_dat_file = True  # Remove `recording.dat` generated by spikeinterface if sorting was successful

n_jobs = 1


In [None]:
if n_jobs == 1:
    for (
        (subject, condition),
        sorting_condition
    ) in itertools.product(data_conditions, sorting_conditions):
        run_sorting(
            subject,
            condition,
            sorting_condition,
            catgt_data=use_catgt_data,
            bad_channels=bad_channels,
            rerun_existing=rerun_existing,
            dry_run=dry_run,
            clean_dat_file=clean_dat_file
        )
else:
    
    parallel = Parallel(
        n_jobs=n_jobs,
        backend='multiprocessing',
    )(
        delayed(run_sorting)(
            subject,
            condition,
            sorting_condition,
            catgt_data=use_catgt_data,
            bad_channels=bad_channels,
            rerun_existing=rerun_existing,
            dry_run=dry_run,
            clean_dat_file=clean_dat_file
        ) for ((subject, condition), sorting_condition)
        in itertools.product(data_conditions, sorting_conditions)
    )