In [None]:
import numpy as np
import tensorflow as tf

from refiner.data import create_data_function, prepare_data_natural
from refiner.model import (
    SimpleModel,
    get_train,
    get_val,
    prepare_data_reweighter,
    prepare_data_refiner,
    apply_reweighter,
    apply_refiner,
    resample,
)
from refiner.plotting import plot_raw, plot_n_ratio, plot_w, plot_w2, plot_training

In [None]:
output_dir = "results/weight_shape/"
bins = np.linspace(0, 3, 120)
retrain = False

In [None]:
def weight_function(x):
    return (
        0.1
        + 0.5 * ((x > 0) & (x < 0.5))
        + (x - 0.75) * ((x > 0.75) & (x < 1.25))
        + (-x + 1.75) * ((x > 1.25) & (x < 1.75))
        + 2 * (0.5**2 - (x - 2.5) ** 2) * ((x > 2) & (x < 3))
    )

In [None]:
data = pos, neg, pos_weights, neg_weights = create_data_function(n_pos=10_000_000, function_pos=weight_function)
plot_raw(data=data, bins=bins, path=output_dir + "raw.pdf")

In [None]:
epochs, batch_size = 10, 1024

In [None]:
reweighter = SimpleModel()
if retrain:
    x_train, y_train, w_train = get_train(*prepare_data_reweighter(*data))
    validation_data = get_val(*prepare_data_reweighter(*data))
    reweighter.compile(
        n_train=x_train.shape[0],
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=(0.001, 0.000001),
    )
    logger = reweighter.fit(
        x_train,
        y_train,
        sample_weight=w_train,
        validation_data=validation_data,
        epochs=epochs,
        batch_size=batch_size,
    )

    plot_training(logger.history, title="Reweighter", path=output_dir + "training_reweighter.pdf")
    reweighter.model.save(output_dir + f"reweighter.keras")
else:
    reweighter.model = tf.keras.models.load_model(output_dir + f"reweighter.keras")

In [None]:
refiner = SimpleModel()
if retrain:
    x_train, y_train, w_train = get_train(*prepare_data_refiner(*data))
    validation_data = get_val(*prepare_data_refiner(*data))
    refiner.compile(
        n_train=x_train.shape[0],
        epochs=epochs,
        batch_size=batch_size,
        learning_rate=(0.001, 0.000001),
    )
    logger = refiner.fit(
        x_train,
        y_train,
        sample_weight=w_train,
        validation_data=validation_data,
        epochs=epochs,
        batch_size=batch_size,
    )

    plot_training(logger.history, title="Refiner", path=output_dir + "training_refiner.pdf")
    refiner.model.save(output_dir + f"refiner.keras")
else:
    refiner.model = tf.keras.models.load_model(output_dir + f"refiner.keras")

In [None]:
data_natural = get_val(*prepare_data_natural(*data))
data_reweighter = get_val(*apply_reweighter(*data, reweighter=reweighter))
data_refiner = get_val(*apply_refiner(*data, refiner=refiner))

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, path=output_dir + "counts.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=np.linspace(-0.025, 1.025, 22), path=output_dir + "weights.pdf")