In [None]:
import numpy as np
import tensorflow as tf

from refiner.data import create_data_gaussian, prepare_data_natural
from refiner.model import (
    SimpleModel,
    get_train,
    get_val,
    prepare_data_reweighter,
    prepare_data_refiner,
    apply_reweighter,
    apply_refiner,
    resample,
)
from refiner.plotting import plot_raw, plot_n_ratio, plot_w, plot_w2, plot_training

In [None]:
output_dir = "results/gauss_hard/"
bins = np.linspace(-3, 3, 100)

In [None]:
data = pos, neg, pos_weights, neg_weights = create_data_gaussian(10_000_000, neg_frac=0.09, neg_scale=0.1)
plot_raw(data=data, bins=bins, path=output_dir + "raw.pdf")

In [None]:
epochs, batch_size = 10, 1024

In [None]:
x_train, y_train, w_train = get_train(*prepare_data_reweighter(*data))
validation_data = get_val(*prepare_data_reweighter(*data))
reweighter = SimpleModel()
reweighter.compile(
    n_train=x_train.shape[0],
    epochs=epochs,
    batch_size=batch_size,
    learning_rate=(0.001, 0.00001),
)
logger = reweighter.fit(
    x_train,
    y_train,
    sample_weight=w_train,
    validation_data=validation_data,
    epochs=epochs,
    batch_size=batch_size,
)

plot_training(logger.history, title="Reweighter", path=output_dir + "training_reweighter.pdf")

In [None]:
x_train, y_train, w_train = get_train(*prepare_data_refiner(*data))
validation_data= get_val(*prepare_data_refiner(*data))
refiner = SimpleModel()
refiner.compile(
    n_train=x_train.shape[0],
    epochs=epochs,
    batch_size=batch_size,
    learning_rate=(0.001, 0.00001),
)
logger = refiner.fit(
    x_train,
    y_train,
    sample_weight=w_train,
    validation_data=validation_data,
    epochs=epochs,
    batch_size=batch_size,
)

plot_training(logger.history, title="Refiner", path=output_dir + "training_refiner.pdf")

In [None]:
data_natural = get_val(*prepare_data_natural(*data))
data_reweighter = get_val(*apply_reweighter(*data, reweighter=reweighter))
data_refiner = get_val(*apply_refiner(*data, refiner=refiner))

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, path=output_dir + "counts.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=np.arange(-1, 1.2, 0.1), path=output_dir + "weights.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, path=output_dir + "variances.pdf")

In [None]:
data_reweighter_resampled = resample(*data_reweighter)
data_refiner_resampled = resample(*data_refiner)
len(data_reweighter_resampled[0]), len(data_refiner_resampled[0])

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, path=output_dir + "counts_resampled.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=np.arange(-1, 3.0, 0.2), path=output_dir + "weights_resampled.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, path=output_dir + "variances_resampled.pdf")