In [1]:
#from torch.utils.data import *
from sklearn.metrics import roc_curve, auc
from torch.utils.data import ConcatDataset, Dataset, DataLoader, sampler, DistributedSampler

import pyarrow.parquet as pq
import pyarrow as pa # pip install pyarrow==0.7.1
import ROOT
import numpy as np
np.random.seed(0)
import glob, os

import dask.array as da

#from scipy.misc import imresize

import matplotlib.pyplot as plt
#%matplotlib inline
from matplotlib.colors import LogNorm, ListedColormap, LinearSegmentedColormap
import matplotlib.ticker as ticker
from matplotlib.ticker import MultipleLocator
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches

from skimage.measure import block_reduce
from numpy.lib.stride_tricks import as_strided

Welcome to JupyROOT 6.24/02


In [2]:
fileStr = '/home/ruchi/e2e_tau/MLAnalyzer/convertRootFiles/DYToTauTau_jets.parquet.1'
outDir = '/home/ruchi/e2e_tau/MLAnalyzer/plottingMacros/croppedJetIMG'
f0s = glob.glob(fileStr)

In [3]:
class ParquetDataset(Dataset):
    def __init__(self, filename):
        self.parquet = pq.ParquetFile(filename)
        self.cols = None # read all columns
        #self.cols = ['X_jet.list.item.list.item.list.item','y'] 
    def __getitem__(self, index):
        data = self.parquet.read_row_group(index, columns=self.cols).to_pydict()
        data['X_jet'] = np.float32(data['X_jet'][0])
        #data['y'] = np.float32(data['y'])
        data['m0'] = np.float32(data['jetM'])
        data['pt'] = np.float32(data['jetPt'])
        # Preprocessing
        data['X_jet'][data['X_jet'] < 1.e-3] = 0. # Zero-Suppression
        data['X_jet'][4,...] = 25.*data['X_jet'][4,...] # For HCAL: to match pixel intensity dist of other layers
        data['X_jet'] = data['X_jet']/100. # To standardize
        return dict(data)
    def __len__(self):
        return self.parquet.num_row_groups

In [4]:
def custom_div_cmap(numcolors=11, name='custom_div_cmap',mincol='blue', midcol='white', maxcol='red'):
    cmap = LinearSegmentedColormap.from_list(name=name,colors=[mincol, midcol, maxcol],N=numcolors)
    return cmap

pink_map = custom_div_cmap(50, mincol='#FFFFFF', midcol='#F699CD' ,maxcol='#FF1694')

In [5]:
def plotJet(img, mins, maxs, str_):
    #im = plt.imshow(np.zeros_like(img[8,:,:]), cmap='Purples', vmin=0., vmax=1., alpha=0.9)
    #if maxs[-1] > 0 : plt.imshow(img[8,:,:], cmap='Greens', norm=LogNorm(), alpha=0.9, vmin=mins[-1], vmax=maxs[-1])
    if maxs[-1] > 0 : plt.imshow(img[7,:,:], cmap='Purples', norm=LogNorm(), alpha=0.9, vmin=mins[-1], vmax=maxs[-1])
    if maxs[-2] > 0 : plt.imshow(img[6,:,:], cmap='Blues', norm=LogNorm(), alpha=0.9, vmin=mins[-2], vmax=maxs[-2])
    if maxs[-3] > 0 : plt.imshow(img[5,:,:], cmap='Greens', norm=LogNorm(), alpha=0.9, vmin=mins[-3], vmax=maxs[-3])
    #if maxs[-4] > 0 : plt.imshow(img[4,:,:], cmap='Greys',  norm=LogNorm(), alpha=0.9, vmin=mins[-4], vmax=maxs[-4])
    #if maxs[-5] > 0 : plt.imshow(img[3,:,:], cmap='Blues',  norm=LogNorm(), alpha=0.9, vmin=mins[-5], vmax=maxs[-5])
    #if maxs[-6] > 0 : plt.imshow(img[0,:,:], cmap='Oranges',norm=LogNorm(), alpha=0.9, vmin=mins[-6], vmax=maxs[-6])
    #plt.colorbar(fraction=0.046, pad=0.04)

    #X AXIS
    ax = plt.axes()
    plt.xlim([0., 125.+0.])
    plt.xticks(np.arange(0,150,25))
    ax_range_x = np.arange(0,125+25,25)
    ax.set_xticks(ax_range_x)
    ax.set_xticklabels(ax_range_x)
    plt.xlabel(r"$\mathrm{i\varphi}'$", size=28) #28, 30
    ax.xaxis.set_tick_params(direction='in', which='major', length=6.)
    ax.xaxis.set_tick_params(direction='in', which='minor', length=3.)

    #Y AXIS
    plt.ylim([125.+0.,0.])
    plt.yticks(np.arange(150,0,25))
    plt.ylabel(r"$\mathrm{i\eta}'$", size=28) #28, 30
    ax_range_y = np.arange(0,125+25,25)
    ax.set_yticks(ax_range_y)
    ax.set_yticklabels(ax_range_y)
    ax.yaxis.set_tick_params(direction='in', which='major', length=6.)
    ax.yaxis.set_tick_params(direction='in', which='minor', length=3.)

    #LEGEND
    #colors = {1:'tab:orange',2:'tab:blue',3:'tab:grey',4:'tab:green',5:'tab:blue',6:'tab:purple',7:'tab:green'}
    colors = {1:'orange',2:'lightblue',3:'grey',4:'green',5:'blue',6:'purple'}
    #labels = {1:'Track pT',2:'ECAL',3:'HCAL',4:'PXB1',5:'PXB2',6:'PXB3',7:'PXB4'}
    labels = {1:'Track pT',2:'ECAL',3:'HCAL',4:'BPix L1',5:'BPix L2',6:'BPix L3'}
    patches =[mpatches.Patch(color=colors[i],label=labels[i]) for i in colors]
    #plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
    plt.legend(handles=patches, loc='center left', bbox_to_anchor=(1, 0.5))
    #plt.savefig(str_, bbox_inches='tight')
    plt.savefig(str_, bbox_inches='tight', format='png')
    #plt.show()
    plt.clf()

In [6]:
dset_train = ParquetDataset(fileStr)
train_cut = 50
idxs = np.random.permutation(len(dset_train))
train_sampler = sampler.SubsetRandomSampler(idxs[:train_cut])
#train_loader = DataLoader(dataset=dset_train, batch_size=32, num_workers=0, sampler=train_sampler, pin_memory=True)
train_loader = DataLoader(dataset=dset_train, batch_size=2, num_workers=0, shuffle=False, pin_memory=True)
for i, data in enumerate(train_loader):
    #if i < args.skipEvents: continue
    #print (" Event ", i)
      
    #if i == args.nEvents: break
    X_train = data['X_jet']
    y_train = data['y']

    #print(type(X_train))
    
    plt.rcParams["font.family"] = "Helvetica"
    plt.rcParams["figure.figsize"] = (12,12)
    #plt.rcParams["axes.facecolor"] = "white"
    plt.rcParams.update({'font.size': 26})
    
    cmap = ['Oranges','Blues','Greys','Reds',pink_map,'Purples','Greens']
    min_ = 0.0001

    #data['y'][0]
    #print(data['pdgId'])

    nJets = list(data['y'][0].size())[0]
    print("there are ", nJets, "jets in the event")

    for jet in range(nJets):
        img = X_train[jet,:,:,:]
        print("JET LABEL IS  ", jet)
        #Selecting only taus
        #if y_train[jet] == 0: continue

        
        '''for ch in range(7):
            img_ = img[ch,:,:]
            max_ = img_.max()
            if max_ == 0: continue
            print "Channel ", ch, " , Max = ", max_
            plotJet_chnl(img_, cmap[ch], min_, max_, '%s/tau_event%d_jet%d_chnl%d.png'%(outDir,i,jet,ch))'''
        

        mins = [0.0001]*7
        maxs = [X_train[jet,0,:,:].max(), X_train[jet,3,:,:].max(), X_train[jet,4,:,:].max(), 
                X_train[jet,5,:,:].max(), X_train[jet,6,:,:].max(), X_train[jet,7,:,:].max(), X_train[jet,8,:,:].max()]
        print ("Min = ", mins, " | Max = ", maxs)
        plotJet(img, mins, maxs, '%s/tau_event%d_jet%d.png'%(outDir,i,jet))

there are  2 jets in the event
JET LABEL IS   0
Min =  [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]  | Max =  [tensor(0.2292), tensor(0.0006), tensor(0.1893), tensor(0.2200), tensor(0.0900), tensor(0.0700), tensor(0.1000)]


  ax = plt.axes()
findfont: Font family ['Helvetica'] not found. Falling back to DejaVu Sans.
findfont: Font family ['Helvetica'] not found. Falling back to DejaVu Sans.


JET LABEL IS   1
Min =  [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]  | Max =  [tensor(0.1882), tensor(0.0390), tensor(0.1241), tensor(0.1800), tensor(0.1300), tensor(0.1100), tensor(0.0600)]
there are  2 jets in the event
JET LABEL IS   0
Min =  [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]  | Max =  [tensor(0.2088), tensor(0.1453), tensor(0.0265), tensor(0.5900), tensor(0.2300), tensor(0.1500), tensor(0.0800)]
JET LABEL IS   1
Min =  [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]  | Max =  [tensor(0.1745), tensor(0.0435), tensor(0.0232), tensor(0.7600), tensor(0.2100), tensor(0.3200), tensor(0.1700)]
there are  2 jets in the event
JET LABEL IS   0
Min =  [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]  | Max =  [tensor(0.2318), tensor(0.1136), tensor(0.0418), tensor(0.4800), tensor(0.1900), tensor(0.1000), tensor(0.0800)]
JET LABEL IS   1
Min =  [0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001]  | Max =  [tensor(0.2229), tensor(0.0281)

<Figure size 864x864 with 0 Axes>