In [None]:
import pandas as pd
import glob
import seaborn as sns
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import itertools
import warnings
warnings.filterwarnings('ignore')
from matplotlib.colors import LogNorm


files_path = '../data/for_analysis/v2/'


def load_data(path):
    with open(path, 'r') as JSON:
        json_dict = json.load(JSON)
    return json.loads(json_dict)   
    

def name_change(name):
    er = name.split('ER')
    he = name.split('He')
    if len(er) > 1:
        return ['ER', er[-1].split('_')[1]]
    elif len(he) > 1:
        return ['He', he[-1].split('_')[1]]
    else:
        return None
    
    
def get_f1_data_from_raw(input_dict):
    x = np.array(input_dict['ROC']['full'])[:, 0, :, :]
    y = np.array(input_dict['ROC']['full'])[:, 1, :, :]
    f1 = (2*x*y)/(x+y)
    f1 = np.nan_to_num(f1)
    result = f1.max(axis=1).ravel()
    n_filters = x.shape[-1]
    table = pd.DataFrame({'path':list(itertools.chain.from_iterable(itertools.repeat(x, n_filters) for x in input_dict['Image_path'])),
                          'image':list(itertools.chain.from_iterable(itertools.repeat(x, n_filters) for x in input_dict['Image_index'])),
                      'filter':input_dict['Filter_name'],
                      'parameter': input_dict['Filter_parameter'],
                      'f1': result})
    table = pd.DataFrame.from_records(table.path.apply(name_change), columns=['particle', 'energy']).join(table.drop('path', axis=1))
    table['parameter'] = table.parameter.apply(lambda x: x[0])
    table.energy = table.energy.astype(np.int)
    table.parameter[table.parameter == 'n'] = 0
    table.parameter = table.parameter.astype(np.int)
    return table

## TODO : merge those functions
def get_data_from_raw(input_dict):
    x = np.array(input_dict['ROC']['full'])[:, 0, :, :]
    y = np.array(input_dict['ROC']['full'])[:, 1, :, :]
    energy = np.array(input_dict['Energy']['image_after_threshold'])
    n_filters = x.shape[2]
    table = pd.DataFrame({'path':list(itertools.chain.from_iterable(itertools.repeat(x, n_filters) for x in input_dict['Image_path'])),
                          'image':list(itertools.chain.from_iterable(itertools.repeat(x, n_filters) for x in input_dict['Image_index'])),
                          'filter':input_dict['Filter_name'],
                          'parameter': input_dict['Filter_parameter']})
    table['recall'] = np.hstack(x).T.tolist()
    table['precision'] = np.hstack(y).T.tolist()
    table['energy_threshold'] = np.hstack(energy).T.tolist()
    table = pd.DataFrame.from_records(table.path.apply(name_change), columns=['particle', 'energy']).join(table.drop('path', axis=1))
    table['parameter'] = table.parameter.apply(lambda x: x[0])
    table.energy = table.energy.astype(np.int)
    table.parameter[table.parameter == 'n'] = 0
    table.parameter = table.parameter.astype(np.int)
    return table

def fill_nan_nn(arr):
    mask = np.isnan(arr)
    idx = np.where(~mask,np.arange(mask.shape[1]),0)
    np.maximum.accumulate(idx,axis=1, out=idx)
    out = arr[np.arange(idx.shape[0])[:,None], idx]
    return out

## Carregando resultados

In [None]:
files = glob.glob(files_path + '*.json')

In [None]:
result_table = []
for file in files:
    result_table.append(get_f1_data_from_raw(load_data(file)))
result_table = pd.concat(result_table)

In [None]:
result_table.head(-2)

## Análise dos resultados
  * Desempenho dos filtros para cada tipo de partícula e valor de energia
  * Reconstrução da curva energia x integral dos clusters

### Análise por valor de Energia

In [None]:
#plt.figure(figsize=(30,15))
g = sns.catplot(x="energy", y="f1", hue="filter", col="particle", data=result_table.groupby(['particle','filter', 'energy', 'image']).agg({'f1':'max'}).reset_index(), kind="box", height=12, aspect=1)
axes = g.axes.ravel()
axes[0].grid()
axes[1].grid()
axes[0].set_xlabel('Energy', fontsize=18)
axes[0].set_ylabel('f1-score', fontsize=18)
axes[1].set_xlabel('Energy', fontsize=18)
axes[0].tick_params(axis='both', which='major', labelsize=18)
axes[1].tick_params(axis='both', which='major', labelsize=18)
axes[0].set_ylim([0, 1])
axes[1].set_ylim([0, 1])
axes[0].set_title('Electron recoil', fontsize=18)
axes[1].set_title('Nuclear recoil', fontsize=18)

In [None]:
only_median = result_table.groupby(['particle','filter', 'energy', 'image']).agg({'f1':'max'}).groupby(['particle', 'filter', 'energy']).agg('median').reset_index()

fig, ax = plt.subplots(1, 2, figsize=(20, 10))
sns.lineplot(x='energy', 
             y='f1',
             hue='filter',
             lw=3,
             #col = 'particle',
             #palette=['b','r'],
             data=only_median[only_median['particle']=='ER'],
             alpha = 0.2,
             ax = ax[0])
sns.lineplot(x='energy', 
             y='f1',
             hue='filter',
             lw=3,
             #col = 'particle',
             #palette=['b','r'],
             data=only_median[only_median['particle']=='He'],
             alpha = 0.2,
             ax = ax[1])

ax[0].grid()
ax[1].grid()
ax[0].set_xlabel('Energy', fontsize=18)
ax[0].set_ylabel('f1-score (50%)', fontsize=18)
ax[1].set_ylabel('f1-score (50%)', fontsize=18)
ax[1].set_xlabel('Energy', fontsize=18)
ax[0].tick_params(axis='both', which='major', labelsize=18)
ax[1].tick_params(axis='both', which='major', labelsize=18)
ax[0].set_ylim([0, 1])
ax[1].set_ylim([0, 1])
ax[0].set_title('Electron recoil', fontsize=18)
ax[1].set_title('Nuclear recoil', fontsize=18)
#ax[0].get_legend().remove()
#ax[1].get_legend().remove()

### Estimação de energia

#### Energia perdida após inserção e remoção do pedestal

In [None]:
teste = []
for file in files:
    d = load_data(file)
    data_frame_teste = pd.DataFrame([d['Image_path'], d['Image_index'], d['Energy']['image_truth'], d['Energy']['image_real']]).T
    teste.append(data_frame_teste)
energy_df = pd.concat(teste)
energy_df.columns = ['image_path', 'image_index', 'energy_truth', 'energy_real']

In [None]:
energy_df = pd.DataFrame.from_records(energy_df.image_path.apply(name_change).values, columns = ['particle', 'energy']).join(energy_df.reset_index().drop(['index', 'image_path', 'image_index'], axis=1))

In [None]:
energy_df.head()

In [None]:
energy_df[["energy", "energy_truth", "energy_real"]] = energy_df[["energy", "energy_truth", "energy_real"]].apply(pd.to_numeric)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(20, 10))
palette = sns.color_palette("mako_r", 2)
sns.lineplot(x="energy", y="value", hue="particle", data=pd.melt(energy_df, id_vars=['particle', 'energy']),  err_style="bars", style="variable",  palette=palette, ax=axes)
axes.grid()
axes.set_ylabel('Cluster integral', fontsize=18)
axes.set_xlabel('Energy', fontsize=18)
axes.set_xlim([0, 60.1])
axes.tick_params(axis='both', which='major', labelsize=18)
axes.legend(fontsize=18)

#### Energia após filtragem
    * No caso anterior temos uma rejeição de background de 100% e detecção de sinal de 100%. Assumi-se uma clusterização perfeita e o erro exibido é o de estimação de pedestal;
    * Com a filtragem (ou ausência desta) pixels podem ser classificados de maneira incorreta ( pixels de sinal podem ser considerados background e vice-versa)
    
    * Escolhendo os melhores filtros para cada particula e energia

In [None]:
agg_results = result_table.groupby(['particle', 'energy', 'filter', 'parameter']).agg('median')

In [None]:
agg_results = agg_results.reset_index()

In [None]:
agg_results

In [None]:
agg_results = agg_results.sort_values('f1').drop_duplicates(subset=['particle', 'filter', 'energy'], keep='last')

In [None]:
agg_results.drop('f1',axis=1, inplace=True)

In [None]:
agg_results

In [None]:
full_result_table = []
for file in files:
    full_result_table.append(get_data_from_raw(load_data(file)))
full_result_table = pd.concat(full_result_table)

In [None]:
full_result_table

In [None]:
filtered_full_result_table = agg_results.merge(full_result_table, how='inner', left_on=['particle', 'energy', 'filter','parameter'], right_on=['particle', 'energy', 'filter','parameter'])

In [None]:
precision_matrix = np.array(filtered_full_result_table['precision'].tolist())
recall_matrix = np.array(filtered_full_result_table['recall'].tolist())
energy_matrix = np.array(filtered_full_result_table['energy_threshold'].tolist())
precision_matrix = fill_nan_nn(precision_matrix)
recall_matrix = fill_nan_nn(recall_matrix)
energy_matrix = fill_nan_nn(energy_matrix)

In [None]:
xx, yy = np.where(precision_matrix>0.95)
position_array = np.array([xx,yy]).T
list_of_valid_index = np.split(position_array[:, 1], np.cumsum(np.unique(position_array[:, 0], return_counts=True)[1])[:-1])

In [None]:
best_recall = []
for index, count in list(enumerate(np.unique(position_array[:, 0]))):
    best_recall.append([count, max(recall_matrix[count,list_of_valid_index[index]]), max(energy_matrix[count,list_of_valid_index[index]])])

In [None]:
p_results = pd.DataFrame(np.array(best_recall), columns=['index', 'recall_at_p', 'energy_at_p'])
p_results.index = p_results['index'].astype(int)
p_results.drop(['index'], axis=1, inplace=True)

In [None]:
filtered_full_result_table = filtered_full_result_table.join(p_results)

In [None]:
filtered_full_result_table.head()

In [None]:
energy_filters_result = filtered_full_result_table[['particle', 'energy', 'filter', 'energy_at_p']]

In [None]:
energy_concat = pd.melt(energy_df, id_vars=['particle', 'energy'])
energy_concat.columns = ['particle', 'energy', 'filter', 'energy_at_p']
energy_concat

In [None]:
filt = ['gaussian', 'energy_real', 'cygno']
data = energy_filters_result.append(energy_concat)
#data = energy_filters_result[energy_filters_result['filter'].isin(filt)]
data = data[data['filter'].isin(filt)]

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(20, 10))
palette = sns.color_palette("hls", len(filt))
sns.lineplot(x="energy", y="energy_at_p", hue="filter", data=data,  err_style="bars", style="particle",  palette=palette, ax=axes)
axes.grid()
axes.set_ylabel('Cluster integral', fontsize=18)
axes.set_xlabel('Energy', fontsize=18)
axes.set_xlim([0, 60.1])
axes.tick_params(axis='both', which='major', labelsize=18)
axes.legend(fontsize=18)

In [None]:
#plt.figure(figsize=(30,15))
g = sns.catplot(x="energy", y="energy_at_p", hue="filter", col="particle", data=data, kind="box", height=12, aspect=1)
axes = g.axes.ravel()
axes[0].grid()
axes[1].grid()
axes[0].set_xlabel('Energy', fontsize=18)
axes[0].set_ylabel('energy', fontsize=18)
axes[1].set_xlabel('Energy', fontsize=18)
axes[0].tick_params(axis='both', which='major', labelsize=18)
axes[1].tick_params(axis='both', which='major', labelsize=18)
axes[0].set_title('Nuclear recoil', fontsize=18)
axes[1].set_title('Electron recoil', fontsize=18)