# Run pema

Authors:
 - Angevaare, Joran <j.angevaare@nikhef.nl> (based on the Peak_Classification_Tester) 


## This notebook ##

**Goal**
 - See how clustering and classification may be improved by Kr simulations?

**Known issues**
 -


**Start up ``strax`` + load tools**

In [None]:
import pema
import os

In [None]:
if not os.path.exists('init.py'):
    init = os.path.join(pema.__path__[0], '..', 'bin', "pema_init.py")
    !ln -s $init init.py
%run init.py

### Initialize the wavefrom simulator with your instructions
Initialize either of
 - [Low E](setup_lowe.py)
 - [High E](setup_highe.py)
 - [Kr](setup_kr.py)

In [None]:
%run setup_lowe.py

## write instructions to CSV

In [None]:
def get_context(skip=None):
    st = pema.pema_context(base_dir=base_dir,
                           config_update=config_update,
                           raw_dir=raw_data_dir,
                           data_dir=data_dir)
    update = {}
    for k, v in update.items():
        if skip is not None and k in skip:
            print(f'skip {k}')
        else:
            st.set_config({k:v})
    st.set_context_config(
        {'allow_shm': True,
         'allow_lazy': False,
         'timeout': 300,
         'max_messages': 10,
    })
    return st
st = get_context()

In [None]:
job_registry=[]
for r in run_list:
    print(r)
    job = pema.ProcessRun(st, 
                          run_id=r, 
                          target=('records', 'peaklets'), 
                          config={})
    job_registry.append(job)

    cmd, job_name = job.make_cmd()
    job.exec_dali(cmd, job_name, environ_init, mem=30_000, max_hours='08:00:00')


In [None]:
# When working locally
for j in job_registry:
    log=j.log_file
    !tail -10 $log

In [None]:
script_writer = pema.ProcessRun(
    st,
    run_list,    
    ('raw_records', 'records', 'peaklets', 'peaks_matched', 'event_info'))

In [None]:
script_writer.all_stored()

# Submit the scripts
First we need to do records, then we can do peaklets (affected by clustering) after that.

# Make configs for the parameters to change
Let's loop over the options we want to scan. We will write those to a dict and work from there.

In [None]:
opts = st.show_config('peaks')
opts = opts[opts['applies_to'] == ('peaklets', 'lone_hits')]
opts = opts.to_records()

In [None]:
options = opts['option']
values = opts['default']
current = opts['current']
mask = current != '<OMITTED>'
values[mask] = current[mask]
res = {k: v for (k, v) in zip(options, values)}

In [None]:
list(res.keys())

In [None]:
keep_keys = [
 'peaklet_gap_threshold',
 'peak_left_extension',
 'peak_right_extension',
 'peak_split_filter_wing_width',
 'peak_split_min_area',
 'peak_split_iterations',
 'tight_coincidence_window_left',
 'tight_coincidence_window_right',
 's1_max_rise_time',
 's1_max_rise_time_post100',
 's2_merge_max_area',
 's2_merge_max_gap',
 's2_merge_max_duration'
]

In [None]:
res = {k: v for (k,v )in res.items() if k in keep_keys}
summary_config = res.copy()
res

In [None]:
def validate_conf(conf):
    """Check right"""
    gap_threshold = conf.get('peaklet_gap_threshold', res.get('peaklet_gap_threshold', 350))
    left_extension = conf.get('peak_left_extension', res.get('peak_left_extension', 30))
    right_extension = conf.get('peak_right_extension', res.get('peak_right_extension', 350))
    return gap_threshold > left_extension + right_extension

conf_tot = [{}]
i=0
keys_seen = []
_st = st.new_context()
for k, v in summary_config.items():
    if k == 'peak_split_gof_threshold': 
        continue
    _type = type(v)
    for factor in [
        #2.5, 2, 1.5, 1.25, 
                   1.1, 1, 1/1.1,
#                    1/1.25, 1/1.5, 1/2, 1/2.5
    ]:
        value = _type(v*factor)
        conf = {k: value}
        if not validate_conf(conf):
            print(f'skip {conf}')
            continue
        _st.set_config(conf)
        ev_key = _st.key_for('0', 'events')

        if str(ev_key) in keys_seen and factor != 1:
            continue
        else:
            conf_tot.append(conf)
            keys_seen.append(str(ev_key))

# factors = [1.5, 
#            1.25, 1, 0.8, 
#            0.5
# ]
# few_factors = [1.25, 1, 0.8]
# for fa in factors:
#     for fb in factors:
#         for fc in few_factors:
#             for fd in few_factors:
#                 for fx in few_factors:
#                     for fy in factors:
#                         for fz in few_factors:
#                             for fw in few_factors:
#                                 if not fa+fb+fc+fd+fx+fy+fz+fw in 7 + np.array(factors):
#                                     continue
#                                 if fa >1 or fb>1 or fc>1 or fd>1:
#                                     continue
# #                                 if fc!=1 or fw != 1:
# #                                     continue
#                                 value = (None, 
#                                          ((0.5*fx, 1*fa), (6.0*fy, 0.4*fb)), 
#                                          ((2.*fz, 1*fc), (4.5*fw, 0.4*fd)))
#                                 conf = {'peak_split_gof_threshold': value}
#                                 _st.set_config(conf)
#                                 ev_key = _st.key_for('0', 'events')

#                                 if str(ev_key) in keys_seen:
#                                     continue
#                                 else:
#                                     conf_tot.append(conf)
#                                     keys_seen.append(str(ev_key))

confs = tuple(conf_tot.copy())
len(conf_tot)

In [None]:
confs

# Submit the logs

In [None]:
run_list

In [None]:
# wanna start small?
selected_runs = [r for r in run_list if st.is_stored(r, 'records')]
print(f'Doing runs:\n{selected_runs}\n{len(selected_runs)/len(run_list)*100:.1f}%')
all_runs = len(selected_runs) == len(run_list)

In [None]:
confs

In [None]:
iter_confs = iter(confs)
job_registry = []

In [None]:
target = ('raw_records',
          'records', 
          'peaklets',
          'peak_basics',
          'match_acceptance_extended',
)

RAM = 15000
queue_max = 110
check_que_after = 10
part = 'xenon1t'

for i, conf in enumerate(tqdm(iter_confs, 
                              total = len(confs) - len(job_registry ), 
                              desc='configs')):
    job = pema.ProcessRun(st, run_id=selected_runs, target=target, config=conf)
    cmd, job_name = job.make_cmd()

    job_registry.append(job)
    if not job.all_stored(return_bool=True):
#         print(job_name)
#         job.purge_below('peak_basics')
#         job.purge_below('match_acceptance_extended')
        job.exec_dali(cmd, job_name, environ_init, mem=6_000, 
                      max_hours='01:00:00', partition=part)

    if i % check_que_after:
        q = !squeue -u `echo $USER`
        while len(q)> queue_max:
            q = !squeue -u `echo $USER`
            print(f'waiting 10s, queue is full. {len(q)}')
            time.sleep(10)

In [None]:
pd.set_option('display.max_rows', 500)

In [None]:
while True:
    nrun = !squeue -u $USER  | wc -l
    nrun = int(nrun[0])
    print(nrun)
    if nrun < 5:
        break
    time.sleep(1)

In [None]:
df = pd.concat([j.all_stored(show_key=True) for j in job_registry])
df

### Load the simulated data using strax ("default")

### Load the simulated data using different config parameters
Let's copy the old context and run with different settings

In [None]:
def compute_acceptance(data):
    total = len(data)
    found = np.sum(data['acceptance_fraction'])
    return found/total, pema.binom_interval(found, total)

In [None]:
def compute_bias(data):
    total = len(data)
    sub_sel = data['rec_bias'] > 0 
    return np.mean(data['rec_bias'][sub_sel]), np.std(data['rec_bias'][sub_sel])

In [None]:
res = defaultdict(list)

for i, job in enumerate(tqdm(job_registry)):
    config = job.config
    if config in res['config']:
        continue
    elif not job.all_stored(return_bool=True):
        print(f'skip {job}')
        continue
    
    data = job.st.get_array(selected_runs, 'match_acceptance_extended', progress_bar=False)
    res['number'].append(i)
    res['config'].append(config)
    res['config_type'].append(list(config.keys()))
    for si in range(1,3):
        sel = data['type'] == si
        acceptance, (low, high) = compute_acceptance(data[sel])
        res[f's{si}_acc'].append(acceptance)
        res[f's{si}_acc_err'].append([acceptance-low, high-acceptance])

        bias_mean, bias_err = compute_bias(data[sel])
        res[f's{si}_bias'].append(bias_mean)
        res[f's{si}_bias_err'].append(bias_err)

In [None]:
df = pd.DataFrame(res)
df

In [None]:
for config_type in np.unique(df['config_type'].values):
    if not config_type:
        continue
    print(config_type)
    mask = np.array([c == config_type for c in  df['config_type']])
    plot_kwargs = dict(markersize= 5,  ls='none', capsize=3, marker='o')
    fig = plt.figure(figsize=(1*np.sum(mask),6))
    plt.title(f'S1/S2 acceptance - {config_type}')
    colors = ('#1f77b4', '#ff7f0e')
    for axi, si in enumerate([1, 2]):
        if axi ==1:
            plt.sca(plt.gca().twinx())
            plt.xticks()
        plt.errorbar(df[mask]['number'], 
                     df[mask][f's{si}_acc'], 
                     yerr=np.array([e for e in df[mask][f's{si}_acc_err']]).T,
                     label = f'S{si} acceptance',
                     c = colors[axi],
                     **plot_kwargs,
                    )
        plt.axhline(df[f's{si}_acc'][0], ls = '--', c = colors[axi], label=f'default S{si} acceptance')
        plt.ylabel(f'S{si} acceptance')
        plt.gca().yaxis.label.set_color(colors[axi])
        plt.gca().tick_params(axis='y', colors=colors[axi])
        if axi==0:
            plt.xticks(df[mask]['number'], 
                       df[mask]['config'], 
                       rotation = 45, ha='right')
    fig.legend(loc=5)
    pema.save_canvas(f'{config_type}_update_lowe_scan', save_dir=os.path.join(fig_dir, 'update_scan'))

In [None]:
for config_type in np.unique(df['config_type'].values):
    if not config_type:
        continue
    print(config_type)
    mask = np.array([c == config_type for c in  df['config_type']])
    plot_kwargs = dict(markersize= 5,  ls='none', capsize=3, marker='o')
    fig = plt.figure(figsize=(1*np.sum(mask),6))
    plt.title(f'S1/S2 acceptance - {config_type}')
    colors = ('#1f77b4', '#ff7f0e')
    for axi, si in enumerate([1, 2]):
        if axi ==1:
            plt.sca(plt.gca().twinx())
            plt.xticks()
        plt.errorbar(df[mask]['number'], 
                     df[mask][f's{si}_bias'], 
                     yerr=df[mask][f's{si}_bias_err'],
                     label = f'S{si} bias',
                     c = colors[axi],
                     **plot_kwargs,
                    )
        plt.axhline(df[f's{si}_bias'][0], ls = '--', c = colors[axi], label=f'default S{si} acceptance')
        plt.ylabel(f'S{si} bias')
        plt.gca().yaxis.label.set_color(colors[axi])
        plt.gca().tick_params(axis='y', colors=colors[axi])
        if axi==0:
            plt.xticks(df[mask]['number'], 
                       df[mask]['config'], 
                       rotation = 45, ha='right')
    fig.legend(loc=5)
    pema.save_canvas(f'{config_type}_update_lowe_bias_scan', save_dir=os.path.join(fig_dir, 'update_scan'))

In [None]:
dat = st.get_array(selected_runs, 'match_acceptance_extended', progress_bar=False)

# Select config and work from there

In [None]:
summary_config = {
    's2_merge_max_duration': 30000,
    's2_merge_max_gap': 5000,
    'peaklet_gap_threshold': 500
                 }
st2 = st.new_context()
st2.set_config(summary_config)

In [None]:
default_acceptence = st.get_array(selected_runs, 'match_acceptance_extended',progress_bar=False)
custom_acceptence = st2.get_array(selected_runs, 'match_acceptance_extended',progress_bar=False)

In [None]:
def si_acceptance(si, binedges, on_axis='n_photon', nbins=50):
    mask = default_acceptence['type'] == si
    pema.summary_plots.acceptance_plot(
        default_acceptence[mask], 
        on_axis, 
        binedges, 
        nbins=nbins, 
        plot_label=default_label,
    )
    mask = custom_acceptence['type'] == si
    pema.summary_plots.acceptance_plot(
        custom_acceptence[mask], 
        on_axis, 
        binedges, 
        nbins=nbins, 
        plot_label=custom_label,
    )
    plt.ylabel('Arb. Acceptance')
    plt.title(f"S{si} acceptance")
    plt.legend()

def acceptance_summary(si, on_axis, axis_label, nbins = 100, plot_range = (0, 200), save_name=''):
    f, axes = plt.subplots(3, 1, figsize=(10,12), sharex=True)
    max_photons = 35
    plt.sca(axes[0])
    sel = ((default_acceptence['type'] == si) 
           & (default_acceptence[on_axis] > plot_range[0])
           & (default_acceptence[on_axis] < plot_range[1])
          )
    pema.summary_plots.plot_peak_matching_histogram(default_acceptence[sel], on_axis, bin_edges = nbins)
    plt.text(0.05,0.95, 
             default_label,
             transform=plt.gca().transAxes,
             ha = 'left',
             va = 'top',
             bbox=dict(boxstyle="round", fc="w")
            )
    plt.legend(loc=(1.01,0))
    plt.xlim(*plot_range)
  
    plt.sca(axes[1])
    sel = ((custom_acceptence['type'] == si) 
           & (custom_acceptence[on_axis] > plot_range[0])
           & (custom_acceptence[on_axis] < plot_range[1])
          )
    print(f'cust {np.sum(sel)}')
    pema.summary_plots.plot_peak_matching_histogram(custom_acceptence[sel], on_axis, bin_edges = nbins)
    plt.text(0.05,0.95, 
             custom_label,
             transform=plt.gca().transAxes,
             ha = 'left',
             va = 'top',
             bbox=dict(boxstyle="round", fc="w")
            )
    plt.legend(loc=(1.01,0))
    plt.xlim(*plot_range)
    
    plt.sca(axes[2])
    mask = default_acceptence['type'] == si
    pema.summary_plots.acceptance_plot(default_acceptence[mask], on_axis, plot_range, nbins=nbins, 
                                       plot_label=default_label)
    mask = custom_acceptence['type'] == si

    pema.summary_plots.acceptance_plot(custom_acceptence[mask], on_axis, plot_range, nbins=nbins, 
                                       plot_label=custom_label)
    plt.legend(loc=(1.01,0))
    plt.ylabel('Arb. acceptance faction')
    plt.xlim(*plot_range)
    plt.xlabel(axis_label)
    plt.ylim(0,1)

    plt.subplots_adjust(hspace=0)
    plt.suptitle(f'S{si} Acceptance', y=0.9)
    pema.save_canvas(f'{si}_acceptance_detailed_{save_name}', save_dir=fig_dir)

In [None]:
# bias_recons(1)
# bias_recons(2)

In [None]:
s1_max = default_acceptence['n_photon'][default_acceptence['type']==1].max()
s2_max = default_acceptence['n_photon'][default_acceptence['type']==2].max()

In [None]:
pema.summary_plots.rec_diff(
    default_acceptence,
    custom_acceptence,
    s1_kwargs=dict(bins=[100, 100], range=[[0,s1_max], [0.6, 1.1]]),
    s2_kwargs=dict(bins=[100, 100], range=[[0,s2_max], [0.6, 1.1]])
        )

In [None]:
acceptance_summary(si = 1, 
                   on_axis = 'n_photon',
                   axis_label = 'N photons simulated', 
                   nbins = 100, 
                   plot_range = (0, 50),
                   save_name = 'tot_compare',)

acceptance_summary(si = 2, 
                   on_axis = 'n_photon',
                   axis_label = 'N photons simulated', 
                   nbins = 100, 
                   plot_range = (0, 250),
                  save_name = 'tot_compare')

acceptance_summary(si = 2, 
                   on_axis = 'z',
                   axis_label = 'z (simulated) [cm]', 
                   nbins = 75, 
                   plot_range = (-160, 10),
                   save_name = 'tot_compare')

In [None]:
pema.compare_outcomes(st, default_acceptence,                 
                      st2, custom_acceptence,
                      only_different=False,
                      plot_fuzz=3000,
                      fuzz=0,
#                       fig_dir=os.path.join(fig_dir, 'total_config'),
                      max_peaks=10)