# Run pema

Authors:
 - Angevaare, Joran <j.angevaare@nikhef.nl> (based on the Peak_Classification_Tester) 


## This notebook ##

**Goal**
 - See how clustering and classification may be improved by Kr simulations?

**Known issues**
 -


**Start up ``strax`` + load tools**

In [None]:
import pema
import os

In [None]:
if not os.path.exists('init.py'):
    init = os.path.join(pema.__path__[0], '..', 'bin', "pema_init.py")
    !ln -s $init init.py
%run init.py

In [None]:
base_dir = '/mnt/d/pema/'
data_name = f'pema_test_{pema.__version__}'
fig_dir = os.path.join(base_dir, f'figures_summary_{data_name}')
data_dir = os.path.join(base_dir, 'processed_data')
raw_data_dir = os.path.join(base_dir, 'raw_data')
instructions_csv = f"./inst_{data_name}.csv"

# Output naming
default_label = 'Normal clustering'
custom_label = 'Changed clustering'

**Initialize the wavefrom simulator with KR instructions**

In [None]:
run_list = list(f'{r:06}' for r in range(11665,11665+6))
run_list

In [None]:
# Just some id which allows CMT to load
run_id = run_list[0]

In [None]:
# setting up instructions like this may take a while. You can set e.g. 
instructions = dict(
    event_rate=100, # Don't make too large -> overlapping truth info
    chunk_size=5, # keep large -> less overhead but takes more RAM
    nchunk=100, # set to 100
    photons_low=1, #PE
    photons_high=100, #PE
    electrons_low=1, #
    electrons_high=100,
    tpc_radius=straxen.tpc_r,
    tpc_length=148.1, # TPC length approx
    drift_field = 18.5, # V/cm VERIFY!
    timing = 'uniform', #Double S1 peaks uniform over time   
)
fax_override = {
    's1_pattern_map' : ntauxfiles.get_abspath('XENONnT_s1_xyz_patterns_LCE_corrected_qes_MCva43fa9b_wires.pkl'),
    's2_pattern_map' : ntauxfiles.get_abspath('XENONnT_s2_xy_patterns_LCE_corrected_qes_MCva43fa9b_wires.pkl')}

## write instructions to CSV

In [None]:
pema.inst_to_csv(
    instructions, 
    instructions_csv, 
    get_inst_from = pema.rand_instructions)

In [None]:
# TODO can we add noise?
config_update = dict(
    detector='XENONnT',
    fax_file=os.path.abspath(instructions_csv),
    fax_config='fax_config_nt_low_field.json',
    fax_config_override=fax_override,
)

In [None]:
st = pema.pema_context(base_dir=base_dir,
                       config_update=config_update,
                       raw_dir=raw_data_dir,
                       data_dir=data_dir)

In [None]:
st.set_context_config(
    {'allow_shm': True,
     'allow_lazy': False,
     'timeout': 300,
     'max_messages': 10,
    }
)

In [None]:
st.set_config({'s2_merge_max_duration': 30000})
summary_config = {
    's2_merge_max_duration': 30000,
    's2_merge_max_gap': 5000,
    'peaklet_gap_threshold': 525
                 }
st.set_config(summary_config)

In [None]:
run_list

In [None]:
jobs = []
for r in run_list:
    script_writer = pema.ProcessRun(st, r, 
                                    ('raw_records', 'records', 
                                     'peaklets', 'peaks_matched',
                                     'event_info'))
    cmd, name = script_writer.make_cmd()
    ret = script_writer.exec_local(cmd, name)
    jobs.append(script_writer)

In [None]:
for j in jobs:
    j.log_file.communicate()

In [None]:
script_writer = pema.ProcessRun(st, run_list, ('raw_records', 'records', 'peaklets', 'peaks_matched', 'event_info'))

In [None]:
# script_writer.log_file.communicate()

In [None]:
script_writer.all_stored()

In [None]:
environ_init = '''eval "$(/home/angevaare/software/Miniconda3/bin/conda shell.bash hook)"
conda activate strax
export PATH=/home/angevaare/software/Miniconda3/envs/strax/bin:$PATH'''
environ_init

# Submit the scripts
First we need to do records, then we can do peaklets (affected by clustering) after that.

# Make configs for the parameters to change
Let's loop over the options we want to scan. We will write those to a dict and work from there.

In [None]:
opts = st.show_config('peaks')[['option', 'default', 'current']].to_records()

In [None]:
options = opts['option']
values = opts['default']
current = opts['current']
mask = current != '<OMITTED>'
values[mask] = current[mask]
res = {k: v for (k, v) in zip(options, values)}

In [None]:
keep_keys = [
 'peaklet_gap_threshold',
 'peak_left_extension',
 'peak_right_extension',
 'peak_min_pmts',
 'peak_split_filter_wing_width',
 'peak_split_min_area',
 'peak_split_iterations',
 'tight_coincidence_window_left',
 'tight_coincidence_window_right',
 's1_max_rise_time',
 's1_max_rise_time_post100',
 's1_min_coincidence',
 's2_min_pmts',
 's2_merge_max_area',
 's2_merge_max_gap',
 's2_merge_max_duration']

In [None]:
res = {k: v for (k,v )in res.items() if k in keep_keys}
summary_config = res.copy()
res

In [None]:
conf_tot = [{}]
i=0
keys_seen = []
_st = st.new_context()
for k, v in summary_config.items():
    _type = type(v)
    for factor in [#5, 
                   2, 1.5, 1.25, 1.1, 1, 1/1.1, 1/1.25, 1/1.5, 1/2, 
#                     1/5
    ]:
        value = _type(v*factor)
        conf = {k: value}
        _st.set_config(conf)
        ev_key = _st.key_for('0', 'events')

        if str(ev_key) in keys_seen and factor != 1:
            continue
        else:
            conf_tot.append(conf)
            keys_seen.append(str(ev_key))

confs = tuple(conf_tot.copy())
len(conf_tot)
# confs

In [None]:
confs

# Submit the logs

In [None]:
run_list

In [None]:
# wanna start small?
selected_runs = [r for r in run_list if st.is_stored(r, 'records')]
print(f'Doing runs:\n{selected_runs}\n{len(selected_runs)/len(run_list)*100:.1f}%')
all_runs = len(selected_runs) == len(run_list)

In [None]:
confs

In [None]:
import psutil
# psutil.cpu_percent()
def get_mem():
    psutil.virtual_memory()
    # you can convert that object to a dictionary 
    dict(psutil.virtual_memory()._asdict())
    # you can have the percentage of used RAM
    return psutil.virtual_memory().percent

In [None]:
job_registry = []
target = ('raw_records', 'records', 
    'peaklets',
    'peak_basics',
#     'events',
    'truth_matched', 
    'match_acceptance_extended',
)
RAM = 15000
queue_max = 200
check_que_after = 50
part = 'xenon1t'


for i, conf in enumerate(tqdm(confs, desc='configs')):
    job = pema.ProcessRun(st, run_id=run_list, target=target, config=conf)
    cmd, job_name = job.make_cmd()

    job_registry.append(job)


#         job.exec_dali(cmd, 
#                       job_name, 
#                       bash_activate = environ_init
#                      ram = RAM,
#                      partition = part
#                      max_hours= '04:00:00')
    if not job.all_stored(return_bool=True):
#             cmd=cmd.replace('  ', ' ')
#             p = subprocess.Popen(cmd.split(' '), 
#                                      stdout=subprocess.PIPE, 
#                                      stderr=subprocess.PIPE, 
#                                      universal_newlines=True)
#             ps.append(p)
#             pass
        job.exec_local(cmd, job_name)
        while get_mem()>90 or psutil.cpu_percent()>90:
            time.sleep(1)
#         job.log_file.communicate() 
#         if i % check_que_after:
#             q = !squeue -u `echo $USER` | grep $part
#             while len(q)> queue_max:
#                 q = !squeue -u `echo $USER` | grep $part
#                 print(f'waiting 10s, queue is full. {len(q)}')
#                 time.sleep(10)

In [None]:
psutil.cpu_percent()

In [None]:
[j.log_file.communicate() for j in tqdm(job_registry) if j.log_file is not None];

In [None]:
pd.set_option('display.max_rows', 500)

In [None]:
pd.concat([j.all_stored() for j in job_registry])

In [None]:
pd.concat([j.all_stored(show_key=True) for j in job_registry])

### Load the simulated data using strax ("default")

### Load the simulated data using different config parameters
Let's copy the old context and run with different settings

In [None]:
def compute_acceptance(data):
    total = len(data)
    found = np.sum(data['acceptance_fraction'])
    return found/total, pema.binom_interval(found, total)

In [None]:
from collections import defaultdict

In [None]:
res = defaultdict(list)

for i, job in enumerate(tqdm(job_registry)):
    config = job.config
    if config in res['config']:
        continue
    elif np.sum([job.all_stored()[t] for t in job.target]) < len(job.target):
        continue
    
    try:
        data = job.st.get_array(run_list, 'match_acceptance_extended', progress_bar=False)
    except (AssertionError, TypeError):
        continue
    res['number'].append(i)
    res['config'].append(config)
    res['config_type'].append(list(config.keys()))
    for si in range(1,3):
        sel = data['type'] == si
        acceptance, (low, high) = compute_acceptance(data[sel])
        res[f's{si}_acc'].append(acceptance)
        res[f's{si}_err'].append([acceptance-low, high-acceptance])

In [None]:
df = pd.DataFrame(res)
df

In [None]:
for config_type in np.unique(df['config_type'].values):
    if not config_type:
        continue
    print(config_type)
    mask = np.array([c == config_type for c in  df['config_type']])
    plot_kwargs = dict(markersize= 5,  ls='none', capsize=3, marker='o')
    fig = plt.figure(figsize=(10,6))
    plt.title(f'S1/S2 acceptance - {config_type}')
    colors = ('#1f77b4', '#ff7f0e')
    for axi, si in enumerate([1, 2]):
        if axi ==1:
            plt.sca(plt.gca().twinx())
            plt.xticks()
        plt.errorbar(df[mask]['number'], 
                     df[mask][f's{si}_acc'], 
                     yerr=np.array([e for e in df[mask][f's{si}_err']]).T,
                     label = f'S{si} acceptance',
                     c = colors[axi],
                     **plot_kwargs,
                    )
        plt.axhline(df[f's{si}_acc'][0], ls = '--', c = colors[axi], label=f'default S{si} acceptance')
        plt.ylabel(f'S{si} acceptance')
        plt.gca().yaxis.label.set_color(colors[axi])
        plt.gca().tick_params(axis='y', colors=colors[axi])
        if axi==0:
            plt.xticks(df[mask]['number'], 
                       df[mask]['config'], 
                       rotation = 45, ha='right')
    fig.legend(loc=5)
    pema.save_canvas(f'{config_type}_scan', save_dir=os.path.join(fig_dir, 'scan'))