In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle

from matplotlib import gridspec
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Layer, concatenate
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.keras.layers import BatchNormalization


plt.rc('font', size=20)
plt.rcParams["font.family"] = "serif"

In [None]:
obs = 'q' #Choose the observable

In [None]:
#load and normalize the data
data = np.load('npfiles/rawdata.npz')
substructure_variables = ['pT', 'w', 'q', 'm', 'r', 'tau1s', 'tau2s']
data_streams = ['_true', '_true_alt', '_reco', '_reco_alt']
n_variables = len(substructure_variables)


normalize = True
    
for stream in data_streams:
    globals()['x'+stream] = data[obs+stream][:150000]
    
xm, xs = (x_true_alt.mean(), x_true_alt.std()) if normalize else (0, 1)
for stream in data_streams:
    globals()['x'+stream] = (globals()['x'+stream] - xm)/xs

In [None]:
class MyLayer(Layer):
    def __init__(self, myc, **kwargs):
        self.myinit = myc
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self._lambda0 = self.add_weight(name='lambda0', 
                                    shape=(1,),
                                    initializer=tf.keras.initializers.Constant(self.myinit), 
                                    trainable=True)
        self._lambda1 = self.add_weight(name='lambda1', 
                                    shape=(1,),
                                    initializer=tf.keras.initializers.Constant(self.myinit), 
                                    trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return tf.exp(self._lambda0 * x + self._lambda1 * x**2)

In [None]:
def weighted_mlc(y_true, y_pred):
    weights = tf.gather(y_true, [1], axis=1) # event weights
    y_true = tf.gather(y_true, [0], axis=1) # actual y_true for loss
    
    # Clip the prediction value to prevent NaN's and Inf's
    epsilon = K.epsilon()
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    t_loss = -weights * ((y_true) * K.log(y_pred) +
                         (1 - y_true) * (1 - y_pred))
    return K.mean(t_loss)

def weighted_mlc_GAN(y_true, y_pred):
    weights = tf.gather(y_pred, [1], axis=1) # event weights
    y_pred = tf.gather(y_pred, [0], axis=1) # actual y_pred for loss
    
    weights_1 = K.sum(y_true*weights)
    weights_0 = K.sum((1-y_true)*weights)
    
    # Clip the prediction value to prevent NaN's and Inf's
    epsilon = K.epsilon()
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    t_loss = weights * ((1 - y_true) * (1 - y_pred)/weights_0)
    return K.mean(t_loss)

In [None]:
xvals_1 = np.concatenate([x_true_alt,x_true])
xvals_2 = np.concatenate([x_reco_alt,x_reco])
yvals = np.concatenate([np.ones(len(x_true_alt)),np.zeros(len(x_true))])

X_train_1, X_test_1, X_train_2, X_test_2, Y_train_1, Y_test_1 = train_test_split(xvals_1, xvals_2, yvals)

In [None]:
myc = np.random.normal()

mymodel_inputtest = Input(shape=(1,))
mymodel_test = MyLayer(myc)(mymodel_inputtest)
model_generator = Model(mymodel_inputtest, mymodel_test)

inputs_disc = Input((1, ))
hidden_layer_1_disc = Dense(50, activation='relu')(inputs_disc)
hidden_layer_2_disc = Dense(50, activation='relu')(hidden_layer_1_disc)
hidden_layer_3_disc = Dense(50, activation='relu')(hidden_layer_2_disc)
outputs_disc = Dense(1, activation='sigmoid')(hidden_layer_3_disc)
model_discrimantor = Model(inputs=inputs_disc, outputs=outputs_disc)

model_discrimantor.compile(loss=weighted_mlc, optimizer='adam')

model_discrimantor.trainable = False
mymodel_gan = Input(shape=(1,))
gan_model = Model(inputs=mymodel_gan,outputs=concatenate([model_discrimantor(mymodel_gan),model_generator(mymodel_gan)]))

gan_model.compile(loss=weighted_mlc_GAN, optimizer='adam')

In [None]:
n_epochs = 50
n_batch = 128*10
n_batches = len(X_train_1) // n_batch
lambdas = []


for i in range(n_epochs):
    for j in range(n_batches):
        X_batch_1 = X_train_1[j*n_batch:(j+1)*n_batch]
        X_batch_2 = X_train_1[j*n_batch:(j+1)*n_batch]
        Y_batch = Y_train_1[j*n_batch:(j+1)*n_batch]
        W_batch = model_generator(X_batch_1)
        W_batch = np.array(W_batch).flatten()
        W_batch[Y_batch==1] = 1        
        Y_batch_2 = np.stack((Y_batch, W_batch), axis=1)
        
        model_discrimantor.train_on_batch(X_batch_2, Y_batch_2)        
        gan_model.train_on_batch(X_batch_1[Y_batch==0],np.zeros(len(X_batch_2[Y_batch==0])))
    lambdasum = np.log(model_generator.predict([1.], verbose = 0))
    lambdasum2 = np.log(model_generator.predict([2.], verbose = 0))
    mylambda1 = (lambdasum2-2*lambdasum)/2
    mylambda0 = lambdasum - mylambda1
    print("on epoch=",i,mylambda0,mylambda1)
    lambdas += [mylambda0, mylambda1]

In [None]:
truth = X_test_1[Y_test_1==0] 
gen = X_test_1[Y_test_1==1]
data = X_test_2[Y_test_1==0]
sim = X_test_2[Y_test_1==1]

total_err = []

for lambda0, lambda1 in lambdas:
    weights = np.exp(lambda1*gen**2+lambda0*gen)*len(data)/np.sum(np.exp(lambda1*gen**2+lambda0*gen))
    mean_err = (np.average(gen, weights = weights - truth.mean())
    var_err = np.average(gen**2, weights = weights - np.mean(truth**2)
    total_err += [(mean_err)**2 + var_err]
    
lambda0, lambda1 = lambdas[np.argmin(total_err)]
weights = np.exp(lambda1*gen**2+lambda0*gen)*len(data)/np.sum(np.exp(lambda1*gen**2+lambda0*gen))

In [None]:
truth = X_test_1[Y_test_1==0] * xs + xm
gen = X_test_1[Y_test_1==1] * xs + xm
data = X_test_2[Y_test_1==0] * xs + xm
sim = X_test_2[Y_test_1==1] * xs + xm

bins = np.linspace(truth.min(), truth.max(), 30)

fig, ax = plt.subplots(figsize=(14, 7))

_,_,_=plt.hist(truth, bins=bins, alpha=0.5, label="Truth", density=True)
_,_,_=plt.hist(gen, bins=bins, alpha=0.5, label="Generation", density=True)
_,_,_=plt.hist(gen, bins=bins, weights=weights, histtype="step", color='black', ls=":", lw=4, label="Moment Unfolding", density=True)

plt.legend(fontsize=24)
plt.xlabel("z (particle level)", fontsize=24)
plt.ylabel("Counts", fontsize=24)
plt.title(f"Jet {obs} Data: Particle Level Histograms", fontsize=24)

plt.savefig(f"figures/{obs}jetexample.pdf", bbox_inches='tight', transparent=True)
plt.show()