## Note that the .py version is what is actually imported. Only use this for testing.


In [2]:
# import statements
import numpy as np
import os
import torch
import uproot
from torch.utils.data import Dataset
import dgl

Using backend: pytorch


In [41]:
class GraphProductionModeDataset(Dataset):
    """
    This class will load the minitree and create the dataset loaded into the nns
    
    0) load the minitree properly
    
    1) remove 0 and 2 type events so we have the same number as qqbar
    
    2) normalize all data
    
    3) ensure this is stored in a numpy array with the correct arrangement of data and a corresponding list of datatypes
    
    4) split this up into training and evaluating datasets
    
    5) provide methods with which these sets can be accessed
    
    Args:
        :param root (string): is root directory and file name of minitree
        :param split (boolean): tells whether to split into training and eval
        :param normalize (boolean): tells whether to normalize data
        :param remove (boolean): tells whether we should remove excess data for non-qqbar events or duplicate qqbar
        :param train (boolean): tells whether we are training or evaluating (probably not needed)
        :param correlation_cut(float): if positive, removes values correlated more than the correlation cut. Requires existing 
                                analysis of the given cut so that data can be loaded
        : param cut_version (int): if positive, loads the specific cut version (otherwise loads unnumbere OG version)
        : param include_qg (boolean): default true. If false then remove stuff with production mode 2
    """
    
    def __init__(self, root, gen=True, split=True, normalize=True, remove=True, train=True, correlation_cut=-1.0, cut_version=-1, 
                 include_qg = True):
        # load a correlation cut if it exists:
        to_remove = np.array(())  # initialize the array of inputs to remove because of cut as empty
        if correlation_cut > 0:
            # load the inputs which must be removed due to cut if positive
            if cut_version > 0:
                to_remove = np.load("../analysis_code/results/inputs_to_remove_cut_" + str(correlation_cut)\
                                    + "v" + str(cut_version) + ".npy")
                print("loaded correlations... shape is " + str(to_remove.shape))
            else:
                to_remove = np.load("../analysis_code/results/inputs_to_remove_cut_" + str(correlation_cut) + ".npy")
                print("loaded correlations... shape is " + str(to_remove.shape))
        
        self.events = uproot.open(root)   # open the root file and load the events to be processed
        self.training = train    # set a variable which determines whether this is loaded for training. probably not needed
        
        key = self.events.keys()[0]   # get the key to make accessing events easier
        data_list = self.events[key].keys()   # get the names of the inputs
        
        # make a numpy array from events array
        self.events_array = np.array([self.events[key + "/" + k].array(library="np") for k in data_list])
        print(data_list)
    
        self.events_array = np.transpose(self.events_array)   # turn into columns (instead of rows) of data

        
        ################# remove any inputs in to_remove from the events array and the list of inputs ################
        if to_remove.shape[0] > 0:
            print(to_remove.shape)
            self.events_array = np.delete(self.events_array, to_remove, 1)
            to_remove = np.sort(to_remove)
            for i in range(len(to_remove)):
                to_remove[i] -= i
                data_list.pop(to_remove[i])
        ##############################################################################################################
        
        
        if remove:
            """
            Here is where we remove the excess data so that qqbar, gg, and gq/qg events are equally
            represented in the training and analysis.
            """
            # find index of production_mode:
            index = 0
            for i in range(len(data_list)):
                if data_list[i] == "production_mode":
                    index = i
                    break
                    
            # sort events array by production mode
            self.events_array = self.events_array[np.argsort(self.events_array[:, index])]           
            
            ############## find first and last index with production mode 1 #########################
            first = 0  # this will be the first index with qqbar
            last = 0   # this will be the last index with qqbar
            found_first = False   # a flag which allows us to stop looking for first once found
            
            for i in range(len(self.events_array[:,index])):
                if not found_first:
                    if self.events_array[i ,index] == 1:
                        first = i
                        found_first = True
                
                if self.events_array[i,index] == 2:
                    last = i - 1
                    break
            #######################################################################################
                    
            num_qqbar = last + 1 - first  # this is the total number of qqbar events
            print("num qqbar = " + str(num_qqbar))
            
            ####################### remove the extra gg and other #############################################
            max_len = len(self.events_array[:,0])
            
            if include_qg:
                """
                keep every production mode but remove excess events beyond qqbar amount
                """ 
                self.events_array = np.delete(self.events_array, list(range(num_qqbar, first)) + \
                                              list(range(max_len - (max_len - (last + 1) -num_qqbar), max_len)), 0)
                
            elif not include_qg:
                """
                Here we remove all quark-gluon events if we don't want them
                """
                self.events_array = np.delete(self.events_array, list(range(num_qqbar, first)) + \
                                             list(range(last, max_len)), 0)
                
            #################################################################################################
            
        if not remove:
            """
            Here we upsample qqbar and qg/gq instead of downsampling gg events
            """
            # find index of production_mode:
            index = 0
            for i in range(len(data_list)):
                if data_list[i] == "production_mode":
                    index = i
                    break
                    
            # sort events array by production mode
            self.events_array = self.events_array[np.argsort(self.events_array[:, index])]           
            
            ############## find number of qqbar and first and last indexes #########################
            first_q = 0  # this will be the first index with qqbar
            last_q = 0   # this will be the last index with qqbar
            found_first = False   # a flag which allows us to stop looking for first once found
            
            for i in range(len(self.events_array[:,index])):
                if not found_first:
                    if self.events_array[i ,index] == 1:
                        first_q = i
                        found_first = True
                
                if self.events_array[i,index] == 2:
                    last_q = i - 1
                    break
                    
            num_qqbar = last_q + 1 - first_q  # this is the total number of qqbar events
            print("num qqbar = " + str(num_qqbar))
            
            ############## find number of gg and first and last indexes #########################
            first_g = 0  # this will be the first index with qqbar
            last_g = 0   # this will be the last index with qqbar
            found_first = False   # a flag which allows us to stop looking for first once found
            
            for i in range(len(self.events_array[:,index])):
                if not found_first:
                    if self.events_array[i ,index] == 1:
                        first_g = i
                        found_first = True
                
                if self.events_array[i,index] == 2:
                    last_g = i - 1
                    break
                    
            num_gg = last_g + 1 - first_g  # this is the total number of qqbar events
            print("num gg = " + str(num_gg))
            
            ############## find number of qg/gq and first and last indexes #########################
            first_qg = 0  # this will be the first index with qqbar
            last_qg = 0   # this will be the last index with qqbar
            found_first = False   # a flag which allows us to stop looking for first once found
            
            for i in range(len(self.events_array[:,index])):
                if not found_first:
                    if self.events_array[i ,index] == 1:
                        first_qg = i
                        found_first = True
                
                if self.events_array[i,index] == 2:
                    last_qg = i - 1
                    break
                    
            num_qg = last_qg + 1 - first_qg  # this is the total number of qqbar events
            print("num qg = " + str(num_qg))
            
            ####################### remove the extra gg and other #############################################
            max_len = len(self.events_array[:,0])
            
            if include_qg:
                """
                keep every production mode but remove excess events beyond qqbar amount
                """ 
                self.events_array = np.delete(self.events_array, list(range(num_qqbar, first)) + \
                                              list(range(max_len - (max_len - (last + 1) -num_qqbar), max_len)), 0)
                
            elif not include_qg:
                """
                Here we remove all quark-gluon events if we don't want them
                """
                self.events_array = np.delete(self.events_array, list(range(num_qqbar, first)) + \
                                             list(range(last, max_len)), 0)
                
            #################################################################################################
            
            
  
        if normalize:
            """
            Here we normalize all the data by subtracting the minimum and dividing by the range (max-min). We do not normalize
            anything not going to be used as an input to the network (last three columns) except the weights which are only
            divided by the max to ensure we keep the signs.
            """
            # loop through all the inputs. the last three are not looped through because they are not input to the 
            # neural network. Hence, we subtract 3 from the length.
            for i in range(len(self.events_array[0,:])-3):
                ori_a = self.events_array[:,i] # original input column 

                min_a = np.min(self.events_array[:,i]) # min
                max_a = np.max(self.events_array[:,i]) # max
                range_a = max_a - min_a #range

                if range_a > 0:
                    self.events_array[:, i] = (ori_a - min_a) / (range_a) # replace the og array with the normalized array
                
            #################### normalize the weights here: ######################################################
            n = len(self.events_array[0,:])-2   # this is the index of the weights for reco level inputs
            self.events_array[:,n] /= np.max(self.events_array[:,n])   # replace weights with normalized weights
            #######################################################################################################
        
        """
        This is where the stuff is done to format for graphs. We need to set up all the nodes and add the edges.
        Each node represents a particle... they can be l, lbar, nu, nubar, b, bbar, wplus, wminus, t, or tbar. We will sort
        by these (using our gen input if needed) and group these up. If some do not exist then that will also 
        be ok. We will put all the information for a given variable in
        Eventually this should also encode "delta" variables along the edges. For now those will be dropped
        """
        
        ############### for now I will include building up and both similar and same parent TODO: add parameter to determine
        # this
        up_source = torch.tensor([6, 7, 8, 9, 2, 3, 4, 5])
        up_range =  torch.tensor([3, 3, 4, 4, 0, 0, 1, 1])
        
        # build down version
        down_source = torch.tensor([3, 3, 4, 4,  0, 0, 1, 1])
        down_range = torch.tensor([6, 7, 8, 9, 2, 3, 4, 5])
        
        # similar type connection
        sim_source = torch.tensor([6, 8, 7, 9, 2, 5, 3, 4, 0, 1])
        sim_range = torch.tensor([8, 6, 9, 7, 5, 2, 4, 3, 1, 0])
        
        # same parent
        child_source = torch.tensor([6, 7, 8, 9, 2, 3, 4, 5])
        child_range = torch.tensor([7, 6, 9, 8, 3, 2, 5, 4])
        
        ############### for now I will include building up and both similar and same parent TODO: add parameter to determine
        # this
        source_ids = torch.cat((up_source, sim_source, child_source))
        range_ids = torch.cat((up_range, sim_range, child_range))
        graph = dgl.graph((source_ids, range_ids), num_nodes=10)
        graph = dgl.add_self_loop(graph)
        
        ########## Now put together the dataset for each of the events... for now I am assuming I have the same data for
        ########## each node
        overall_count = 0
        print((len(self.events_array[0,:]) - 3)/10)
        new_array_data = torch.empty(len(self.events_array[:,0]),10, int((len(self.events_array[0,:]) - 3)/10))
        for event in self.events_array:
            event_data = torch.empty(10, int((len(event) - 3)/10))
            
            prepend = ""
            if gen:
                prepend = "gen_"
            
            count = 0
            for i in range(len(self.events_array[0,:]) -3):
                if (data_list[i][0:len("top") + len(prepend)]) == prepend + "top":
                    event_data[0,count] = event[i]
                    count+=1
            
            count = 0
            for i in range(len(self.events_array[0,:]) -3):
                if (data_list[i][0:len("tbar") + len(prepend)]) == prepend + "tbar":
                    event_data[1,count] = event[i]
                    count+=1
                    
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("b_") + len(prepend)]) == prepend + "b_":
                    event_data[2,count] = event[i]
                    count+=1
                    
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("wplus") + len(prepend)]) == prepend + "wplus":
                    event_data[3,count] = event[i]
                    count+=1
                    
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("wminus") + len(prepend)]) == prepend + "wminus":
                    event_data[4,count] = event[i]
                    count+=1
            
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("bbar") + len(prepend)]) == prepend + "bbar":
                    event_data[5,count] = event[i]
                    count+=1
            
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("lbar") + len(prepend)]) == prepend + "lbar":
                    event_data[6,count] = event[i]
                    count+=1
            
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("nu_") + len(prepend)]) == prepend + "nu_":
                    event_data[7,count] = event[i]
                    count+=1
                    
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("l_") + len(prepend)]) == prepend + "l_":
                    event_data[8,count] = event[i]
                    count+=1
                    
            count = 0
            for i in range(len(self.events_array[0,:])):
                if (data_list[i][0:len("nubar") + len(prepend)]) == prepend + "nubar":
                    event_data[9,count] = event[i]
                    count+=1
            
            ### Now we have event data and can make this a tensor which goes along with the graph for each event
            new_array_data[overall_count] = event_data
            overall_count+=1
        
        # out of the loop now
        # time to make the event array be new form
        self.weights = self.events_array[:,len(self.events_array[0,:]) - 2]
        self.production_mode = self.events_array[:,len(self.events_array[0,:]) - 3]
        self.events_array = new_array_data
        
               
        if split:
            """
            Here we split the events array into a validation dataset and a training dataset. 80% will be training and 
            20% will be validation.
            """
            # shuffle so no longer sorted by production modebefore splitting
            p = np.random.permutation(len(self.weights))
            self.weights = self.weights[p]
            self.production_mode = self.production_mode[p]
            self.events_array = self.events_array[p]
            
            train_size = int(len(self.events_array[:,0,0])*80/100)  # calculate the size of the training array
            
            # split into the two arrays
            self.train_array, self.eval_array = np.split(self.events_array, [train_size])
            self.train_weights, self.eval_weights = np.split(self.weights, [train_size])
            self.train_target, self. eval_target = np.split(self.production_mode, [train_size])
            print("training " + str(self.train_array.shape))
            print("evaluating " + str(self.eval_array.shape))
            
        else:
            """
            If not splitting, just set both the train_array and eval_array to be the entire dataset
            """
            self.train_array = self.events_array
            self.eval_array = self.events_array
            self.train_weights = self.weights
            self.eval_weights = self.weights
            self.train_target = self.production_mode
            self.eval_target = self.production_mode
            
        
                
    def __getitem__(self, index):
        """
        Return the training array (should never be eval... might delete stuff) in format needed for dataloader 
        """
        if self.training:
            return self.train_array[index], self.train_target[index], self.train_weights[index]
        return self.eval_array[index], self.eval_target[index], self.eval_weights[index]
    
    def __len__(self):
        """
        Return the length of the training array (once again, might want to delete extra code for eval)
        """
        if self.training:
            return len(self.train_array[:,0,0])
        return len(self.eval_array[:,0,0])
    
    def get_eval_data(self):
        """ Return the entire validation dataset """
        return self.eval_array, self.eval_target, self.eval_weights
    
    def get_training_data(self):
        """ Return the entire training dataset """
        return self.train_array, self.train_target, self.train_weights

# Testing code to ensure the class works

In [42]:
root_path = "/depot/cms/top/mcnama20/TopSpinCorr-Run2-Entanglement/CMSSW_10_2_22/src/TopAnalysis/Configuration/analysis/diLeptonic/three_files/Nominal"

file = root_path + "/ee_modified_root_1_allorentz_gen.root"

data_o = GraphProductionModeDataset(file, correlation_cut = -1, cut_version=-1, include_qg = False)

['gen_l_eta', 'gen_lbar_eta', 'gen_l_phi', 'gen_lbar_phi', 'gen_l_pt', 'gen_lbar_pt', 'gen_l_mass', 'gen_lbar_mass', 'gen_nu_eta', 'gen_nubar_eta', 'gen_nu_phi', 'gen_nubar_phi', 'gen_nu_pt', 'gen_nubar_pt', 'gen_nu_mass', 'gen_nubar_mass', 'gen_b_eta', 'gen_bbar_eta', 'gen_b_phi', 'gen_bbar_phi', 'gen_b_pt', 'gen_bbar_pt', 'gen_b_mass', 'gen_bbar_mass', 'gen_top_eta', 'gen_tbar_eta', 'gen_top_phi', 'gen_tbar_phi', 'gen_top_pt', 'gen_tbar_pt', 'gen_top_mass', 'gen_tbar_mass', 'gen_wplus_eta', 'gen_wminus_eta', 'gen_wplus_phi', 'gen_wminus_phi', 'gen_wplus_pt', 'gen_wminus_pt', 'gen_wminus_mass', 'gen_wplus_mass', 'production_mode', 'eventWeight', '__index__']
num qqbar = 52908
4.0
training torch.Size([84652, 10, 4])
evaluating torch.Size([21163, 10, 4])


In [43]:
data_o.__getitem__(0)

(tensor([[4.8760e-01, 5.1726e-01, 9.0414e-02, 1.7250e+02],
         [6.7380e-01, 2.2498e-02, 6.5248e-02, 1.7250e+02],
         [5.6652e-01, 5.3617e-01, 8.0513e-02, 6.1055e-01],
         [5.1515e-01, 4.8376e-01, 3.9854e-02, 5.2797e-01],
         [6.1145e-01, 9.7792e-01, 8.8921e-02, 5.0845e-01],
         [5.2334e-01, 3.7969e-01, 4.3108e-02, 6.4319e-01],
         [3.9618e-01, 7.1380e-01, 5.3003e-02, 5.1070e-01],
         [4.4443e-01, 3.4056e-01, 7.4200e-02, 6.0778e-01],
         [5.4001e-01, 9.0005e-01, 5.1843e-02, 5.8100e-01],
         [5.8979e-01, 3.6507e-02, 7.8783e-02, 4.8109e-01]]),
 0.0,
 0.8816012037936595)