In [None]:
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm

from refiner.data import load_data_tt, prepare_data_natural
from refiner.model import (
    SimpleModel,
    get_train,
    get_val,
    prepare_data_reweighter,
    prepare_data_refiner,
    apply_reweighter,
    apply_refiner,
)
from refiner.plotting import plot_raw, plot_n_ratio_multi, plot_w, plot_w2, plot_w_2d_scatter

In [None]:
output_dir = "results/tt_ensemble/"
bins = np.linspace(0, 1, 20)
transform = lambda x: 100 * x[:, 0]
xlabel = "Lepton $p_T$ [GeV]"
# transform = lambda x: 100 * x[:, 30]
# xlabel = "Leading ISR jet $p_T$ [GeV]"
n_jets, n_features = 15, 5
retrain = False
data_kwargs = dict(test_size=0.2, random_state=42)

In [None]:
data = pos, neg, pos_weights, neg_weights = load_data_tt(n_jets=n_jets)

In [None]:
plot_raw(data=data, bins=np.linspace(0, 300, 20), transform=transform, path=output_dir + f"raw_0.pdf", xlabel=xlabel)

In [None]:
epochs, batch_size = 10, 1024

In [None]:
reweighters = [SimpleModel(input_shape=(n_jets*n_features,)) for i in range(10)]
if retrain:
    x_train, y_train, w_train = get_train(*prepare_data_reweighter(*data), **data_kwargs)
    validation_data = get_val(*prepare_data_reweighter(*data), **data_kwargs)

    for reweighter in tqdm(reweighters):
        reweighter.compile(
            n_train=x_train.shape[0],
            epochs=epochs,
            batch_size=batch_size,
            learning_rate=(0.001, 0.000001),
        )
        logger = reweighter.fit(
            x_train,
            y_train,
            sample_weight=w_train,
            validation_data=validation_data,
            epochs=epochs,
            batch_size=batch_size,
            verbose=0,
        )
    for i, reweighter in enumerate(reweighters):
        reweighter.model.save(output_dir + f"reweighter_{i}.keras")
else:
    for i, reweighter in enumerate(reweighters):
        reweighter.model = tf.keras.models.load_model(output_dir + f"reweighter_{i}.keras")


In [None]:
refiners = [SimpleModel(input_shape=(n_jets*n_features,)) for i in range(10)]
if retrain:
    x_train, y_train, w_train = get_train(*prepare_data_refiner(*data), **data_kwargs)
    validation_data = get_val(*prepare_data_refiner(*data), **data_kwargs)

    for refiner in tqdm(refiners):
        refiner.compile(
            n_train=x_train.shape[0],
            epochs=epochs,
            batch_size=batch_size,
            learning_rate=(0.001, 0.000001),
        )
        logger = refiner.fit(
            x_train,
            y_train,
            sample_weight=w_train,
            validation_data=validation_data,
            epochs=epochs,
            batch_size=batch_size,
            verbose=0,
        )
    for i, refiner in enumerate(refiners):
        refiner.model.save(output_dir + f"refiner_{i}.keras")
else:
    for i, refiner in enumerate(refiners):
        refiner.model = tf.keras.models.load_model(output_dir + f"refiner_{i}.keras")


In [None]:
data_natural = get_val(*prepare_data_natural(*data), **data_kwargs)
data_reweighters = [get_val(*apply_reweighter(*data, reweighter=reweighter), **data_kwargs) for reweighter in reweighters]
data_refiners = [get_val(*apply_refiner(*data, refiner=refiner), **data_kwargs) for refiner in refiners]

In [None]:
# for i in range(n_jets * n_features):
#     plot_n_ratio_multi(
#         data=data_natural,
#         reweighter=data_reweighters,
#         refiner=data_refiners,
#         transform=lambda x: x[:, i],
#         bins=60,
#         ratio_unc="std",
#         path=output_dir + f"counts_{i}.pdf",
#     )

In [None]:
plot_n_ratio_multi(
    data=data_natural,
    reweighter=data_reweighters,
    refiner=data_refiners,
    transform=transform,
    bins=np.linspace(0, 300, 20),
    ratio_unc="std",
    path=output_dir + f"counts_0.pdf",
    xlabel=xlabel,
)

In [None]:
plot_w(data=data_natural, reweighter=data_reweighters[0], refiner=data_refiners[0], bins=np.linspace(-1.3, 1.3, 23), path=output_dir + "weights.pdf")