# Supernova Signal - WFSim

Create waveform simulator instruction file based on a supernova model.

```
Created : October 2021
Last Update : 28-10-2021
Melih Kara kara@kit.edu
Ricardo Peres
```

**ToDo**<br>
Recoil spectra sampling is pretty sparse and sampling from that gives discrete values. Interpolate the final spectra and sample from that.

## Table of Contents
- None

In [None]:
# import sys
# sys.modules['admix'] = None

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import LogNorm
import multihist as mh
import _pickle as pickle
import scipy.interpolate as itp
import click

import strax, straxen, wfsim, cutax
import nestpy
import datetime, os
from tqdm.notebook import tqdm

In [None]:
with open(r"rates_combined.pickle", "rb") as input_file:
    rates_Er, rates_t, recoil_energy_bins, timebins  = pickle.load(input_file)

In [None]:
from Supernova_Models import SN_lightcurve
from SN_plotter import Plotter
M30sn_model  = SN_lightcurve(progenitor_mass = 30,
                             metallicity= 0.02,
                             time_of_revival = 100, 
                             distance = 10)

M30sn_model.get_recoil_spectra1D()
M30sn_model.get_recoil_spectra2D()

In [None]:
plotter = Plotter(M30sn_model)
# fig, ax, samples = plotter.plot_sampled_energies()
# plotter.plot_sampled_energies(x='time', xscale='log');

#### Object attributes

In [None]:
# tot_rates1D = M30sn_model.total_rate1D # integrated(summed) between 0-10s 
# tot_rates2D = M30sn_model.total_rate2D
# recoil_energy_bins = M30sn_model.recoil_en
# timebins = M30sn_model.t
# nu_energies = M30sn_model.mean_E
# rates_Er, rates_t = M30sn_model._get_1Drates_from2D()
# Er_sample_E = M30sn_model.sample_from_recoil_spectrum(N_sample=100000)
# Er_sample_t = M30sn_model.sample_from_recoil_spectrum(x='time',N_sample=100000)

## dump to a pickle
# combined_dicts = [rates_Er,rates_t, recoil_energy_bins, timebins]

# with open(r"rates_combined.pickle", "wb") as output_file:
#     pickle.dump(combined_dicts, output_file)

The `sample_from_recoil_spectrum` can return 0 values, which might be a problem, as there would be no signal at these times (or worse, at negative times.)

In [None]:
def _inverse_transform_sampling( x_vals, y_vals, n_samples):
    cum_values = np.zeros(x_vals.shape)
    y_mid = (y_vals[1:]+y_vals[:-1])*0.5
    cum_values[1:] = np.cumsum(y_mid*np.diff(x_vals))
    inv_cdf = itp.interp1d(cum_values/np.max(cum_values), x_vals)
    r = np.random.rand(n_samples)
    return inv_cdf(r)

def sample_from_recoil_spectrum(x='energy', N_sample=1):
    if x.lower()=='energy':
        spectrum = rates_Er['Total']
        xaxis = recoil_energy_bins
        ## interpolate
        intrp_rates = itp.interp1d(xaxis, spectrum, kind="cubic", fill_value="extrapolate")
        xaxis = np.linspace(xaxis.min(), xaxis.max(), 200)
        spectrum = intrp_rates(xaxis)            
    elif x.lower()=='time':
        spectrum = rates_t['Total']
        xaxis = timebins
    else: return print('choose x=time or x=energy')
    sample = _inverse_transform_sampling(xaxis, spectrum, N_sample)
    return sample

In [None]:
sample_t = sample_from_recoil_spectrum(x='time',N_sample=10000)
sample_E = sample_from_recoil_spectrum(N_sample=10000)
fig, (ax1,ax2) = plt.subplots(ncols=2, figsize=(14,3))
ax1.hist(sample_t, range=(min(sample_t),0.1), bins=100);
ax1.set_xlabel('sampled times');
ax2.hist(sample_E, bins=200);

In [None]:
# kwargs = dict(norm=LogNorm(), bins=100) # range=[[0, 3_000], [0, 20_000]],

# mh_cut = mh.Histdd(sample_t, sample_E, **kwargs)

# mh_cut.plot(log_scale=True,
#             cblabel='count',
#             cmap=plt.get_cmap('jet'),
#             alpha=0.7,  colorbar_kwargs=dict(orientation="vertical", 
#                                              pad=0.05,
#                                              aspect=30, 
#                                              fraction=0.1));
# plt.ylabel('Recoil Energy [keV]')
# plt.xlabel('Time [s]');

In [None]:
E_thr =  np.array([np.trapz(rates_Er['Total'][i:], recoil_energy_bins[i:]) for i,foo in enumerate(recoil_energy_bins)])

plt.plot(E_thr*4.7);
plt.ylabel('events/keVnr');
plt.xlabel('Recoil Energy threshold');
plt.axvline(0, color='k');

Take the fiducial volume as 4.7 tonnes for now. There should be ~75 events in total.

---
## Import WFSim

In [None]:
versions_str = straxen.print_versions(('strax','straxen','cutax','wfsim'), return_string=True)
straxen.print_versions(('strax','straxen','cutax','wfsim'))

In [None]:
# wfsim.instruction_dtype

Based on nest's naming convention [here](https://github.com/XENONnT/WFSim/blob/2c614b0f7b0d7c7adc516f6188e281857e8d7e22/wfsim/core.py#L22)

Ricardo's script for instruction generation https://github.com/XENONnT/analysiscode/blob/master/S2_shape_width/simulation/make_submit_files.py

In [None]:
volume = 4.7

In [None]:
def instructions_SNdata_magnificient(nevents, interaction_type=0, 
                                     dump_csv=False, filename=None,
                                     below_cathode=False):
    ''' WFsim instructions to simulate Supernova NR peak.
        Parameters
        ----------
        nevents : `int`
            total number of events desired
            
        Notes
        -----
        For each SN signal there is E_thr[0]*volume events generated
        to generate nevents we shift times nevents/(number of single SN events)
        
    '''
    n = nevents
    A, Z  = 131.293, 54
    lxe_density = 2.862 #g/cm^3
    drift_field = 200 #V/cm
    
    number_of_events = np.ceil(E_thr[0]*volume).astype(int)
    nr_iterations = np.ceil(n/number_of_events).astype(int)
    rolled_sample_size = int(nr_iterations*number_of_events)
    n = rolled_sample_size
    sample_E = np.ones(rolled_sample_size)*-1
    sample_t = np.ones(rolled_sample_size)*-1
    ## shifted time sampling
    for i in range(nr_iterations):
        from_ = int(i*number_of_events)
        to_ = int((i+1)*number_of_events)
        time_shift = i*20 # add 20 sec to each iteration
        sample_E[from_:to_] = sample_from_recoil_spectrum(N_sample=number_of_events)
        sample_t[from_:to_] = sample_from_recoil_spectrum(x='time',N_sample=number_of_events) + time_shift
        
        # SN signal also has pre-SN neutrino, so if there are negative times boost them
        minnum = np.min(sample_t)
        if minnum <= 0: sample_t -= minnum      
            
    instructions = np.ones(2*n, dtype = wfsim.instruction_dtype)
    instructions[:] = -1
    instructions['time'] = (1e8 * sample_t).repeat(2)+1000000
    
    instructions['event_number'] = np.arange(0, n).repeat(2)
    instructions['type'] = np.tile([1,2], n)
    instructions['recoil'][:] = interaction_type
    instructions['local_field'][:] = 18
    
    r = np.sqrt(np.random.uniform(0, straxen.tpc_r**2, n)) 
    t = np.random.uniform(-np.pi, np.pi, n)
    instructions['x'] = np.repeat(r * np.cos(t), 2)
    instructions['y'] = np.repeat(r * np.sin(t), 2)
    if below_cathode:
        instructions['z'] = np.repeat(np.random.uniform(-straxen.tpc_z-12, 0, n), 2)
    else:
        instructions['z'] = np.repeat(np.random.uniform(-straxen.tpc_z, 0, n), 2)
    
    interaction_type = nestpy.INTERACTION_TYPE(interaction_type)    
    nc = nestpy.nestpy.NESTcalc(nestpy.nestpy.VDetector())
    
    quanta, exciton, recoil, e_dep = [], [], [], []
    for energy_deposit in tqdm(sample_E, desc='generating instructions from nest'):
        interaction = nestpy.INTERACTION_TYPE(interaction_type)
        y = nc.GetYields(interaction, energy_deposit, lxe_density, drift_field, A, Z)
        q = nc.GetQuanta(y, lxe_density)
        quanta.append(q.photons)
        quanta.append(q.electrons)
        exciton.append(q.excitons)
        exciton.append(0)
        # both S1 and S2
        recoil += [interaction_type, interaction_type]
        e_dep += [energy_deposit, energy_deposit]

    instructions['amp'] = quanta
    instructions['local_field'] = drift_field
    instructions['n_excitons'] = exciton
    instructions['recoil'] = recoil
    instructions['e_dep'] = e_dep
    instructions_df = pd.DataFrame(instructions)
    instructions_df = instructions_df[instructions_df['amp'] > 0]
    instructions_df.sort_values('time', inplace=True)
    if dump_csv:
        tdy = str(datetime.date.today())
        filename = filename or f'{tdy}_instructions.csv'
        instructions_df.to_csv(f'/dali/lgrandi/melih/sn_wfsim/instructions/{filename}', index=False)
        print(f'Saved in -> /dali/lgrandi/melih/sn_wfsim/instructions/{filename}')
    return instructions_df

In [None]:
a = instructions_SNdata_magnificient(10000, dump_csv=True, filename='SN_wfsim_instructions.csv')
b = instructions_SNdata_magnificient(100000, dump_csv=True, filename='SN_wfsim_instructions_100k.csv')

In [None]:
# plt.hist(a['time']);
# plt.hist(a['e_dep']);

In [None]:
def clean_repos(pattern='*'):
    if input('Are you sure to delete all the data?\n'
            '\t/dali/lgrandi/melih/sn_wfsim/instructions/*\n'
            '\t/dali/lgrandi/melih/sn_wfsim/logs/*\n'
            '\t/dali/lgrandi/melih/sn_wfsim/strax_data/*\n>>>').lower() == 'y':
        os.system(f'rm -r /dali/lgrandi/melih/sn_wfsim/instructions/{pattern}')
        os.system(f'rm -r /dali/lgrandi/melih/sn_wfsim/logs/{pattern}')
        os.system(f'rm -r /dali/lgrandi/melih/sn_wfsim/strax_data/{pattern}')
        
def see_repos():
    if not os.path.isdir('/dali/lgrandi/melih/sn_wfsim/logs/'):
        os.mkdir('/dali/lgrandi/melih/sn_wfsim/logs/')
    if not os.path.isdir('/dali/lgrandi/melih/sn_wfsim/strax_data/'):
        os.mkdir('/dali/lgrandi/melih/sn_wfsim/strax_data/')
    click.secho('\n >>Instructions\n', bg='blue', color='white');    os.system(f'ls -r /dali/lgrandi/melih/sn_wfsim/instructions/')
    click.secho('\n >>Logs\n', bg='blue', color='white');            os.system(f'ls -r /dali/lgrandi/melih/sn_wfsim/logs/')
    click.secho('\n >>Existing data\n', bg='blue', color='white');   os.system(f'ls -r /dali/lgrandi/melih/sn_wfsim/strax_data/')

In [None]:
# clean_repos()
see_repos()

In [None]:
st = cutax.contexts.xenonnt_sim_SR0v0_cmt_v5(output_folder='/dali/lgrandi/melih/sn_wfsim/strax_data')
st.set_config(dict(fax_config_override=dict(field_distortion_on=False)))
rid = 'SN_wfsimdata'
for kind in ['truth', 'raw_records', 'peaks']:
    click.echo(f'{kind:15s} is {click.style(" stored ", bold=True, bg="green") if st.is_stored(rid, kind) else click.style(" not stored ",bold=True,bg="red")}')

In [None]:
df = st.get_df(rid,'event_info')
truth = st.get_df(rid,'truth')
peak_basics = st.get_df(rid, 'peak_basics')
event_info = st.get_df(rid, 'event_info')

In [None]:
def quality_plot(ev):
    plt.figure(figsize=(15,10))
    plt.subplots_adjust(wspace=0.2, hspace=0.5)
    plt.subplot(321); plt.title('s1 area')
    plt.hist(ev['s1_area'], bins = 200, histtype = 'step')
    
    plt.subplot(322); plt.title('s2 area')
    plt.hist(ev['s2_area'], bins = 200, histtype = 'step')
    
    plt.subplot(323); plt.title('s2 width')
    plt.hist(ev['s2_range_50p_area'], bins = 200, histtype = 'step')
    
    plt.subplot(324); plt.title('counts above area threshold') #plt.title('dt [us]')
#     plt.hist(ev['drift_time']/1000, bins = 200, histtype = 'step')
    counts_above_i = []
    s2range = np.linspace(ev['s2_area'].min(), 1500, 50)
    for i, thr in enumerate(s2range):
        counts_above_i.append(len(ev['s2_area'][ev['s2_area']>thr]))
    plt.semilogy(s2range, counts_above_i)
    plt.xlabel('S2 area')

    plt.subplot(325); plt.title('z [cm]')
    plt.hist(ev['z'][ev['z']<0], bins = 200, histtype = 'step')
    plt.axvline(-straxen.tpc_z, ls = '--', color = 'k')
    
    plt.subplot(326);
    plt.gca().set_aspect('equal')
    s1 = ev['s1_area']
    s2 = ev['s2_area']
    mask = (s1>0) & (s2>0) 
    kwargs = dict(norm=LogNorm(), bins=100) # range=[[0, 3_000], [0, 20_000]],
    mh_cut = mh.Histdd(s1[mask], s2[mask]/100, **kwargs)
    mh_cut.plot(log_scale=True, cblabel='count',
            cmap=plt.get_cmap('jet'), alpha=0.7,  
            colorbar_kwargs=dict(orientation="vertical", 
                                 pad=0.05,
                                 aspect=30, 
                                 fraction=0.1));
    plt.xlabel('s1 area [P.E.]'); plt.ylabel('s2 area [100 P.E.]');
    
def plot_xy_rz(evt):
    plt.figure(figsize=(12,4))
    plt.subplots_adjust(wspace=0.9)
    plt.subplot(121)
    plt.hist2d(evt['x'],evt['y'], range = ((-70,70),(-70,70)),cmin = 1,bins = 100) # , norm = LogNorm()
    plt.xlabel('x [cm]')
    plt.ylabel('y [cm]')
    plt.gca().set_aspect('equal')
    plt.colorbar()
    plt.subplot(122)
    plt.hist2d(np.power(evt['r'],2),evt['z'], range = ((0,70**2),(-160,10)),cmin = 1,bins = 100) # , norm = LogNorm()
    plt.colorbar()
    plt.xlabel('r$^2$ [cm]')
    plt.ylabel('z [cm]')
    plt.show()

In [None]:
quality_plot(event_info)

In [None]:
plot_xy_rz(event_info)

In [None]:
st.data_info('event_info')['Field name'].values

In [None]:
len(a['z']), len(truth['z']), len(event_info['z']), len(event_info['z_naive']), peak_basics

In [None]:
plt.hist(truth['time']);

---