In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
from pathlib import Path
import matplotlib.pyplot as plt

In [None]:
from ecephys_analyses.psth import make_psth_figures
from ecephys_analyses.data import paths

In [None]:
plt.rcParams['figure.figsize'] = [8, 6]
plt.rcParams['figure.dpi'] = 200 # 200 e.g. is really fine, but slower

SMALL_SIZE = 12
MEDIUM_SIZE = 15
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

# All

In [None]:
OUTPUT_DIR = None  # Save in subject/condition/plots/sorting_condition

#### Data

# sorting_condition = 'ks2_5_catgt_df'
sorting_condition = 'ks2_5_catgt_df_postpro_1'

good_only_values = [
    False,
#     True
]

states = [
    "Wake", "N2", "REM"
#     None
]

regions = [
    'cortex',
    'thalamus',
    'all',
]

subject = 'Eugene'
condition_group = ''

conditions = paths.get_conditions(subject, condition_group)
print(f"N={len(conditions)} conditions: {conditions}")

#### PSTH

norm_window = [-10000, 2000]
# norm_window = [-4000, 2000]  # For Allan only
plot_window = [-500, 500]
binsize = 10

#### Plot

clim_values = [
    [-3, 3],
]
draw_region_limits = True

#### Run

n_jobs = 13

In [None]:
print(conditions)

In [None]:
# Example fig..
fig, ax = make_psth_figures(
    subject, conditions[0], sorting_condition,
    good_only=False,
    normalize='baseline_zscore',
    region='all',
    norm_window=norm_window,
    plot_window=plot_window,
    state='Wake', clim=[-3, 3],
    save=False, show=True,
    output_dir=None, draw_region_limits=True
)

In [None]:
if n_jobs == 1:
    for (
            condition,
            state,
            region, 
            good_only,
            clim
    ) in itertools.product(
        conditions, states, regions, good_only_values, clim_values
    ):
        make_psth_figures(
            subject,
            condition,
            sorting_condition,
            state=state,
            region=region,
            good_only=good_only,
            binsize=binsize,
            norm_window=norm_window,
            plot_window=plot_window,
            clim=clim,
            save=True,
            show=False,
            draw_region_limits=draw_region_limits,
            output_dir=OUTPUT_DIR,
        )
else:
    from joblib import Parallel, delayed
    
    def make_figures_parall( condition, state, region,  good_only, clim ):
        make_psth_figures(
            subject,
            condition,
            sorting_condition,
            state=state,
            region=region,
            good_only=good_only,
            binsize=binsize,
            norm_window=norm_window,
            plot_window=plot_window,
            clim=clim,
            save=True,
            show=False,
            draw_region_limits=draw_region_limits,
            output_dir=OUTPUT_DIR,
        )
#         try:
#             make_psth_figures(
#                 subject,
#                 condition,
#                 sorting_condition,
#                 state=state,
#                 region=region,
#                 good_only=good_only,
#                 binsize=binsize,
#                 norm_window=norm_window,
#                 plot_window=plot_window,
#                 clim=clim,
#                 save=True,
#                 show=False,
#                 draw_region_limits=draw_region_limits,
#                 output_dir=OUTPUT_DIR,
#             )
#         except Exception as e:
#             print(f"failed :", subject, condition, state, region,  good_only, clim)
#             print(repr(e))
#             raise e
        
    Parallel(n_jobs=n_jobs, backend='multiprocessing')(
        delayed(make_figures_parall)(
            condition,
            state,
            region, 
            good_only,
            clim
        ) for (
            condition,
            state,
            region, 
            good_only,
            clim
        ) in itertools.product(
            conditions, states,
            regions, good_only_values, clim_values
        )
    )
    print('done')