In [None]:
import numpy as np
import tensorflow as tf

from refiner.data import load_data_tt, prepare_data_natural
from refiner.model import (
    SimpleModel,
    get_train,
    get_val,
    prepare_data_reweighter,
    prepare_data_refiner,
    apply_reweighter,
    apply_refiner,
    resample,
)
from refiner.plotting import plot_raw, plot_n_ratio, plot_n_ratio_multi, plot_w, plot_w2, plot_w_2d_scatter

In [None]:
output_dir = "results/tt_ensemble/"
bins = np.linspace(0, 3, 30)
transform = lambda x: x[:, 0]
n_jets, n_features = 15, 5

In [None]:
data = pos, neg, pos_weights, neg_weights = load_data_tt(n_jets=n_jets)
plot_raw(data=data, bins=bins, transform=transform, path=output_dir + "raw.pdf")

In [None]:
epochs, batch_size = 10, 1024

In [None]:
x_train, y_train, w_train = get_train(*prepare_data_reweighter(*data))
validation_data = get_val(*prepare_data_reweighter(*data))

reweighters = []
for i in range(10):
    reweighter = SimpleModel(input_shape=(n_jets*n_features,))
    reweighter.compile(
        n_train=x_train.shape[0],
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=(0.001, 0.0000001),
    )
    logger = reweighter.fit(
        x_train,
        y_train,
        sample_weight=w_train,
        validation_data=validation_data,
        epochs=epochs,
        batch_size=batch_size,
    )
    reweighters.append(reweighter)

In [None]:
x_train, y_train, w_train = get_train(*prepare_data_refiner(*data))
validation_data = get_val(*prepare_data_refiner(*data))

refiners = []
for i in range(10):
    refiner = SimpleModel(input_shape=(n_jets*n_features,))
    refiner.compile(
        n_train=x_train.shape[0],
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=(0.001, 0.00001),
    )
    logger = refiner.fit(
        x_train,
        y_train,
        sample_weight=w_train,
        validation_data=validation_data,
        epochs=epochs,
        batch_size=batch_size,
    )
    refiners.append(refiner)


In [None]:
data_natural = get_val(*prepare_data_natural(*data))
data_reweighters = [get_val(*apply_reweighter(*data, reweighter=reweighter)) for reweighter in reweighters]
data_refiners = [get_val(*apply_refiner(*data, refiner=refiner)) for refiner in refiners]

In [None]:
for i in range(10):
    plot_n_ratio_multi(
        data=data_natural,
        reweighter=data_reweighters,
        refiner=data_refiners,
        transform=lambda x: x[:, i],
        ratio_unc="hilo",
        path=output_dir + f"counts_{i}.pdf",
    )

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=np.arange(-1.2, 3.0, 0.2), path=output_dir + "weights.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, transform=transform, path=output_dir + "variances.pdf")

In [None]:
data_reweighter_resampled = resample(*data_reweighter)
data_refiner_resampled = resample(*data_refiner)
len(data_reweighter_resampled[0]), len(data_refiner_resampled[0])

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, transform=transform, path=output_dir + "counts_resampled.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=np.arange(-1.2, 3.0, 0.2), path=output_dir + "weights_resampled.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, transform=transform, path=output_dir + "variances_resampled.pdf")

In [None]:
plot_w_2d_scatter(data_reweighters[0], data_refiners[0], transform=transform, path=output_dir + "weights_2d.pdf")

In [None]:
np.max(hist.T)

In [None]:
arr = np.array([-1])
