# PREAMBLE
<script
  src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"
  type="text/javascript">
</script>

In [None]:
import numpy as np
import pandas as pd
import numpy.linalg as la
from validphys.api import API
from validphys.loader import FallbackLoader as Loader
from matplotlib import pyplot as plt
import yaml

l = Loader()

# Definition of the input

In [None]:
fit = "240209-rs-01-nnpdf40-alphas-tcm"

theory_plus = 207
theory_mid = 200
theory_min = 201

alphas_step_size = 0.002
alphas_central = 0.118

covmat_scaling_factor = 1

# COMPUTATION OF $\alpha_s$

In [None]:
fitpath = API.fit(fit=fit).path
filterpath = fitpath / "filter.yml"

with open(filterpath) as f:
    filterfile = yaml.safe_load(f)
pdf_ori = filterfile["theorycovmatconfig"]["pdf"]

common_dict = dict(
    dataset_inputs={"from_": "fit"},
    fit=fit,
    fits=[fit],
    use_cuts="fromfit",
    metadata_group="nnpdf31_process",
)

# Inputs for central theory (used to construct the alphas covmat)
inps_central = dict(
    theoryid=theory_mid,
    pdf=pdf_ori,
    use_t0=True,
    datacuts={"from_": "fit"},
    t0pdfset={"from_": "datacuts"},
    **common_dict
)

# Inputs for plus theory (used to construct the alphas covmat)
inps_plus = dict(theoryid=theory_plus, pdf=pdf_ori, **common_dict)

# Inputs for minus theory prediction (used to construct the alphas covmat)
inps_minus = dict(theoryid=theory_min, pdf=pdf_ori, **common_dict)

# inputs for the computation of the prediction of the fit with cov=C+S, where S is computed using the
# inps_central, inps_plus, and inps_minus dictionaries
inps_central_fit = dict(theoryid=theory_mid, pdf={"from_": "fit"}, **common_dict)

In [None]:
datth_central = API.groups_central_values_no_table(**inps_central)

In [None]:
datth_plus = API.groups_central_values_no_table(**inps_plus)

In [None]:
datth_minus = API.groups_central_values_no_table(**inps_minus)

In [None]:
beta_tilde = np.sqrt(covmat_scaling_factor) * (alphas_step_size / np.sqrt(2)) * np.array([1, -1])
S_tilde = beta_tilde @ beta_tilde

In [None]:
delta_plus = (np.sqrt(covmat_scaling_factor) / np.sqrt(2)) * (datth_plus - datth_central)
delta_minus = (np.sqrt(covmat_scaling_factor) / np.sqrt(2)) * (
    datth_minus - datth_central
)

beta = [delta_plus, delta_minus]
S_hat = beta_tilde @ beta

S = np.outer(delta_plus, delta_plus) + np.outer(delta_minus, delta_minus)
S = pd.DataFrame(S, index=delta_minus.index, columns=delta_minus.index)

In [None]:
try:
    stored_covmat = pd.read_csv(
        fitpath / "tables/datacuts_theory_theorycovmatconfig_user_covmat.csv",
        sep="\t",
        encoding="utf-8",
        index_col=2,
        header=3,
        skip_blank_lines=False,
    )
except FileNotFoundError:
    stored_covmat = pd.read_csv(
        fitpath / "tables/datacuts_theory_theorycovmatconfig_theory_covmat_custom.csv",
        index_col=[0, 1, 2],
        header=[0, 1, 2],
        sep="\t|,",
        engine="python",
    ).fillna(0)
    storedcovmat_index = pd.MultiIndex.from_tuples(
        [(aa, bb, np.int64(cc)) for aa, bb, cc in stored_covmat.index],
        names=["group", "dataset", "id"],
    )  # make sure theoryID is an integer, same as in S
    stored_covmat = pd.DataFrame(
        stored_covmat.values, index=storedcovmat_index, columns=storedcovmat_index
    )
    stored_covmat = stored_covmat.reindex(S.index).T.reindex(S.index)

if not np.allclose(S, stored_covmat):
    print("Reconstructed theory covmat, S, is not the same as the stored covmat!")

In [None]:
# in case we'd like to save the covmat to be used in a fit
# S.to_csv("alphascovmat_01190_extended_nnpdf40_without_nuclearuncs_ernfits_fixed.csv")

In [None]:
datth_central_fit = API.group_result_table_no_table(**inps_central_fit)
th_replicas_fit = datth_central_fit.iloc[:, 2:].to_numpy()

# Experimental covariance matrix
C = API.groups_covmat(**inps_central)

In [None]:
# Different from the prediction of the mean PDF (i.e. replica0)
mean_prediction = np.mean(th_replicas_fit[:], axis=1)

X = np.zeros_like(C.values)
for i in range(th_replicas_fit.shape[1]):
    X += np.outer(
        (th_replicas_fit[:, i] - mean_prediction), (th_replicas_fit[:, i] - mean_prediction)
    )
X *= 1 / th_replicas_fit.shape[1]

In [None]:
pseudodata = API.read_pdf_pseudodata(**common_dict)

In [None]:
dat_central = np.mean(
    [i.pseudodata.reindex(datth_central.index.to_list()).to_numpy().flatten() for i in pseudodata],
    axis=0,
)
# dat_central = datth_central["data_central"]

In [None]:
invcov = la.inv(C + S)
delta_T_tilde = S_hat @ invcov @ (dat_central - mean_prediction)
P_tilde = S_hat.T @ invcov @ X @ invcov @ S_hat + (S_tilde - S_hat.T @ invcov @ S_hat)
pred = alphas_central + delta_T_tilde
unc = np.sqrt(P_tilde)
print(rf"Prediction for $\alpha_s$: {pred:.5f} ± {unc:.5f}")