In [None]:
import numpy as np
import tensorflow as tf

from refiner.data import load_data_tt, prepare_data_natural
from refiner.model import (
    simple_model,
    get_train,
    get_val,
    prepare_data_reweighter,
    prepare_data_refiner,
    apply_reweighter,
    apply_refiner,
    resample,
)
from refiner.plotting import plot_raw, plot_n_ratio, plot_w, plot_w2, plot_training

In [None]:
output_dir = "results/tt/"
bins = np.arange(0, 3.2, 0.2)
transform = lambda x: x[:, 0]

In [None]:
data = pos, neg, pos_weights, neg_weights = load_data_tt()
plot_raw(data=data, bins=bins, transform=transform, path=output_dir + "raw.pdf")

In [None]:
reweighter = simple_model(input_shape=(5,))
reweighter.compile(optimizer="rmsprop", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), metrics=["accuracy"])
x_train, y_train, w_train, = get_train(*prepare_data_reweighter(*data))
history_reweighter = reweighter.fit(x_train, y_train, sample_weight=w_train, epochs=5, validation_split=0.2, batch_size=256)
plot_training(history_reweighter, title="Reweighter", path=output_dir + "training_reweighter.pdf")

In [None]:
refiner = simple_model(input_shape=(5,))
refiner.compile(optimizer="rmsprop", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), metrics=["accuracy"])
x_train, y_train, w_train = get_train(*prepare_data_refiner(*data))
history_refiner = refiner.fit(x_train, y_train, sample_weight=w_train, epochs=5, validation_split=0.2, batch_size=256)
plot_training(history_refiner, title="Refiner", path=output_dir + "training_refiner.pdf")

In [None]:
data_natural = get_val(*prepare_data_natural(*data))
data_reweighter = get_val(*apply_reweighter(*data, reweighter=reweighter))
data_refiner = get_val(*apply_refiner(*data, refiner=refiner))

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, transform=transform, path=output_dir + "counts.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=np.arange(-1.2, 3.0, 0.2), path=output_dir + "weights.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, transform=transform, path=output_dir + "variances.pdf")

In [None]:
data_reweighter_resampled = resample(*data_reweighter)
data_refiner_resampled = resample(*data_refiner)
len(data_reweighter_resampled[0]), len(data_refiner_resampled[0])

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, transform=transform, path=output_dir + "counts_resampled.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=np.arange(-1.2, 3.0, 0.2), path=output_dir + "weights_resampled.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, transform=transform, path=output_dir + "variances_resampled.pdf")