In [5]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from copy import deepcopy

from n3jet.utils import FKSPartition
from n3jet.utils.general_utils import (
    bool_convert,
    dot
)

In [36]:
# original NJet to compare against
nj_compare_dir = '/mt/batch/jbullock/Sherpa_NJet/runs/diphoton/4g2A/RAMBO/parallel_fixed/integration_grid/NJet_NJet_unit_grid_2/'
hepmc_mom_file_nj_compare = nj_compare_dir + 'full_momenta_events_1.5M_new_sherpa_cuts_PDF.npy'
events_file_nj_compare = nj_compare_dir + 'original_weights_events_1.5M_new_sherpa_cuts_PDF.npy'
trials_file_nj_compare = nj_compare_dir + 'trials_events_1.5M_new_sherpa_cuts_PDF.npy'

# NN files for reweighting
nn_dir = '/mt/batch/jbullock/Sherpa_NJet/runs/diphoton/4g2A/RAMBO/parallel_fixed/integration_grid/nn_NJet_unit_grid_2_delta_0001/'
reweight_raw_dir = nn_dir + 'reweight_raw/'
reweigth_points = nn_dir + 'reweight_near_05.npy'
mom_file_nn = nn_dir + 'unshuffle_momenta_events_1.5M_new_sherpa_cuts_PDF.npy'
me_file_nn = nn_dir + 'unshuffle_events_NN_1.5M_new_sherpa_cuts_PDF_loop.npy'
hepmc_mom_file_nn = nn_dir + 'full_momenta_events_1.5M_new_sherpa_cuts_PDF.npy'
events_file_nn = nn_dir + 'original_weights_events_1.5M_new_sherpa_cuts_PDF.npy'
trials_file_nn = nn_dir + 'trials_events_1.5M_new_sherpa_cuts_PDF.npy'

In [12]:
all_legs = True
delta_cut = 0.0
delta_near = 0.001

In [34]:
def len_check(x):
    if x is None:
        raise ValueError('Value is None')
    elif len(x) > 1:
        raise ValueError('length of array is > 1')
    elif len(x) < 1:
        raise ValueError('length of array is < 1')
    else:
        return x[0]
    
def reconcile(to_match, match_to):
    indices = []
    for idx, i in tqdm(enumerate(to_match)):
        try_all = np.where(np.all(match_to==i, axis=(1, 2)))[0]
        if len(try_all) == 0:
            try_2 = np.where(np.all(match_to[:,2]==i[2], axis=(1)))[0]
            if len(try_2) == 0:
                try_3 = np.where(np.all(match_to[:,3]==i[3], axis=(1)))[0]
                if len(try_3) == 0:
                    try_4 = np.where(np.all(match_to[:,4]==i[4], axis=(1)))[0]
                    if len(try_4) == 0:
                        print ('Struggling on index: {}'.format(idx))
                        indices.append(None)
                    else:
                        indices.append(try_4)
                else:
                    indices.append(try_3)
            else:
                indices.append(try_2)
        else:
            indices.append(try_all)
    return indices

def index_process(x):
    indices = []
    for i in tqdm(x):
        index = len_check(i)
        indices.append(index)
        
    return indices

## Load data

In [13]:
chosen_max = np.load(reweigth_points)

In [15]:
reweight_data = []
for i in os.listdir(reweight_raw_dir):
    if i.split(".")[-1] == "part":
        with open(reweight_raw_dir + i, 'r') as f:
            reweight_data.append(float(f.read().split(" ")[0]))

In [16]:
assert len(chosen_max) == len(reweight_data)

In [37]:
hepmc_mom_nj_compare = np.load(hepmc_mom_file_nj_compare, allow_pickle=True)
events_nj_compare = np.load(events_file_nj_compare, allow_pickle=True)
trials_nj_compare = np.load(trials_file_nj_compare, allow_pickle=True)

mom_nn = np.load(mom_file_nn, allow_pickle=True)
me_nn = np.load(me_file_nn, allow_pickle=True)
hepmc_mom_nn = np.load(hepmc_mom_file_nn, allow_pickle=True)
events_nn = np.load(events_file_nn, allow_pickle=True)
trials_nn = np.load(trials_file_nn, allow_pickle=True)

## Cross-section check

In [21]:
xs_compare = np.sum(events_nj_compare)/np.sum(trials_nj_compare)
print ("Comparison XS is {}".format(xs_compare))

Comparison XS is 4.18227928531e-06


In [33]:
nn_xs = np.sum(events_nn)/np.sum(trials_nn)
print ("NN XS is {}".format(nn_xs)) 

NN XS is 4.38744671886e-06


## Split data

In [18]:
nn_fks_partition = FKSPartition(
    momenta = list(hepmc_mom_nn),
    labels = events_nn,
    all_legs = all_legs
)

nn_cut_momenta, nn_near_momenta, nn_cut_labels, nn_near_labels = nn_fks_partition.cut_near_split(
    delta_cut = delta_cut,
    delta_near = delta_near
)

100%|██████████| 1500000/1500000 [02:05<00:00, 11978.69it/s]


In [38]:
nn_me_fks_partition = FKSPartition(
    momenta = list(mom_nn),
    labels = me_nn,
    all_legs = all_legs
)

nn_me_cut_momenta, nn_me_near_momenta, nn_me_cut_labels, nn_me_near_labels = nn_me_fks_partition.cut_near_split(
    delta_cut = delta_cut,
    delta_near = delta_near
)


  0%|          | 0/1500340 [00:00<?, ?it/s][A
  0%|          | 2086/1500340 [00:00<01:11, 20851.87it/s][A
  0%|          | 4258/1500340 [00:00<01:10, 21104.38it/s][A
  0%|          | 6331/1500340 [00:00<01:11, 20989.12it/s][A
  1%|          | 8417/1500340 [00:00<01:11, 20948.13it/s][A
  1%|          | 10546/1500340 [00:00<01:10, 21047.22it/s][A
  1%|          | 12594/1500340 [00:00<01:11, 20873.41it/s][A
  1%|          | 14695/1500340 [00:00<01:11, 20913.19it/s][A
  1%|          | 16834/1500340 [00:00<01:10, 20900.76it/s][A
  1%|▏         | 18951/1500340 [00:00<01:10, 20979.37it/s][A
  1%|▏         | 20985/1500340 [00:01<01:11, 20782.49it/s][A
  2%|▏         | 23005/1500340 [00:01<01:11, 20557.05it/s][A
  2%|▏         | 25061/1500340 [00:01<01:11, 20555.30it/s][A
  2%|▏         | 27174/1500340 [00:01<01:11, 20722.14it/s][A
  2%|▏         | 29341/1500340 [00:01<01:10, 20995.96it/s][A
  2%|▏         | 31428/1500340 [00:01<01:10, 20723.58it/s][A
  2%|▏         | 33524/150

 19%|█▊        | 277766/1500340 [00:13<00:58, 20889.26it/s][A
 19%|█▊        | 279857/1500340 [00:13<00:58, 20849.78it/s][A
 19%|█▉        | 281944/1500340 [00:13<00:58, 20840.53it/s][A
 19%|█▉        | 284085/1500340 [00:13<00:57, 21007.30it/s][A
 19%|█▉        | 286257/1500340 [00:13<00:57, 21213.65it/s][A
 19%|█▉        | 288391/1500340 [00:13<00:57, 21249.18it/s][A
 19%|█▉        | 290528/1500340 [00:13<00:56, 21284.71it/s][A
 20%|█▉        | 292671/1500340 [00:13<00:56, 21325.74it/s][A
 20%|█▉        | 294805/1500340 [00:14<00:56, 21314.54it/s][A
 20%|█▉        | 296937/1500340 [00:14<00:56, 21288.34it/s][A
 20%|█▉        | 299067/1500340 [00:14<00:56, 21174.17it/s][A
 20%|██        | 301213/1500340 [00:14<00:56, 21255.90it/s][A
 20%|██        | 303367/1500340 [00:14<00:56, 21338.82it/s][A
 20%|██        | 305502/1500340 [00:14<00:56, 21093.26it/s][A
 21%|██        | 307633/1500340 [00:14<00:56, 21155.28it/s][A
 21%|██        | 309750/1500340 [00:14<00:56, 21043.53i

 37%|███▋      | 554996/1500340 [00:26<00:45, 20888.45it/s][A
 37%|███▋      | 557166/1500340 [00:26<00:44, 21122.75it/s][A
 37%|███▋      | 559280/1500340 [00:26<00:45, 20908.32it/s][A
 37%|███▋      | 561412/1500340 [00:26<00:44, 21028.11it/s][A
 38%|███▊      | 563562/1500340 [00:26<00:44, 21167.10it/s][A
 38%|███▊      | 565680/1500340 [00:26<00:44, 21128.97it/s][A
 38%|███▊      | 567856/1500340 [00:27<00:43, 21310.64it/s][A
 38%|███▊      | 570048/1500340 [00:27<00:43, 21487.40it/s][A
 38%|███▊      | 572198/1500340 [00:27<00:43, 21480.56it/s][A
 38%|███▊      | 574347/1500340 [00:27<00:43, 21383.60it/s][A
 38%|███▊      | 576548/1500340 [00:27<00:42, 21567.67it/s][A
 39%|███▊      | 578719/1500340 [00:27<00:42, 21609.21it/s][A
 39%|███▊      | 580881/1500340 [00:27<00:42, 21436.10it/s][A
 39%|███▉      | 583026/1500340 [00:27<00:43, 21310.42it/s][A
 39%|███▉      | 585158/1500340 [00:27<00:43, 21140.84it/s][A
 39%|███▉      | 587273/1500340 [00:27<00:43, 21038.47i

 55%|█████▌    | 830113/1500340 [00:39<00:31, 21432.85it/s][A
 55%|█████▌    | 832322/1500340 [00:39<00:30, 21625.28it/s][A
 56%|█████▌    | 834515/1500340 [00:39<00:30, 21715.15it/s][A
 56%|█████▌    | 836688/1500340 [00:39<00:30, 21653.51it/s][A
 56%|█████▌    | 838855/1500340 [00:40<00:30, 21649.89it/s][A
 56%|█████▌    | 841021/1500340 [00:40<00:30, 21574.82it/s][A
 56%|█████▌    | 843179/1500340 [00:40<00:30, 21516.86it/s][A
 56%|█████▋    | 845332/1500340 [00:40<00:30, 21356.55it/s][A
 56%|█████▋    | 847469/1500340 [00:40<00:30, 21242.43it/s][A
 57%|█████▋    | 849594/1500340 [00:40<00:31, 20824.80it/s][A
 57%|█████▋    | 851711/1500340 [00:40<00:30, 20926.17it/s][A
 57%|█████▋    | 853806/1500340 [00:40<00:31, 20328.46it/s][A
 57%|█████▋    | 855974/1500340 [00:40<00:31, 20713.87it/s][A
 57%|█████▋    | 858051/1500340 [00:40<00:31, 20539.29it/s][A
 57%|█████▋    | 860109/1500340 [00:41<00:31, 20427.73it/s][A
 57%|█████▋    | 862236/1500340 [00:41<00:30, 20670.56i

 74%|███████▎  | 1103857/1500340 [00:52<00:19, 20421.34it/s][A
 74%|███████▎  | 1105900/1500340 [00:52<00:19, 20237.12it/s][A
 74%|███████▍  | 1107925/1500340 [00:52<00:19, 20051.58it/s][A
 74%|███████▍  | 1109950/1500340 [00:53<00:19, 20107.82it/s][A
 74%|███████▍  | 1111962/1500340 [00:53<00:19, 19871.98it/s][A
 74%|███████▍  | 1113951/1500340 [00:53<00:19, 19480.30it/s][A
 74%|███████▍  | 1116009/1500340 [00:53<00:19, 19796.12it/s][A
 75%|███████▍  | 1118091/1500340 [00:53<00:19, 20090.76it/s][A
 75%|███████▍  | 1120282/1500340 [00:53<00:18, 20603.82it/s][A
 75%|███████▍  | 1122416/1500340 [00:53<00:18, 20818.97it/s][A
 75%|███████▍  | 1124589/1500340 [00:53<00:17, 21081.71it/s][A
 75%|███████▌  | 1126723/1500340 [00:53<00:17, 21155.96it/s][A
 75%|███████▌  | 1128844/1500340 [00:53<00:17, 21169.53it/s][A
 75%|███████▌  | 1130963/1500340 [00:54<00:17, 20802.36it/s][A
 76%|███████▌  | 1133047/1500340 [00:54<00:17, 20488.05it/s][A
 76%|███████▌  | 1135184/1500340 [00:54<

 92%|█████████▏| 1374663/1500340 [01:05<00:05, 20982.46it/s][A
 92%|█████████▏| 1376815/1500340 [01:05<00:05, 21139.88it/s][A
 92%|█████████▏| 1378930/1500340 [01:05<00:05, 20945.60it/s][A
 92%|█████████▏| 1381043/1500340 [01:05<00:05, 20998.46it/s][A
 92%|█████████▏| 1383191/1500340 [01:06<00:05, 21131.97it/s][A
 92%|█████████▏| 1385305/1500340 [01:06<00:05, 20899.50it/s][A
 92%|█████████▏| 1387411/1500340 [01:06<00:05, 20946.33it/s][A
 93%|█████████▎| 1389550/1500340 [01:06<00:05, 21076.48it/s][A
 93%|█████████▎| 1391684/1500340 [01:06<00:05, 21154.48it/s][A
 93%|█████████▎| 1393852/1500340 [01:06<00:04, 21308.94it/s][A
 93%|█████████▎| 1396040/1500340 [01:06<00:04, 21475.21it/s][A
 93%|█████████▎| 1398226/1500340 [01:06<00:04, 21586.78it/s][A
 93%|█████████▎| 1400387/1500340 [01:06<00:04, 21592.05it/s][A
 93%|█████████▎| 1402547/1500340 [01:06<00:04, 21444.52it/s][A
 94%|█████████▎| 1404693/1500340 [01:07<00:04, 21393.30it/s][A
 94%|█████████▍| 1406833/1500340 [01:07<

In [39]:
nn_near_round = np.round(nn_near_momenta, 6)
nn_me_near_round = np.round(nn_me_near_momenta, 6)

## Reweight

In [40]:
reweight_props = np.arange(0,0.6,0.1)

In [42]:
nn_near_chosen = nn_near_round[chosen_max]

In [43]:
nn_near_chosen[0]

array([[ 183.173961,    0.      ,    0.      ,  183.173961],
       [ 170.022834,    0.      ,    0.      , -170.022834],
       [ 128.988407,  -20.091576,  102.495194,   75.690639],
       [ 114.587362,  -20.016646, -107.04783 ,   35.642101],
       [ 107.445019,   41.714001,    3.308502,  -98.961751],
       [   2.176007,   -1.605778,    1.244134,    0.780137]])

In [45]:
nn_me_indices = reconcile(nn_near_chosen, nn_me_near_round)


0it [00:00, ?it/s][A
35it [00:00, 343.20it/s][A
81it [00:00, 370.75it/s][A
132it [00:00, 402.87it/s][A
189it [00:00, 441.17it/s][A
248it [00:00, 477.19it/s][A
308it [00:00, 506.35it/s][A
368it [00:00, 529.73it/s][A
420it [00:00, 525.89it/s][A
483it [00:00, 552.75it/s][A
552it [00:01, 585.55it/s][A
620it [00:01, 610.22it/s][A
682it [00:01, 589.23it/s][A
742it [00:01, 538.97it/s][A
799it [00:01, 546.85it/s][A
857it [00:01, 556.03it/s][A
916it [00:01, 563.34it/s][A
973it [00:01, 555.53it/s][A
1029it [00:01, 539.26it/s][A
1084it [00:01, 527.42it/s][A
1139it [00:02, 533.06it/s][A
1206it [00:02, 567.16it/s][A
1273it [00:02, 594.01it/s][A
1340it [00:02, 613.71it/s][A
1407it [00:02, 627.29it/s][A
1485it [00:02, 666.38it/s][A
1567it [00:02, 704.85it/s][A
1645it [00:02, 725.24it/s][A
1719it [00:02, 718.32it/s][A
1795it [00:02, 728.54it/s][A
1876it [00:03, 750.96it/s][A
1958it [00:03, 769.62it/s][A
2036it [00:03, 758.57it/s][A
2113it [00:03, 733.22it/s][A
2187it

In [46]:
assert len(nn_me_indices) == len(chosen_max)

In [47]:
nn_me_indices = index_process(nn_me_indices)


  0%|          | 0/17972 [00:00<?, ?it/s][A
100%|██████████| 17972/17972 [00:00<00:00, 444641.25it/s][A

In [123]:
nn_near_labels_reweights = []
nn_me_near_labels_max = np.sort(nn_me_near_labels)[-100]
for i in reweight_props:
    nn_near_labels_new = deepcopy(nn_near_labels)
    for jdx, j in enumerate(chosen_max[:int(i*len(nn_near_labels_new))]):
        if reweight_data[jdx] <= nn_me_near_labels_max:
        #if reweight_data[jdx]/nn_me_near_labels[nn_me_indices[jdx]] < 100:
            nn_near_labels_new[j] = nn_near_labels[j]*(reweight_data[jdx]/nn_me_near_labels[nn_me_indices[jdx]])
    nn_near_labels_reweights.append(nn_near_labels_new)

In [124]:
nn_xs_reweights = []
for i in nn_near_labels_reweights:
    new_xs = np.sum(np.concatenate((nn_cut_labels, i)))/np.sum(trials_nn)
    nn_xs_reweights.append(new_xs)
    print ("New XS is {}".format(new_xs))

New XS is 4.38744671886e-06
New XS is 4.11704632045e-06
New XS is 3.87949718042e-06
New XS is 3.65725826098e-06
New XS is 3.43824342958e-06
New XS is 3.26084784065e-06
