# PREAMBLE
<script
  src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"
  type="text/javascript">
</script>

In [None]:

import numpy as np
import pandas as pd
from validphys.api import API
from validphys.loader import FallbackLoader

%matplotlib inline

l = FallbackLoader()

# Definition of the input

In [None]:
fitname = "240509-rs-alphas-tcm-closure"

mhou_fit = False

# COMPUTATION OF $\alpha_s$

In [None]:
fit = API.fit(fit=fitname)

prior_pdf = fit.as_input()["theorycovmatconfig"]["pdf"]
# prior_pdf = "240409-01-rs-symm_pos_pseudodata"

common_dict = dict(
    dataset_inputs={"from_": "fit"},
    fit=fit.name,
    fits=[fit.name],
    use_cuts="fromfit",
    metadata_group="nnpdf31_process",
)

theoryids_dict = ({
        "point_prescription": {"from_": "theorycovmatconfig"},
        "theoryids":{ "from_": "scale_variation_theories"},
        "t0theoryid": {"from_": "theory"},
        "theoryid": {"from_": "theory"},
        "theory": {"from_": "fit"},
        "theorycovmatconfig": {"from_": "fit"},
    } | common_dict)
theoryids = API.theoryids(**theoryids_dict)
theory_plus = theoryids[1].id
theory_mid = theoryids[0].id
theory_min = theoryids[2].id

# Inputs for central theory (used to construct the alphas covmat)
inps_central = dict(theoryid=theory_mid, pdf=prior_pdf, **common_dict)

# Inputs for plus theory (used to construct the alphas covmat)
inps_plus = dict(theoryid=theory_plus, pdf=prior_pdf, **common_dict)

# Inputs for minus theory prediction (used to construct the alphas covmat)
inps_minus = dict(theoryid=theory_min, pdf=prior_pdf, **common_dict)

# inputs for the computation of the prediction of the fit with cov=C+S, where S is computed using the
# inps_central, inps_plus, and inps_minus dictionaries
inps_central_fit = dict(theoryid=theory_mid, pdf={"from_": "fit"}, **common_dict)

In [None]:
prior_theorypreds_central = API.group_result_table_no_table(**inps_central).iloc[:, 2:].mean(axis=1)

In [None]:
prior_theorypreds_plus = API.group_result_table_no_table(**inps_plus).iloc[:, 2:].mean(axis=1)

In [None]:
prior_theorypreds_minus = API.group_result_table_no_table(**inps_minus).iloc[:, 2:].mean(axis=1)

In [None]:
gamma = prior_theorypreds_plus + prior_theorypreds_minus - 2 * prior_theorypreds_central

# z = 1 / (gamma @ np.linalg.inv(C) @ (prior_theorypreds_central - dat_central) + 1)
covmat_scaling_factor = fit.as_input().get("theorycovmatconfig",{}).get("rescale_alphas_covmat",1.0)

In [None]:
# Get the values of alphas...
alphas_plus = API.theory_info_table(theory_db_id=theory_plus).loc["alphas"].iloc[0]
alphas_central = API.theory_info_table(theory_db_id=theory_mid).loc["alphas"].iloc[0]
alphas_min = API.theory_info_table(theory_db_id=theory_min).loc["alphas"].iloc[0]

# ... and make sure the alphas shift in both directions is symmetric
delta_alphas_plus = alphas_plus - alphas_central
delta_alphas_min = alphas_central - alphas_min
if abs(delta_alphas_min - delta_alphas_plus) > 1e-6:
    raise ValueError("alphas shifts in both directions is not symmetric")
else:
    alphas_step_size = delta_alphas_min

In [None]:
beta_tilde = np.sqrt(covmat_scaling_factor) * (alphas_step_size / np.sqrt(2)) * np.array([1, -1])
S_tilde = beta_tilde @ beta_tilde

In [None]:
delta_plus = (np.sqrt(covmat_scaling_factor) / np.sqrt(2)) * (
    prior_theorypreds_plus - prior_theorypreds_central
)
delta_minus = (np.sqrt(covmat_scaling_factor) / np.sqrt(2)) * (
    prior_theorypreds_minus - prior_theorypreds_central
)

beta = [delta_plus, delta_minus]
S_hat = beta_tilde @ beta

S = np.outer(delta_plus, delta_plus) + np.outer(delta_minus, delta_minus)
S = pd.DataFrame(S, index=delta_minus.index, columns=delta_minus.index)

In [None]:
# beta_tilde = np.sqrt(covmat_scaling_factor) * (alphas_step_size)
# S_tilde = beta_tilde * beta_tilde

# beta = np.sqrt(covmat_scaling_factor) * (
#     prior_theorypreds_plus - prior_theorypreds_minus
# )

# S_hat = beta_tilde * beta

# S = np.outer(beta, beta)
# S = pd.DataFrame(S, index=beta.index, columns=beta.index)

In [None]:
try:
    stored_covmat = pd.read_csv(
        fit.path / "tables/datacuts_theory_theorycovmatconfig_user_covmat.csv",
        sep="\t",
        encoding="utf-8",
        index_col=2,
        header=3,
        skip_blank_lines=False,
    )
except FileNotFoundError:
    stored_covmat = pd.read_csv(
        fit.path / "tables/datacuts_theory_theorycovmatconfig_theory_covmat_custom.csv",
        index_col=[0, 1, 2],
        header=[0, 1, 2],
        sep="\t|,",
        engine="python",
    ).fillna(0)
    storedcovmat_index = pd.MultiIndex.from_tuples(
        [(aa, bb, np.int64(cc)) for aa, bb, cc in stored_covmat.index],
        names=["group", "dataset", "id"],
    )  # make sure theoryID is an integer, same as in S
    stored_covmat = pd.DataFrame(
        stored_covmat.values, index=storedcovmat_index, columns=storedcovmat_index
    )
    stored_covmat = stored_covmat.reindex(S.index).T.reindex(S.index)

if not np.allclose(S, stored_covmat):
    print("Reconstructed theory covmat, S, is not the same as the stored covmat!")

In [None]:
theorypreds_fit = API.group_result_table_no_table(**inps_central_fit).iloc[:, 2:]

In [None]:
t0theoryid = API.t0theoryid(**theoryids_dict).id

# Experimental covariance matrix
C = API.groups_covmat(
    use_t0=True,
    datacuts={"from_": "fit"},
    t0pdfset={"from_": "datacuts"},
    theoryid=t0theoryid,
    **common_dict
)

In [None]:
# # MHOU covmat saved as user uncertainties
# try:
#     mhou_fit = fit.as_input()["theorycovmatconfig"]["use_user_uncertainties"]
#     if mhou_fit:
#         mhou_covmat = API.user_covmat(**(inps_central_fit|fit.as_input()['theorycovmatconfig']))
#         exp_covmat = C # we don't use exp_covmat, but may be useful to keep
#         C = C + mhou_covmat
# except:
#     pass

In [None]:
# Different from the prediction of the mean PDF (i.e. replica0)
mean_prediction = theorypreds_fit.mean(axis=1)

X = np.zeros_like(C.values)
for i in range(theorypreds_fit.shape[1]):
    X += np.outer(
        (theorypreds_fit.iloc[:, i] - mean_prediction),
        (theorypreds_fit.iloc[:, i] - mean_prediction),
    )
X *= 1 / theorypreds_fit.shape[1]

In [None]:
pseudodata = API.read_pdf_pseudodata(**common_dict)

In [None]:
dat_reps = pd.concat(
    [i.pseudodata.reindex(prior_theorypreds_central.index) for i in pseudodata], axis=1
)
# dat_central = API.group_result_central_table_no_table(**inps_central)["data_central"]

In [None]:
import gc; gc.collect()

In [None]:
invcov = np.linalg.inv(C + S)
delta_T_tilde = -S_hat @ invcov @ (mean_prediction - dat_reps.mean(axis=1))
P_tilde = S_hat.T @ invcov @ X @ invcov @ S_hat + S_tilde - S_hat.T @ invcov @ S_hat
pred = alphas_central + delta_T_tilde
unc = np.sqrt(P_tilde)
print(rf"Prediction for $\alpha_s$: {pred:.4f} Â± {unc:.4f}")