In [None]:
import numpy as np
import tensorflow as tf

from data import create_data_gaussian, prepare_data_natural
from model import (
    simple_model,
    get_train,
    get_val,
    prepare_data_reweighter,
    prepare_data_refiner,
    apply_reweighter,
    apply_refiner,
    resample,
)
from plotting import plot_raw, plot_n_ratio, plot_w, plot_w2, plot_training

In [None]:
output_dir = "results/gauss_5dim/"
bins = np.arange(-3, 3.1, 0.1)

In [None]:
# # Calculate Normalization
# A = lambda k, sigma: 1 / np.sqrt((2*np.pi)**k*sigma**(k*2))
# neg_frac = A(3, 0.5) / A(3, 1) + 1

In [None]:
data = pos, neg, pos_weights, neg_weights = create_data_gaussian(10_000_000, neg_frac=1/9, neg_scale=0.5, shape=(3,))  # 10_000_000, neg_frac=0.07, neg_scale=0.5, shape=(5,)
plot_raw(data=data, bins=bins, path=output_dir + "raw.pdf")

In [None]:
reweighter = simple_model(input_shape=(3,))
reweighter.compile(optimizer="rmsprop", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), metrics=["accuracy"])
x_train, y_train, w_train, = get_train(*prepare_data_reweighter(*data))
history_reweighter = reweighter.fit(x_train, y_train, sample_weight=w_train, epochs=1, validation_split=0.2, batch_size=256)
plot_training(history_reweighter, title="Reweighter", path=output_dir + "training_reweighter.pdf")

In [None]:
refiner = simple_model(input_shape=(3,))
refiner.compile(optimizer="rmsprop", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), metrics=["accuracy"])
x_train, y_train, w_train = get_train(*prepare_data_refiner(*data))
history_refiner = refiner.fit(x_train, y_train, sample_weight=w_train, epochs=1, validation_split=0.2, batch_size=256)
plot_training(history_refiner, title="Refiner", path=output_dir + "training_refiner.pdf")

In [None]:
data_natural = get_val(*prepare_data_natural(*data))
data_reweighter = get_val(*apply_reweighter(*data, reweighter=reweighter))
data_refiner = get_val(*apply_refiner(*data, refiner=refiner))

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, path=output_dir + "counts.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=np.arange(-1, 1.2, 0.1), path=output_dir + "weights.pdf")

In [None]:

def check_refiner(vals, refiner):
    # Analytic check of refiner ratio
    # vals = (bins[1:] + bins[:-1]) / 2
    ratio = 1 / refiner.predict(vals) - 1

    plt.plot(vals, ratio, label="DNN Ratio")
    plt.plot(vals, np.abs(neg)/pos, label="Analytic Ratio")
    plt.legend()


In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter, refiner=data_refiner, bins=bins, path=output_dir + "variances.pdf")

In [None]:
data_reweighter_resampled = resample(*data_reweighter)
data_refiner_resampled = resample(*data_refiner)
len(data_reweighter_resampled[0]), len(data_refiner_resampled[0])

In [None]:
plot_n_ratio(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, path=output_dir + "counts_resampled.pdf")

In [None]:
plot_w(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=np.arange(-1, 3.0, 0.2), path=output_dir + "weights_resampled.pdf")

In [None]:
plot_w2(data=data_natural, reweighter=data_reweighter_resampled, refiner=data_refiner_resampled, bins=bins, path=output_dir + "variances_resampled.pdf")