# Imports and Helper Functions

In [None]:
import multiprocessing
num_available_cpus = multiprocessing.cpu_count()

print("Number of available CPUs:", num_available_cpus)

import sys

import math
import time
import tqdm

import numpy as np
import scipy as sp
from scipy import stats
from scipy.spatial import ConvexHull

import itertools
import logging
import matplotlib.pyplot as plt

import pandas as pd
import h5py

from sklearn import metrics

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.distributions import MultivariateNormal
import torch.utils.data as utils

from argparse import ArgumentParser
import re

sys.path.append("../new_flows")
from flows import RealNVP, Planar, MAF
from models import NormalizingFlowModel

In [None]:
from nflows.flows.base import Flow
from nflows.flows.autoregressive import MaskedAutoregressiveFlow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform, MaskedPiecewiseQuadraticAutoregressiveTransform, MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation

In [None]:
from helper_functions import *

In [None]:
torch.cuda.empty_cache()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device =", device)
torch.set_default_tensor_type('torch.cuda.FloatTensor') if torch.cuda.is_available() else print ('cpu')

torch.set_num_threads(num_available_cpus)

print("Number of threads:", torch.get_num_threads())
print("Number of interop threads:", torch.get_num_interop_threads())

# Load models

In [None]:
num_features = 14
hidden_features = 56

num_layers = 4
num_blocks_per_layer = 4
#num_iter = 10000
num_iter = 1000
print_interval = 20

#Current flow_type options: 'MAF', 'NSQUAD' (neural spline quadratic), 'NSRATQUAD' (neural spline rational quadratic)
flow_type = 'NSQUAD'

study = 'BB'

In [None]:
filename = 'Pure_NF_%s_k%s_hf%s_nbpl%s' % (flow_type, num_layers, hidden_features, num_blocks_per_layer)

if study == 'BB': 
    bkg_model = torch.load("pure_flows/bkg_%s.pt" % (filename))
    sig_model = torch.load("pure_flows/sig_%s.pt" % (filename))
else: 
    bkg_model = torch.load("new_sample_flows/%s/bkg_%s.pt" % (study, filename))
    sig_model = torch.load("new_sample_flows/%s/sig_%s.pt" % (study, filename))

# Load and Process Samples

In [None]:
num_bkg_batches = 2
num_batches = 111
sampling_percentage = 100

In [None]:
bkg_data, bkg_unnorm_data, bkg_masses = LAPS_test(sample_type = 'qcdbkg', num_batches = num_bkg_batches)
print(bkg_data.shape)
bkg_mean = np.mean(bkg_unnorm_data, axis=0)
bkg_std = np.std(bkg_unnorm_data, axis=0)

bkgtr_bkg_losses = -bkg_model.log_prob(bkg_data)[0].detach().cpu().numpy()
sigtr_bkg_losses = -sig_model.log_prob(bkg_data)[0].detach().cpu().numpy()

In [None]:
CMS_data, CMS_unnorm_data, CMS_masses = LAPS_test_CMS(num_batches = num_batches, inp_meanstd = (bkg_mean, bkg_std))

CMS_num_samples = CMS_data.shape[0]
sampling_indices = np.random.randint(CMS_num_samples, size = int(sampling_percentage * CMS_num_samples / 100))
CMS_data = CMS_data[sampling_indices, :]
CMS_unnorm_data = CMS_unnorm_data[sampling_indices, :]
CMS_masses = CMS_masses[sampling_indices, :]

print(CMS_data.shape)
bkgtr_CMS_losses = -bkg_model.log_prob(CMS_data)[0].detach().cpu().numpy()
sigtr_CMS_losses = -sig_model.log_prob(CMS_data)[0].detach().cpu().numpy()

In [None]:
CMS_mj1 = CMS_unnorm_data[:,0]
CMS_mj2 = CMS_unnorm_data[:,7]

In [None]:
#df_mass = pd.DataFrame(np.ndarray.tolist(CMS_masses))
df_mass = pd.DataFrame(np.ndarray.tolist(CMS_mj1))    # FOR TTBAR SEARCH
#df_mass = pd.DataFrame(np.ndarray.tolist(CMS_mj2))    # FOR TTBAR SEARCH

df_mass.to_csv('csv_files/CMS_masses.csv')

df_bkgloss = pd.DataFrame(np.ndarray.tolist(bkgtr_CMS_losses))
df_bkgloss.to_csv('csv_files/bkgtr_CMS_losses.csv')

df_sigloss = pd.DataFrame(np.ndarray.tolist(sigtr_CMS_losses))
df_sigloss.to_csv('csv_files/sigtr_CMS_losses.csv')

# Master QUAK Spaces

In [None]:
x_bad_loss_cutoff = 100
y_bad_loss_cutoff = 100

In [None]:
plt.rcParams["figure.figsize"] = (10,10)

In [None]:
temp_bkgtr_CMS_losses = np.append(bkgtr_CMS_losses, np.array([0,]))
temp_sigtr_CMS_losses = np.append(sigtr_CMS_losses, np.array([0,]))

In [None]:
plt.scatter(bkgtr_CMS_losses, sigtr_CMS_losses, s=2, label = '13 TeV CMS data')
plt.xlim(0, x_bad_loss_cutoff)
plt.ylim(0, y_bad_loss_cutoff)
plt.xlabel('QCD Bkg Model Loss')
#plt.ylabel(r'''W'$\rightarrow$tB' (M=2000) Sig Model Loss''')
plt.ylabel(r'''W'$\rightarrow$WZ Sig Model Loss''')
plt.title('Testing Data QUAK Space (Scatter Plot)')
plt.legend()
plt.show()

In [None]:
num_bins = 2000

h_bkg, bkg_xedges, bkg_yedges, _ = plt.hist2d(temp_bkgtr_CMS_losses, temp_sigtr_CMS_losses, cmap = plt.cm.jet, bins=num_bins)
plt.colorbar()
plt.xlabel('QCD Bkg Model Loss')
plt.ylabel(r'''W'$\rightarrow$WZ Sig Model Loss''')
plt.title('Testing Data QUAK Space (Heat Map)')
plt.xlim(0, x_bad_loss_cutoff)
plt.ylim(0, y_bad_loss_cutoff)
plt.show()

# Normalized Input Variable Density Histograms

In [None]:
if study == 'BB': 
    num_sig_batches = 5
else: 
    num_sig_batches = 1

In [None]:
sig_data, sig_unnorm_data, sig_masses = LAPS_test(sample_type = 'wprimesig', num_batches = num_sig_batches, inp_meanstd = (bkg_mean, bkg_std))

sig_num_samples = sig_data.shape[0]
if study == 'BB': 
    sampling_indices = np.random.randint(sig_num_samples, size = int(0.158 * sig_num_samples))
else: 
    sampling_indices = np.random.randint(sig_num_samples, size = int(0.05 * sig_num_samples))

sig_data = sig_data[sampling_indices, :]
sig_unnorm_data = sig_unnorm_data[sampling_indices, :]
sig_masses = sig_masses[sampling_indices, :]

print(sig_data.shape)
bkgtr_sig_losses = -bkg_model.log_prob(sig_data)[0].detach().cpu().numpy()
sigtr_sig_losses = -sig_model.log_prob(sig_data)[0].detach().cpu().numpy()

In [None]:
plt.rcParams["figure.figsize"] = (5,5)

In [None]:
plot_titles = [r'$M_{j1}$', r'Jet 1 $\tau_{21}$', r'Jet 1 $\tau_{32}$', r'Jet 1 $\tau_{43}$', r'Jet 1 $\tau_s$', r'Jet 1 $P_b$', r'Jet 1 $n_{pf}$', 
              r'$M_{j2}$', r'Jet 2 $\tau_{21}$', r'Jet 2 $\tau_{32}$', r'Jet 2 $\tau_{43}$', r'Jet 2 $\tau_s$', r'Jet 2 $P_b$', r'Jet 2 $n_{pf}$',]

for index in range(num_features): 
    n, bins, patches = plt.hist(bkg_data[:, index], bins=30, histtype='step', density=True, label='QCD bkg samples')
    plt.hist(sig_data[:, index], bins=bins, histtype='step', density=True, label=r'''W'$\rightarrow$WZ sig samples''')
    plt.hist(CMS_data[:, index], bins=bins, histtype='step', density=True, label='13 TeV CMS data')
    if index % 7 == 4: 
        plt.legend(loc=(1.04,0.73))
    plt.title(plot_titles[index])
    plt.show()