In [2]:
import numpy as np
import os
import torch
import uproot
from torch.utils.data import Dataset

In [75]:
class ProductionModeDataset(Dataset):
    """
    This class will load the minitree and create the dataset loaded into the nns
    
    0) load the minitree properly
    
    1) remove 0 and 2 type events so we have the same number as qqbar
    
    2) normalize all data
    
    3) ensure this is stored in a numpy array with the correct arrangement of data and a corresponding list of datatypes
    
    4) split this up into training and evaluating datasets
    
    Args:
        :param root (string): is root directory and file name of minitree
        :param split (boolean): tells whether to split into training and eval
        :param normalize (boolean): tells whether to normalize data
        :param remove (boolean): tells whether we should remove excess data for non-qqbar events or duplicate qqbar
    """
    
    def __init__(self, root, split=True, normalize=True, remove=True):
        self.events = uproot.open(root)
        #self.events = self.events['Events']
        
        #data_list = data = uproot.open(depo_dir + root_f, key = "ttBar_treeVariables_step8;3")
        
        
        # TODO: make less complex when not needed later
#         data_list = [x in self.events["ttBar_treeVariables_step8;3"].keys() if ((("phi" in x) or ("eta" in x) \
#                                                                              or ("pt" in x) or ("production" in x)) and \
#                                                                              ("gen" not in x))]
#         print(str(data_list))
        print(self.events.keys())
        data_list = self.events["ttBar_treeVariables_step8;4;1"].keys()
    
        self.events_array = np.array([self.events["ttBar_treeVariables_step8;4;1" + "/" + k].array(library="np") for k in data_list])
#         print((self.events_array.shape))
        print(len(data_list))
        if remove:
#             print("remove")
            """
            Here is where we remove the excess data
            """
            # find index of production_mode:
            index = 0
            for i in range(len(data_list)):
                if data_list[i] == "production_mode":
                    index = i
                    break
                    
#             print(self.events_array.shape)
            self.events_array = self.events_array[:, np.argsort(self.events_array[index, :])]            
            print(self.events_array[index, :])
            
            # find first and last index with production mode 1
            first = 0
            last = 0
            found_first = False
            for i in range(len(self.events_array[index,:])):
                if !found_first:
                    if self.events_array[index, i] == 1:
                        first = i
                
                if self.events_array[index,i] == 2:
                    last = i - 1
                    break
                    
            num_qqbar = last + 1 - first
            
            # remove the extra gg and other
            for i in range(num_qqbar, first):
                self.events_array = np.delete(self.events_array, i, 0)
            
            max_len = len(self.events_array[0,:])
            for i in range(max_len - (max_len - (last + 1) -num_qqbar), max_len):
                self.events_array = np.delete(self.events_array, i, 0)
            
            
        # normalize here
        
        
        # transpose and split here
        
        

In [76]:
root_path = "/depot-new/cms/top/mcnama20/TopSpinCorr-Run2-Entanglement/CMSSW_10_2_22/src/TopAnalysis/Configuration/analysis/diLeptonic/three_files"

file = root_path + "/ee_modified_root.root"

ProductionModeDataset(file)

['ttBar_treeVariables_step8;4;1']
95
(95, 97722)
[0. 0. 0. ... 2. 2. 2.]


<__main__.ProductionModeDataset at 0x7f4c497e6a90>