In [5]:
import numpy as np
import os
import torch
import uproot
from torch.utils.data import Dataset

In [7]:
class ProductionModeDataset(Dataset):
    """
    This class will load the minitree and create the dataset loaded into the nns
    
    0) load the minitree properly
    
    1) remove 0 and 2 type events so we have the same number as qqbar
    
    2) normalize all data
    
    3) ensure this is stored in a numpy array with the correct arrangement of data and a corresponding list of datatypes
    
    4) split this up into training and evaluating datasets
    
    Args:
        :param root (string): is root directory and file name of minitree
        :param split (boolean): tells whether to split into training and eval
        :param normalize (boolean): tells whether to normalize data
        :param remove (boolean): tells whether we should remove excess data for non-qqbar events or duplicate qqbar
        :param train (boolean): tells whether we are training or evaluating
    """
    
    def __init__(self, root, split=True, normalize=True, remove=True, train=True):
        self.events = uproot.open(root)
        self.training = train
        #self.events = self.events['Events']
        
        #data_list = data = uproot.open(depo_dir + root_f, key = "ttBar_treeVariables_step8;3")
        
        
        # TODO: make less complex when not needed later
#         data_list = [x in self.events["ttBar_treeVariables_step8;3"].keys() if ((("phi" in x) or ("eta" in x) \
#                                                                              or ("pt" in x) or ("production" in x)) and \
#                                                                              ("gen" not in x))]
#         print(str(data_list))
        key = self.events.keys()[0]
        data_list = self.events[key].keys()
    
        self.events_array = np.array([self.events[key + "/" + k].array(library="np") for k in data_list])
        print(data_list)
    
        self.events_array = np.transpose(self.events_array)   # turn to columns of data
        
        if remove:
            """
            Here is where we remove the excess data
            """
            # find index of production_mode:
            index = 0
            for i in range(len(data_list)):
                if data_list[i] == "production_mode":
                    index = i
                    break
                    
            self.events_array = self.events_array[np.argsort(self.events_array[:, index])]            
            
            # find first and last index with production mode 1
            first = 0
            last = 0
            found_first = False
            
            for i in range(len(self.events_array[:,index])):
                if not found_first:
                    if self.events_array[i ,index] == 1:
                        first = i
                        found_first = True
                
                if self.events_array[i,index] == 2:
                    last = i - 1
                    break
                    
            num_qqbar = last + 1 - first
            print("num qqbar = " + str(num_qqbar))
            # remove the extra gg and other
            max_len = len(self.events_array[:,0])
            self.events_array = np.delete(self.events_array, list(range(num_qqbar, first)) + \
                                          list(range(max_len - (max_len - (last + 1) -num_qqbar), max_len)), 0)
            
        # normalize here
            #print(len(self.events_array[:,0]))
            #columns = self.events_array[1]
            
            #min_a = []
            #max_a = []
            #range_a = []
            #update_a = []
        if normalize:
            for i in range(len(self.events_array[0,:])-3):#the last three is subtracted, thust not there any more
                ori_a = self.events_array[:,i]#original number
                #print('ori_a =',ori_a)
                min_a = np.min(self.events_array[:,i])#min
                max_a = np.max(self.events_array[:,i])#max
                range_a = max_a - min_a #range
                #print('range =',range_a)
                self.events_array[:, i] = (ori_a - min_a) / (range_a) #normalized list. i dont know what to do with it.
#                 print('normalized =', np.max(self.events_array[:,i]))
                
            #range_a.append(max_a - min_a)
            
            #print(min_a)
        
        # split here        
        if split:
            # shuffle so no longer sorted before splitting
            np.random.shuffle(self.events_array)
            
            train_size = int(len(self.events_array[:,0])*81/100)
#             eval_size = len(self.events_array[:,0]) - train_size
            
            self.train_array, self.eval_array = np.split(self.events_array, [train_size])
            print("training " + str(self.train_array.shape))
            print("evaluating " + str(self.eval_array.shape))
            
        else:
            self.train_array = self.events_array
            self.eval_array = self.events_array
                
    def __getitem__(self, index):
        if self.training:
            return self.train_array[index]
        return self.eval_array[index]
    
    def __len__(self):
        if self.training:
            return len(self.train_array[:,0])
        return len(self.eval_array[:,0])
    
    def get_eval_data(self):
        return self.eval_array

In [8]:
# root_path = "/depot/darkmatter/data/jupyterhub/Physics_Undergrads/Steve/things"

# file = root_path + "/all_1.root"

# ProductionModeDataset(file)

['lb_delta_eta', 'lbbar_delta_eta', 'lnu_delta_eta', 'lnubar_delta_eta', 'lbarb_delta_eta', 'lbarbbar_delta_eta', 'lbarnu_delta_eta', 'lbarnubar_delta_eta', 'bnu_delta_eta', 'bnubar_delta_eta', 'bbarnu_delta_eta', 'bbarnubar_delta_eta', 'lb_delta_phi', 'lbbar_delta_phi', 'lnu_delta_phi', 'lnubar_delta_phi', 'lbarb_delta_phi', 'lbarbbar_delta_phi', 'lbarnu_delta_phi', 'lbarnubar_delta_phi', 'bnu_delta_phi', 'bnubar_delta_phi', 'bbarnu_delta_phi', 'bbarnubar_delta_phi', 'wplusb_delta_eta', 'wplusbbar_delta_eta', 'wminusb_delta_eta', 'wminusbbar_delta_eta', 'wplusb_delta_phi', 'wplusbbar_delta_phi', 'wminusb_delta_phi', 'wminusbbar_delta_phi', 'top_eta', 'top_boosted_eta', 'tbar_eta', 'tbar_boosted_eta', 'ttbar_delta_eta', 'ttbar_eta', 'llbar_delta_eta', 'bbbar_delta_eta', 'nunubar_delta_eta', 'top_phi', 'tbar_phi', 'ttbar_phi', 'ttbar_delta_phi', 'llbar_phi', 'llbar_delta_phi', 'bbbar_phi', 'bbbar_delta_phi', 'nunubar_phi', 'nunubar_delta_phi', 'l_eta', 'lbar_eta', 'l_phi', 'lbar_phi', '

<__main__.ProductionModeDataset at 0x7fdbcbf582b0>

In [15]:
# data_o = ProductionModeDataset(file)

['ttBar_treeVariables_step8;5;1']
['lb_delta_eta', 'lbbar_delta_eta', 'lnu_delta_eta', 'lnubar_delta_eta', 'lbarb_delta_eta', 'lbarbbar_delta_eta', 'lbarnu_delta_eta', 'lbarnubar_delta_eta', 'bnu_delta_eta', 'bnubar_delta_eta', 'bbarnu_delta_eta', 'bbarnubar_delta_eta', 'lb_delta_phi', 'lbbar_delta_phi', 'lnu_delta_phi', 'lnubar_delta_phi', 'lbarb_delta_phi', 'lbarbbar_delta_phi', 'lbarnu_delta_phi', 'lbarnubar_delta_phi', 'bnu_delta_phi', 'bnubar_delta_phi', 'bbarnu_delta_phi', 'bbarnubar_delta_phi', 'wplusb_delta_eta', 'wplusbbar_delta_eta', 'wminusb_delta_eta', 'wminusbbar_delta_eta', 'wplusb_delta_phi', 'wplusbbar_delta_phi', 'wminusb_delta_phi', 'wminusbbar_delta_phi', 'top_eta', 'top_boosted_eta', 'tbar_eta', 'tbar_boosted_eta', 'ttbar_delta_eta', 'ttbar_eta', 'llbar_delta_eta', 'bbbar_delta_eta', 'nunubar_delta_eta', 'top_phi', 'tbar_phi', 'ttbar_phi', 'ttbar_delta_phi', 'llbar_phi', 'llbar_delta_phi', 'bbbar_phi', 'bbbar_delta_phi', 'nunubar_phi', 'nunubar_delta_phi', 'l_eta', 

In [18]:
# data = data_o.get_eval_data()
# data

array([[4.25484424e-02, 1.94240780e-01, 1.40045778e-01, ...,
        2.00000000e+00, 8.74388428e+01, 5.62070000e+04],
       [5.26011911e-01, 2.05888226e-01, 2.55710306e-01, ...,
        2.00000000e+00, 6.62610474e+01, 3.77360000e+04],
       [2.99117268e-01, 4.18714768e-01, 4.24535231e-01, ...,
        0.00000000e+00, 8.83249359e+01, 3.07602000e+05],
       ...,
       [3.94856636e-01, 9.05526598e-02, 3.97503147e-01, ...,
        2.00000000e+00, 7.72362442e+01, 5.08670000e+04],
       [5.40636140e-02, 2.46736763e-01, 1.53004308e-01, ...,
        0.00000000e+00, 7.36623306e+01, 3.35191000e+05],
       [9.10920545e-02, 3.42148793e-01, 3.55437631e-01, ...,
        0.00000000e+00, 7.73415146e+01, 3.35012000e+05]])