Fax Tutorial
=============

Hello friend.
Welcome to the basic tutorial on how to simulate waveforms with the lastest fax version in strax.
Here we'll just demonstrate the basic functionality. For more indepth analysis stuff, checkout the straxen tutorials for more detailed thing.

In [None]:
import numpy as np
import strax
import straxen
import wfsim

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from multihist import Histdd, Hist1d
from scipy import stats

Setting everything up
=================

First we need to define the right context. The thing which differs now is where to get the plugin to provide raw records. By default this is the DAQ Reader. Now we do not want this so we register wfsim.RawRecordsFromFax. I think it is self explanatory where this plugin tells strax to get raw records from

In [None]:
url_base = 'https://raw.githubusercontent.com/XENONnT/strax_auxiliary_files/master/fax_files/'

In [None]:
st = strax.Context(
    register=wfsim.RawRecordsFromFax,
    config=dict(detector="XENONnT",
                fax_config=url_base + 'fax_config_nt.json'),
    **straxen.contexts.common_opts)

If you want to use the XENONnT configuration you'll need to make a small modification to strax. Strax has a number of tpc channels hardcoded and you'll need to overwrite that with the correct nT number of channels
To do so you'll need to run:

In [None]:
straxen.plugins.pulse_processing.n_tpc = 494
straxen.plugins.PeakPositions.n_top_pmts = 253

Additionally you'll also want to get strax to use the right pmt gains and a nT neural net for position reconstruction
For this do:

In [None]:
st.set_config(dict(to_pe_file=url_base + 'to_pe_nt.npy',
                   nn_architecture = url_base + 'mlp_model.json',
                   nn_weights = url_base +  'mlp_model.h5'))

Now we need to define a run id. What you give it doesn't really matter, since strax will look for data and make new if it doesn't find anything. And this is what you want.
Strax will use the run id to get the electron lifetime and pmt gains from a database, and returns placeholders if the run doesn't exist. (Currently the electron lifetime doesn't return a placeholder, this should be fixed)

In [None]:
run_id = '1'

Strax has a build in timeout which we need to modify. When simulating stuff we'll probably take more time then is allowed by the timeout so we need to increase it

In [None]:
strax.Mailbox.DEFAULT_TIMEOUT=1000

Defining instructions
===============

The last detail before we can start. To give fax instructions you now have two possibilities. Either read in a MC output file and let a super basic nestpy convert it to instructions, or have them be random generated.
First I'll show how to read from a file

In [None]:
file = '/dali/lgrandi/pgaemers/fax_files/Xenon1T_WholeLXe_Pb212_00008_g4mc_G4.root'
st.set_config(dict(fax_file= file))

The alternative is to let fax make some random things for you. This will call the function strax_interface.rand_instructions in case you want to change it up a little bit.
We need to tell fax 3 parameters. nchunk tells strax over how many files to smear out the data. Currently it is highy advised to set this to 1 to avoid crashes. event_rate determains how many events per second to make,  so this will determaine, approximatly, the spacing between events. Finally chunk_size defines the length of a chunk in seconds.
The total number of events generated is the product of all three numbers

In [None]:
st.set_config(dict(fax_file=None))
st.set_config(dict(nchunk=2, event_rate = 1, chunk_size = 50))

What you also get do is define your own instruction generating function and overwrite the default one.
You can do this as follows:

In [None]:
def super_awesome_custom_instruction(c):
    n = c['nevents'] = c['event_rate'] * c['chunk_size'] * c['nchunk']
    c['total_time'] = c['chunk_size'] * c['nchunk']

    instructions = np.zeros(2 * n, dtype=wfsim.strax_interface.instruction_dtype)
    uniform_times = c['total_time'] * (np.arange(n) + 0.5) / n
    instructions['time'] = np.repeat(uniform_times, 2) * int(1e9)
    instructions['event_number'] = np.digitize(instructions['time'],
         1e9 * np.arange(c['nchunk']) * c['chunk_size']) - 1
    instructions['type'] = np.tile([1, 2], n)
    instructions['recoil'] = ['er' for i in range(n * 2)]

    r = np.sqrt(np.random.uniform(0, 2500, n))
    t = np.random.uniform(-np.pi, np.pi, n)
    instructions['x'] = np.repeat(r * np.cos(t), 2)
    instructions['y'] = np.repeat(r * np.sin(t), 2)
    instructions['z'] = np.repeat(np.random.uniform(-100, 0, n), 2)

    nphotons = np.random.uniform(2000, 2050, n)
    nelectrons = 10 ** (np.random.uniform(1, 4, n))
    instructions['amp'] = np.vstack([nphotons, nelectrons]).T.flatten().astype(int)

    return instructions

The instruction function gets called with a config, and from this it will take the event rate, chunk size and number of chunks to figure out how many events you. wanted. to have.
The things you are (probably?) most interested in playing. with are the position and signal sizes. 

Finally you evaluate the following line, which will. tell the simulator to execute your. custom function rather then the predefined one:

In [None]:
wfsim.strax_interface.rand_instruction = super_awesome_custom_instruction

What actually happens?
===


What happens behind the scenes is that the instructions are first grouped together in chunks. Then we loop over the instructions and the full chunk is returned before starting with the next one.

We use a S1 and S2 class to calculate the arrival times of the photons and the channels which have been hit. Then we'll hand them over to the Pulse class to calculate the currents in the channels. Finally the currents go to a RawData class where we fake the digitizer response.

S1
==
For S1s we start with calculating the light yield based on the position of the interaction, and draw the number of photons seen from a poisson distribution.

Second we calculate the arrival times of the photons. This is based on the scintiallation of the xenon atoms. It is dependend on the recombination time, the singlet and triplet fractions.

Finally the channels are calcuated. Based on the pattern map we use a interpolation map to get a probability distribution for channels to be hit for a S1 signal based on the position of the interaction, and then we draw from this distribution for every photon.

S2
===
S2s are slightly more complicated. First we need to drift the electrons up, and while doing so we'll lose some of them.
To get the photon timings, we first need to get the arrival times of the electrons at the gas interface based on a diffusion model. Then we can calculate the photon timings based on a luminesience model for every individual electron. And for the channels we do the same trick with the interpolating map.


Pulse
===
When we have our lists of channels and timing we can generate actual pulses. First we add a pmt transition time. Then we loop over all channels, calculate the double pe emission probabilities, and add a current in the pmt channel based on the arrival time. This is all stored in a big dictionary. Afterwards this is passed to our fake digitizers which then returns you with your very own pretty data


Getting down to bussiness
---

Now we have acces to all the normal strax data types, and another one called 'truth' which holds the simulation instructions. Calling it follows the normal strax convention.

In [None]:
# Remove any previously simulated data, if such exists
# !rm -r strax_data

records = st.get_array(run_id,'records')
peaks = st.get_array(run_id, ['peaks','peak_classification'])
data = st.get_df(run_id, 'event_info')

truth = st.get_df(run_id, 'truth')

Now it is time to make pretty plots and see if what we makes actually makes any sense

In [None]:
r = records[records['channel']==10]
ns = np.arange(len(records['data'][0]))
plt.plot(ns, r['data'][:10].mean(axis=0), linestyle='steps-mid')
plt.fill_between(
    ns,
    np.percentile(r['data'], 25, axis=0),
    np.percentile(r['data'], 75, axis=0),
    step='mid', alpha=0.3, linewidth=0)
plt.xlabel("Sample in record")
plt.ylabel("Amplitude (ADCc)")
plt.title('Average  amplitude in a channel')
plt.show()

In [None]:
plt.plot(peaks[peaks['type']==1]['data'][:10].T)
plt.xlabel("Sample in record")
plt.ylabel("Amplitude (PE/sample)")
plt.title("Some S1's")
plt.show()

In [None]:
plt.plot(peaks[peaks['type']==2]['data'][:10].T)
plt.xlabel("Sample in record")
plt.ylabel("Amplitude (PE/sample)")
plt.title("Some S2's")
plt.show()

(Since the instructions are generated randomly I do not know what your results look like, so I make some assumptions):

Wauw, they really look amazing, right? For further analysis, we'd like to look at if the created peaks more or less look like the instructions

In [None]:
print(f"The number of found events is {len(data)}, while the number of events in the instruction was {len(truth)/2} ")

In [None]:
truth = st.get_df(run_id, 'truth')
data = st.get_df(run_id, 'event_info')

truth = truth[(truth['type'] == 1) & (truth['n_photon'] > 0)]
truth.sort_values(by='t_first_photon', inplace=True)
timing_grid = truth['t_first_photon']

truth['merge_index'] = np.digitize(truth['time'], timing_grid)
data['merge_index'] = np.digitize(data['time'], timing_grid)

truth.drop('event_number', axis=1, inplace=True)
data = data.merge(truth, how='outer', left_on='merge_index', right_on='merge_index')

In [None]:
plt.rcParams['figure.figsize'] = (10, 5)
fig = plt.figure()

ax = fig.add_subplot(121)
mh = Histdd(data.s1_area, data.n_photon,
            bins=[np.logspace(0, 2.2, 101), np.logspace(0, 2.2, 101)])
plt.pcolormesh(mh.bin_edges[0], mh.bin_edges[1], mh.histogram.T, norm=LogNorm())
plt.xscale('log'); plt.yscale('log')
plt.xlabel('S1 (reconstructed)')
plt.ylabel('S1 (truth)')

ax = fig.add_subplot(122)
mh = Histdd(data.n_photon, (data.s1_area-data.n_photon)/data.n_photon, 
            bins=[np.logspace(0, 2.2, 101), np.linspace(-0.5, 0.5, 101)])
plt.pcolormesh(mh.bin_edges[0], mh.bin_edges[1], mh.histogram.T, norm=LogNorm())
plt.xscale('log')
plt.xlabel('S1 (reconstructed)')
plt.ylabel('Bias')

plt.show()

In [None]:
truth = st.get_df(run_id, 'truth')
data = st.get_df(run_id, 'event_info')

truth = truth[(truth['type'] == 2) & (truth['n_photon'] > 0)]
truth.sort_values(by='t_first_photon', inplace=True)
timing_grid = truth['t_first_photon']

truth['merge_index'] = np.digitize(truth['time'], timing_grid)
data['merge_index'] = np.digitize(data['time'], timing_grid)
truth.drop('event_number', axis=1, inplace=True)
data = data.merge(truth, how='outer', left_on='merge_index', right_on='merge_index')

In [None]:
plt.rcParams['figure.figsize'] = (10, 5)
fig = plt.figure()

ax = fig.add_subplot(121)
mh = Histdd(data.s2_area, data.n_photon,
            bins=[np.logspace(2, 4.5, 121), np.logspace(2, 4.5, 121)])
plt.pcolormesh(mh.bin_edges[0], mh.bin_edges[1], mh.histogram.T, norm=LogNorm())
plt.xscale('log'); plt.yscale('log')
plt.xlabel('S2 (reconstructed)')
plt.ylabel('S2 (truth)')

ax = fig.add_subplot(122)
mh = Histdd(data.n_photon, (data.s2_area-data.n_photon)/data.n_photon, 
            bins=[np.logspace(2, 4.5, 121), np.linspace(-0.5, 0.5, 121)])
plt.pcolormesh(mh.bin_edges[0], mh.bin_edges[1], mh.histogram.T, norm=LogNorm())
plt.xscale('log')
plt.xlabel('S2 (reconstructed)')
plt.ylabel('Bias')

plt.tight_layout()
plt.show()