In [5]:
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.templates.default = 'plotly_white'

import pandas as pd
import numpy as np
import pickle
import os
import re

In [6]:
_nsre = re.compile('([0-9]+)')
def natural_sort_key(s):
    return [int(text) if text.isdigit() else text.lower()
            for text in re.split(_nsre, s)]

In [7]:
def conditions(name_of_file, model, mode):
    if model in name_of_file:
        if 'fit' in name_of_file:
            if 'pkl' in name_of_file:
                if mode in name_of_file:
                    return True
    return False

In [8]:
path = 'plots/fv_old/fitst_junotao_v05d'
all_xi2 = []
all_files = os.listdir(path)
all_files.sort(key=natural_sort_key)
models = ['neural_network', 'xgboost']
modes = ['all', 'none']
for model in models:
    for mode in modes:
        xi2_array = []
        for name_of_file in all_files:
            if conditions(name_of_file, model, mode):
                with open(f'{path}/{name_of_file}', 'rb') as file:
                    xi2 = pickle.load(file)['fitresult']['juno']['fun']
                    xi2_array.append(xi2)
        all_xi2.append(xi2_array)

In [35]:
names = ['NN w. systematics', 'NN w/o systematics', 'BDT w. systematics', 'BDT w/o systematics']
colors = ['darkviolet', 'darkviolet', 'darkred', 'darkred']

In [68]:
Rs = np.linspace(14, 17.2, 33)
fig = go.Figure()#make_subplots(rows=2, cols=2,
#                    vertical_spacing=0.15)


# for row in range(2):
#     for col in range(2):
for i in range(4):
    if i%2 != 0:
        dash='dash'
    else:
        dash='solid'
        
    fig.add_trace(
        go.Scatter(
            x=Rs,
            y=all_xi2[i],
            mode='markers+lines',
            line=dict(
                dash=dash,
                color=colors[i],
            ),
            marker=dict(
                symbol='x',
                size=7
            ),
            name=names[i],
        ),
#             row=row+1, col=col+1
    )

xaxis = dict(
    showline=True,
    ticks='outside',
    mirror=True,
    tick0=14,
    dtick=0.5,
    linecolor='black',
    showgrid=True,
    gridcolor='grey',
    gridwidth=0.25,
)

yaxis = dict(
    showline=True,
    ticks='outside',
    mirror=True,
    linecolor='black',
    showgrid=True,
    dtick=1,
    range=[4.75, 11],
    gridcolor='grey',
    gridwidth=0.25,
    zeroline=True,
    zerolinecolor='black',
    zerolinewidth=0.25
)

# axis_params = {}
# for i in range(1, 5):
#     axis_params[f'xaxis{i}'] = xaxis
#     axis_params[f'yaxis{i}'] = yaxis
#     axis_params[f'xaxis{i}_title'] = 'R, m'
#     axis_params[f'yaxis{i}_title'] = 'σ'

fig.update_layout(
    xaxis_title = "R, m",
    yaxis_title = "Δχ²",
    xaxis=xaxis,
    yaxis=yaxis,
    legend=dict(
#         orientation="h",
        x=0.025,
        y=0.99,
#         y=1.12,
        bordercolor="Black",
        borderwidth=1
    ),
    font=dict(
        family="Times New Roman",
        size=18,
    ),
#     **axis_params,
#     height=700
)


fig.show()
pio.write_image(fig, 'plots/fv_old/xi2.pdf', width=900, height=600)