In [3]:
import numpy as np
import os
import torch
import uproot
from torch.utils.data import Dataset

In [13]:
class ProductionModeDataset(Dataset):
    """
    This class will load the minitree and create the dataset loaded into the nns
    
    0) load the minitree properly
    
    1) remove 0 and 2 type events so we have the same number as qqbar
    
    2) normalize all data
    
    3) ensure this is stored in a numpy array with the correct arrangement of data and a corresponding list of datatypes
    
    4) split this up into training and evaluating datasets
    
    Args:
        :param root (string): is root directory and file name of minitree
        :param split (boolean): tells whether to split into training and eval
        :param normalize (boolean): tells whether to normalize data
        :param remove (boolean): tells whether we should remove excess data for non-qqbar events or duplicate qqbar
        :param train (boolean): tells whether we are training or evaluating
    """
    
    def __init__(self, root, split=True, normalize=True, remove=True, train=True, correlation_cut=-1.0):
        # load a correlation cut if it exists:
        to_remove = np.array(())
        if correlation_cut > 0:
            to_remove = np.load("../analysis_code/results/inputs_to_remove_cut_" + str(correlation_cut) + ".npy")
            print("loaded correlations... shape is " + str(to_remove.shape))
        
        self.events = uproot.open(root)
        self.training = train
        #self.events = self.events['Events']
        
        #data_list = data = uproot.open(depo_dir + root_f, key = "ttBar_treeVariables_step8;3")
        
        
        # TODO: make less complex when not needed later
#         data_list = [x in self.events["ttBar_treeVariables_step8;3"].keys() if ((("phi" in x) or ("eta" in x) \
#                                                                              or ("pt" in x) or ("production" in x)) and \
#                                                                              ("gen" not in x))]
#         print(str(data_list))
        key = self.events.keys()[0]
        data_list = self.events[key].keys()
    
        self.events_array = np.array([self.events[key + "/" + k].array(library="np") for k in data_list])
        print(data_list)
    
        self.events_array = np.transpose(self.events_array)   # turn to columns of data

        if to_remove.shape[0] > 0:
            self.events_array = np.delete(self.events_array, to_remove, 1)
            to_remove = np.sort(to_remove)
            for i in range(len(to_remove)):
                to_remove[i] -= i
                data_list.pop(to_remove[i])

        
        if remove:
            """
            Here is where we remove the excess data
            """
            # find index of production_mode:
            index = 0
            for i in range(len(data_list)):
                if data_list[i] == "production_mode":
                    index = i
                    break
                    
            self.events_array = self.events_array[np.argsort(self.events_array[:, index])]            
            
            # find first and last index with production mode 1
            first = 0
            last = 0
            found_first = False
            
            for i in range(len(self.events_array[:,index])):
                if not found_first:
                    if self.events_array[i ,index] == 1:
                        first = i
                        found_first = True
                
                if self.events_array[i,index] == 2:
                    last = i - 1
                    break
                    
            num_qqbar = last + 1 - first
            print("num qqbar = " + str(num_qqbar))
            # remove the extra gg and other
            max_len = len(self.events_array[:,0])
            self.events_array = np.delete(self.events_array, list(range(num_qqbar, first)) + \
                                          list(range(max_len - (max_len - (last + 1) -num_qqbar), max_len)), 0)
            
        # normalize here
            #print(len(self.events_array[:,0]))
            #columns = self.events_array[1]
            
            #min_a = []
            #max_a = []
            #range_a = []
            #update_a = []
        if normalize:
            for i in range(len(self.events_array[0,:])-3):#the last three is subtracted, thust not there any more
                ori_a = self.events_array[:,i]#original number
                #print('ori_a =',ori_a)
                min_a = np.min(self.events_array[:,i])#min
                max_a = np.max(self.events_array[:,i])#max
                range_a = max_a - min_a #range
                #print('range =',range_a)
                self.events_array[:, i] = (ori_a - min_a) / (range_a) #normalized list. i dont know what to do with it.
#                 print('normalized =', np.max(self.events_array[:,i]))
                
            # normalize the weights here:
            n = len(self.events_array[0,:])-2
            self.events_array[:,n] /= np.max(self.events_array[:,n])
        
        
        # split here        
        if split:
            # shuffle so no longer sorted before splitting
            np.random.shuffle(self.events_array)
            
            train_size = int(len(self.events_array[:,0])*81/100)
#             eval_size = len(self.events_array[:,0]) - train_size
            
            self.train_array, self.eval_array = np.split(self.events_array, [train_size])
            print("training " + str(self.train_array.shape))
            print("evaluating " + str(self.eval_array.shape))
            
        else:
            self.train_array = self.events_array
            self.eval_array = self.events_array
                
    def __getitem__(self, index):
        if self.training:
            return self.train_array[index]
        return self.eval_array[index]
    
    def __len__(self):
        if self.training:
            return len(self.train_array[:,0])
        return len(self.eval_array[:,0])
    
    def get_eval_data(self):
        return self.eval_array
    
    def get_training_data(self):
        return self.train_array

In [14]:
root_path = "/depot/darkmatter/data/jupyterhub/Physics_Undergrads/Steve/things"

file = root_path + "/all_1.root"

data_o = ProductionModeDataset(file, correlation_cut = 0.8)

loaded correlations... shape is (21,)
['lb_delta_eta', 'lbbar_delta_eta', 'lnu_delta_eta', 'lnubar_delta_eta', 'lbarb_delta_eta', 'lbarbbar_delta_eta', 'lbarnu_delta_eta', 'lbarnubar_delta_eta', 'bnu_delta_eta', 'bnubar_delta_eta', 'bbarnu_delta_eta', 'bbarnubar_delta_eta', 'lb_delta_phi', 'lbbar_delta_phi', 'lnu_delta_phi', 'lnubar_delta_phi', 'lbarb_delta_phi', 'lbarbbar_delta_phi', 'lbarnu_delta_phi', 'lbarnubar_delta_phi', 'bnu_delta_phi', 'bnubar_delta_phi', 'bbarnu_delta_phi', 'bbarnubar_delta_phi', 'wplusb_delta_eta', 'wplusbbar_delta_eta', 'wminusb_delta_eta', 'wminusbbar_delta_eta', 'wplusb_delta_phi', 'wplusbbar_delta_phi', 'wminusb_delta_phi', 'wminusbbar_delta_phi', 'top_eta', 'top_boosted_eta', 'tbar_eta', 'tbar_boosted_eta', 'ttbar_delta_eta', 'ttbar_eta', 'llbar_delta_eta', 'bbbar_delta_eta', 'nunubar_delta_eta', 'top_phi', 'tbar_phi', 'ttbar_phi', 'ttbar_delta_phi', 'llbar_phi', 'llbar_delta_phi', 'bbbar_phi', 'bbbar_delta_phi', 'nunubar_phi', 'nunubar_delta_phi', 'l_et

In [11]:
data_o.__getitem__(0)

array([3.72715515e-01, 2.58894116e-02, 3.45254717e-02, 1.99258939e-01,
       1.17393606e-01, 9.01361103e-02, 1.81657309e-03, 2.44576740e-01,
       3.05554165e-01, 8.84354442e-01, 6.65447651e-01, 5.72729012e-01,
       2.87516328e-01, 2.91670475e-01, 7.07460533e-02, 1.89763779e-02,
       3.56920362e-01, 2.67904959e-01, 2.19759834e-01, 3.09048836e-01,
       8.92343914e-02, 2.80172142e-01, 3.01306707e-01, 2.76023044e-01,
       6.24736169e-01, 4.80638790e-02, 5.27162548e-01, 1.07177735e-01,
       4.17956806e-01, 1.17960623e-01, 3.54601881e-01, 2.98531575e-04,
       4.81094593e-01, 6.23694164e-02, 2.60965951e-01, 8.37408404e-01,
       2.35444133e-01, 8.18387333e-01, 8.35135937e-01, 8.48587194e-01,
       3.49537976e-01, 1.78700005e-01, 9.57125714e-01, 3.66311318e-01,
       5.82054723e-01, 2.34551295e-01, 6.52294333e-01, 7.65744429e-02,
       2.95814529e-01, 3.85171319e-01, 3.51287262e-01, 2.86013961e-02,
       3.40297469e-02, 7.06269939e-02, 7.25539916e-02, 2.20674442e-02,
      

In [6]:
# data = data_o.get_eval_data()
# data

array([[3.13525153e-01, 9.91798033e-02, 4.54658588e-01, ...,
        0.00000000e+00, 6.25184078e-01, 3.23028000e+05],
       [2.96012653e-01, 9.31698965e-02, 2.77613574e-02, ...,
        2.00000000e+00, 5.97872822e-01, 1.22317000e+05],
       [6.80722665e-01, 2.12740390e-01, 2.79750833e-01, ...,
        2.00000000e+00, 6.40147838e-01, 2.79218000e+05],
       ...,
       [2.89145781e-01, 5.40879933e-01, 1.00308518e-01, ...,
        2.00000000e+00, 6.23233559e-01, 2.87315000e+05],
       [1.30235214e-01, 3.13092280e-01, 3.24186193e-02, ...,
        1.00000000e+00, 6.43382026e-01, 1.01885000e+05],
       [1.59179573e-01, 2.32185964e-01, 1.51536304e-01, ...,
        0.00000000e+00, 6.78623185e-01, 3.15352000e+05]])

In [7]:
# data[:,81]

array([0.62518408, 0.59787282, 0.64014784, ..., 0.62323356, 0.64338203,
       0.67862319])