# Used for testing different scaling schemes for the signals

In [None]:
import os
#import cWGANGP_model_def
#import pandas as pd
import tensorflow as tf
from tensorflow import keras
import numpy as np
#from sklearn.preprocessing import MinMaxScaler
from NuRadioReco.utilities import units
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import AutoMinorLocator
plt.style.use('plot_style.txt')

### Load data

In [None]:
data = np.load('/mnt/md0/aholmberg/data/signal_had_14_10deg.npy')
condition = data[:, :2]
shower_n = data[:, 3]
signals = data[:, 3:]
signals_filtered = np.load('/mnt/md0/aholmberg/data/signal_had_14_filtered_10deg.npy')

### Define a normalization for the signals based on the parametrisation defined in the alvarez model

In [None]:
def get_time_normalized(theta, trace, n_index=1.78, R =1 * units.km):
    cherenkov_angle = np.arccos(1. / n_index)
    scale_factor = np.expand_dims(((np.sin(theta) / np.sin(cherenkov_angle)) * np.exp(-np.log(2) * (theta - cherenkov_angle) ** 2) / R), axis=-1)
    trace_scaled = trace / scale_factor
    return trace_scaled


# See wat impact the different parts of the alvarez model have

In [None]:
N = 896
dt = 1e-10 * units.second
n_index = 1.78
R = 1 * units.km
condition = data[:, :2]
index = 9290
print(condition[index,:])
energy = condition[index, 0] * units.eV
theta = condition[index, 1] * units.rad
print(condition[index,0] / units.PeV, condition[index, 1] / units.deg)

In [None]:
freqs = np.fft.rfftfreq(N, dt)[1:]

cherenkov_angle = np.arccos(1. / n_index)

epsilon = np.log10(energy / units.TeV)
dThetaHad = 0
if (epsilon >= 0 and epsilon <= 2):
    dThetaHad = 500 * units.MHz / freqs * (2.07 - 0.33 * epsilon + 7.5e-2 * epsilon ** 2) * units.deg
elif (epsilon > 2 and epsilon <= 5):
    dThetaHad = 500 * units.MHz / freqs * (1.74 - 1.21e-2 * epsilon) * units.deg
elif(epsilon > 5 and epsilon <= 7):
    dThetaHad = 500 * units.MHz / freqs * (4.23 - 0.785 * epsilon + 5.5e-2 * epsilon ** 2) * units.deg
elif(epsilon > 7):
    dThetaHad = 500 * units.MHz / freqs * (4.23 - 0.785 * 7 + 5.5e-2 * 7 ** 2) * \
        (1 + (epsilon - 7) * 0.075) * units.deg
print(cherenkov_angle/units.deg)

In [None]:
f0 = 1.15 * units.GHz
E = 2.53e-7 * energy / units.TeV * freqs / f0 / (1 + (freqs / f0) ** 1.44)
E *= units.V / units.m / units.MHz
plt.plot(E)
E *= np.sin(theta) / np.sin(cherenkov_angle)
plt.plot(E)

In [None]:
tmp = np.zeros(len(freqs) + 1)
tmp3 = np.zeros(len(freqs) + 1)
tmp[1:] = E * np.exp(-np.log(2) * ((theta - cherenkov_angle) / dThetaHad) ** 2) / R
tmp3[1:] = E * np.exp(-np.log(2) * ((theta - cherenkov_angle)) ** 2) / R
plt.plot(tmp)
def missing_energy_factor(E_0):
    # Missing energy factor for hadronic cascades
    # Taken from DOI: 10.1016/S0370-2693(98)00905-8
    epsilon = np.log10(E_0 / units.TeV)
    f_epsilon = -1.27e-2 - 4.76e-2 * (epsilon + 3)
    f_epsilon += -2.07e-3 * (epsilon + 3) ** 2 + 0.52 * np.sqrt(epsilon + 3)
    return f_epsilon

tmp[1:] *= missing_energy_factor(energy)
plt.plot(tmp)
plt.plot(tmp3)

In [None]:
tmp2 = tmp.copy()
tmp[1:] /= np.sin(theta) / np.sin(cherenkov_angle)
tmp[1:] /= np.exp(-np.log(2) * ((theta - cherenkov_angle) ** 2) / dThetaHad) / R

tmp2 *= 0.5
tmp *= 0.5  # the factor 0.5 is introduced to compensate the unusual fourier transform normalization used in the ZHS code

trace = np.fft.irfft(tmp * np.exp(0.5j * np.pi)) / dt  # set phases to 90deg
trace = np.roll(trace, len(trace) // 2)

trace2 = np.fft.irfft(tmp2 * np.exp(0.5j * np.pi)) / dt  # set phases to 90deg
trace2 = np.roll(trace2, len(trace2) // 2)

arztrace = signals_filtered[index,:].copy()

In [None]:
trace2 /= np.sin(theta) / np.sin(cherenkov_angle)
trace2 /= np.exp(-np.log(2) * ((theta - cherenkov_angle)) ** 2) / R # actual norm  / dThetaHad
arztrace /= np.sin(theta) / np.sin(cherenkov_angle)
arztrace /= np.exp(-np.log(2) * ((theta - cherenkov_angle)) ** 2) / R # actual norm 

In [None]:
factor = np.sin(theta) / np.sin(cherenkov_angle) * np.exp(-np.log(2) * ((theta - cherenkov_angle)) ** 2) / R
factor

In [None]:
plt.plot(trace)
plt.plot(trace2)
plt.plot(arztrace)

In [None]:
plt.plot(trace*(np.power(10, (17 - np.log10(energy/units.eV)))))
plt.plot(trace2*(np.power(10, (17 - np.log10(energy/units.eV)))))
plt.plot(arztrace*(np.power(10, (17 - np.log10(energy/units.eV)))))

# Compare normalisation techniques

In [None]:
n = 1024*32
test_signals = signals_filtered[0:n*10:10, :]
test_condition = condition[0:n*10:10, :]
test_signals_escale = test_signals*np.expand_dims(1e19/test_condition[:,0], axis=-1)
test_signals_anglescale = get_time_normalized(test_condition[:, 1], test_signals)
#test_signals_scaled = test_signals*np.expand_dims(1e19/test_condition[:,0], axis=-1) * np.expand_dims((((test_condition[:, 1]/units.deg - cherenkov_angle/units.deg))**4 + 1)/3, axis=-1)
test_signals_scaled = test_signals*np.expand_dims(1e19/test_condition[:,0], axis=-1) * np.expand_dims((((test_condition[:, 1]/units.deg - cherenkov_angle/units.deg))**4 + 1)/6, axis=-1)
#test_signals_scaled = test_signals_anglescale * np.expand_dims(1e19/test_condition[:,0], axis=-1)
#test_signals_scaled = test_signals_anglescale * np.expand_dims(np.power(10, (16 - np.log10(test_condition[:, 0]))), axis=-1)
#test_signals_scaled = test_signals_anglescale * np.expand_dims(np.power(10, 1.6*np.tanh(16.5 - np.log10(test_condition[:, 0]))), axis=-1)
#test_signals_scaled = test_signals_anglescale * np.expand_dims(np.power(10, 0.001*(17 - np.log10(test_condition[:, 0]))**3), axis=-1)

### Make a histogram of the amplitudes 

In [None]:
max = np.max(np.abs(test_signals), axis=1)
max_escale = np.max(np.abs(test_signals_escale), axis=1)
max_ascale = np.max(np.abs(test_signals_anglescale), axis=1)
max_scaled = np.max(np.abs(test_signals_scaled), axis=1)

g = sns.histplot(max_scaled, bins=40, log_scale=(False,True))
g.axes.xaxis.set_minor_locator(  AutoMinorLocator(5))
g.axes.set_xlim(0,5.3)
g.axes.set_xlabel(r'max(abs(signal)) [V/m]')
g.figure.savefig('thesis/Exjobb-rapport/figures/scaled-dist.pdf', dpi=300)
# arg = np.argmax(max)
# print(test_condition[arg,0], test_condition[arg,1]/units.deg, arg)
""" x = range(0, n)
fig, ax = plt.subplots(2,2,figsize=(10,10))
ax[0, 0].scatter(x, max)
ax[0, 1].scatter(x, max_escale)
ax[1, 0].scatter(x, max_ascale)
ax[1, 1].scatter(x, max_scaled) """
#ax[1, 0].set_ylim(0,2)
#ax[0, 1].set_ylim(0,2)
#ax[1, 1].set_ylim(0,2)