In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import itertools

In [None]:
import ecephys.utils
from ecephys_analyses.psth import make_psth_heatmap, make_pooled_psth_hist
from ecephys_analyses.data import paths

In [None]:
plt.rcParams['figure.figsize'] = [8, 4]
plt.rcParams['figure.dpi'] = 200 # 200 e.g. is really fine, but slower

SMALL_SIZE = 12
MEDIUM_SIZE = 15
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# All

In [None]:
# OUTPUT_DIR = None  # If None, save all the plots in root_dir/subject/condition/plots/sorting_condition
OUTPUT_DIR = './plots/psth_hist'  # If None, save all the plots in root_dir/subject/condition/plots/sorting_condition


#### Data params

root_key = 'eStim'  # A key in `ecephys_analyses/data/conditions.yml` specifying the path where we save all subjects' data

condition_group = 'eStim'  # eg 'eStim'. A key in `ecephys_analyses/data/conditions.yml` specifying the list of conditions to process

subjects = [
    # 'Santiago',
]

sorting_condition = 'ks2_5_catgt_Th=12-10_lam=50_8s-batches_postpro_2' # Name of the curated kilosort output directory. For all subjects/conditions

states = [
    "N2", "Wake", "REM", "Sevo"
]  # States to compute. The `stim_times.csv` file should have a `stim_state` column

regions = [
    'cortex_base',
    'thalamus_base',
#     'all',
]  # Regions to compute. if not 'all', should be a key in `ecephys_analyses/data/regions.yml`

selection_intervals = {
    'fr': (0.0, float('inf')),
} # etc, etc, Keys are columns from cluster_info.tsv or metrics.csv.

selected_groups = ['good', 'mua', 'noise_contam']


#### PSTH params

normalize = 'baseline_norm'

# norm_window = [-10000, 2000]  # Window for computing and normalizing PSTH. [norm_window[0], -binsize] is used as baseline to normalize the PSTH
norm_window = [-4000, 2000]  # For Allan only since the pulses are closeby
plot_window = [-200, 800]  # Window for plotting the PSTH. (Can be shorter than window for normalization)
binsize = 5  # To bin spikes

#### Plot

ylim = [0, 3]  # Color scale domain for z-scored data
# ylim = None  # Color scale domain for z-scored data

#### Run

n_jobs = 10  # How many jobs in parallel


kwargs = {
    'selection_intervals': selection_intervals,
    'selected_groups': selected_groups,
    'binsize': binsize,
    'normalize': normalize,
    'norm_window': norm_window,
    'plot_window': plot_window,
    'ylim': ylim,
    'save': True,
    'show': False,
    'output_dir': OUTPUT_DIR,
    'root_key': root_key,
}

data = [] 
for s in subjects:
    data += [(s, c) for c in paths.get_conditions(s, condition_group)]
# conditions = paths.get_conditions(subject, condition_group)
print(f"N={len(data)} conditions: {data}")

In [None]:
# Example fig..
fig, ax = make_pooled_psth_hist(
    data[0][0], data[0][1], sorting_condition,
    region=regions[0],
    state=states[0],
    **kwargs,
)

In [None]:
if n_jobs == 1:
    for (
            (subject, condition),
            state,
            region, 
    ) in itertools.product(
        data, regions,
    ):

        make_pooled_psth_hist(
            subject, condition, sorting_condition,
            region=region,
            state=state,
            **kwargs,
        )
else:
    from joblib import Parallel, delayed
    
    def make_figures_parall( subject, condition, state, region):
        make_pooled_psth_hist(
            subject, condition, sorting_condition,
            region=region,
            state=state,
            **kwargs,
        )
        
    Parallel(n_jobs=n_jobs, backend='multiprocessing')(
        delayed(make_figures_parall)(
            subject, condition,
            state,
            region, 
        ) for (
            (subject, condition), state, region, 
        ) in itertools.product(
            data, states, regions
        )
    )
    print('done')