In [1]:
import numpy as np
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import h5py
import pickle
import pandas
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.models import Model,Sequential
from keras.layers import Input, Dense, Dropout
from keras.utils import plot_model
from keras.models import load_model
from sklearn.preprocessing import scale, normalize
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.utils import check_random_state
import onnx
import onnxruntime as ort
import pandas as pd

Using TensorFlow backend.


In [2]:
tf.__version__

'2.1.0'

In [3]:
ort.__version__

'1.1.1'

In [4]:
conv = tf.keras.models.load_model('/lcg/storage13/atlas/gupta/checkpoint.h5')

In [5]:
conv.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
em_barrel_Lr1 (InputLayer)      [(None, 56, 11)]     0                                            
__________________________________________________________________________________________________
em_barrel_Lr0 (InputLayer)      [(None, 7, 11)]      0                                            
__________________________________________________________________________________________________
em_barrel_Lr2 (InputLayer)      [(None, 7, 11)]      0                                            
__________________________________________________________________________________________________
em_barrel_Lr3 (InputLayer)      [(None, 7, 11)]      0                                            
______________________________________________________________________________________________

**Keras to ONNX conversion:**

In [5]:
os.environ['TF_KERAS'] = '1'
import keras2onnx
keras2onnx.convert_keras(conv,conv.name)

**Data Preparation:**

In [4]:
images    = ['em_barrel_Lr0', 'em_barrel_Lr1_fine', 'em_barrel_Lr2'  , 'em_barrel_Lr3',
             'tile_barrel_Lr1', 'tile_barrel_Lr2', 'tile_barrel_Lr3', 'tracks_image']
scalars   = ['p_Eratio', 'p_Reta'   , 'p_Rhad'     , 'p_Rphi'  , 'p_TRTPID' , 'p_numberOfSCTHits'  ,
             'p_ndof'  , 'p_dPOverP', 'p_deltaEta1', 'p_f1'    , 'p_f3'     , 'p_deltaPhiRescaled2',
             'p_weta2' , 'p_d0'     , 'p_d0Sig'    , 'p_qd0Sig', 'p_nTracks', 'p_sct_weight_charge']
others    = ['eventNumber', 'p_TruthType', 'p_iffTruth', 'p_LHTight', 'p_LHMedium', 'p_LHLoose',
             'p_eta', 'p_et_calo','p_LHValue']
train_var = {'images' :images, 'tracks':[],'scalars':scalars}
all_var   = {**train_var, 'others':others}; scalars = train_var['scalars']

In [5]:
def make_sample(data_file, all_var, idx, n_tracks, n_classes, cuts='', p='p_', upscale=False):
    var_list = np.sum(list(all_var.values()))
    with h5py.File(data_file, 'r') as data:
        sample = {key:data[key][idx[0]:idx[1]] for key in var_list if key != 'tracks_image'}
        if 'tracks_image' in var_list or 'tracks' in var_list:
            n_tracks    = min(n_tracks, data[p+'tracks'].shape[1])
            tracks_data = data[p+'tracks'][idx[0]:idx[1]][:,:n_tracks,:]
            tracks_data = np.concatenate((abs(tracks_data[...,0:5]), tracks_data[...,5:13]), axis=2)
    if 'tracks_image' in var_list: sample.update({'tracks_image':tracks_data})
    if 'tracks'       in var_list: sample['tracks'] = tracks_data
    if tf.__version__ < '2.1.0': sample = {key:np.float32(sample[key]) for key in sample}
    if upscale:
        for n in all_var['images']: sample[n] = resize_images(np.float32(sample[n]), target_shape=(56,11))
    #if idx[1]-idx[0] > 1: print('(', '\b'+format(time.time() - start_time, '2.1f'), '\b'+' s)')
    return sample, make_labels(sample, n_classes)

In [6]:
def make_labels(sample, n_classes):
    MC_type, IFF_type = sample['p_TruthType'], sample['p_iffTruth']
    if n_classes == 2:
        labels = np.where(IFF_type <= 1                                 , -1, IFF_type)
        labels = np.where(IFF_type == 2                                 ,  0, labels  )
        return   np.where(IFF_type >= 3                                 ,  1, labels  )
    elif n_classes == 6:
        labels = np.where(np.logical_or (IFF_type <= 1, IFF_type == 4)  , -1, IFF_type)
        labels = np.where(np.logical_or (IFF_type == 6, IFF_type == 7)  , -1, labels  )
        labels = np.where(IFF_type == 2                                 ,  0, labels  )
        labels = np.where(IFF_type == 3                                 ,  1, labels  )
        labels = np.where(IFF_type == 5                                 ,  2, labels  )
        labels = np.where(np.logical_or (IFF_type == 8, IFF_type == 9)  ,  3, labels  )
        labels = np.where(np.logical_and(IFF_type ==10,  MC_type == 4)  ,  4, labels  )
        labels = np.where(np.logical_and(IFF_type ==10,  MC_type ==16)  ,  4, labels  )
        labels = np.where(np.logical_and(IFF_type ==10,  MC_type ==17)  ,  5, labels  )
        return   np.where(  labels == 10                                , -1, labels  )
    elif n_classes == 9:
        labels = np.where(IFF_type == 9                                 ,  4, IFF_type)
        return   np.where(IFF_type ==10                                 ,  6, labels  )
    else: print('\nERROR:', n_classes, 'classes not supported -> exiting program\n'); sys.exit()

In [7]:
def load_scaler(sample, scalars, scaler_file):
    print('CLASSIFIER: loading scaler transform from ' + scaler_file)
    scaler         = pickle.load(open(scaler_file, 'rb'))
    #start_time     = time.time()
    scalars_scaled = np.hstack([np.expand_dims(sample[key], axis=1) for key in scalars])
    print('CLASSIFIER: applying scaler transform to scalar variables', end=' ... ', flush=True)
    scalars_scaled = scaler.transform(scalars_scaled)
    for n in np.arange(len(scalars)): sample[scalars[n]] = scalars_scaled[:,n]
    #print('(', '\b'+format(time.time() - start_time, '2.1f'), '\b'+' s)')
    return sample

In [8]:
data_file = '/lcg/storage13/atlas/gupta/el_data.h5' 
scaler_file = '/lcg/storage13/atlas/gupta/scaler.pkl' 

In [9]:
sample, labels = make_sample(data_file, all_var, [0,1], 4, 2)

In [13]:
#sample

In [12]:
labels

array([1], dtype=int32)

In [10]:
sample = load_scaler(sample, scalars, scaler_file)

CLASSIFIER: loading scaler transform from /lcg/storage13/atlas/gupta/scaler.pkl
CLASSIFIER: applying scaler transform to scalar variables ... 



In [14]:
#sample

In [11]:
ini_list= ['em_barrel_Lr0', 'em_barrel_Lr1','em_barrel_Lr2', 'em_barrel_Lr3', 'tile_barrel_Lr1', 
'tile_barrel_Lr2', 'tile_barrel_Lr3', 'p_Eratio', 'p_Reta', 'p_Rhad', 'p_Rphi', 
'p_TRTPID', 'p_numberOfSCTHits', 'p_ndof', 'p_dPOverP', 'p_deltaEta1', 'p_f1', 'p_f3',
'p_deltaPhiRescaled2', 'p_weta2', 'p_d0', 'p_d0Sig', 'p_qd0Sig', 'p_nTracks', 
'p_sct_weight_charge', 'eventNumber', 'p_TruthType', 'p_iffTruth', 'p_LHTight', 
'p_LHMedium', 'p_LHLoose', 'p_eta', 'p_et_calo', 'p_LHValue', 'tracks_image']
final_sample0 = dict(zip(ini_list, list(sample.values()))) 

**This part is Important: I needed to convert float16 (default) to float32 to be fed in .ONNX**

In [12]:
final_sample1 = {}
ini_list=['em_barrel_Lr0', 'em_barrel_Lr1','em_barrel_Lr2', 'em_barrel_Lr3', 'tile_barrel_Lr1', 
'tile_barrel_Lr2', 'tile_barrel_Lr3', 'tracks_image', 'p_Eratio', 'p_Reta', 'p_Rhad', 'p_Rphi', 
'p_TRTPID', 'p_numberOfSCTHits', 'p_ndof', 'p_dPOverP', 'p_deltaEta1', 'p_f1', 'p_f3',
'p_deltaPhiRescaled2', 'p_weta2', 'p_d0', 'p_d0Sig', 'p_qd0Sig', 'p_nTracks', 
'p_sct_weight_charge']
for key in ini_list:
    final_sample1[key] = final_sample0[key].astype(np.float32)

In [13]:
final_sample1

{'em_barrel_Lr0': array([[[ 0.0000000e+00,  0.0000000e+00, -4.0054321e-05, -4.0054321e-05,
          -4.0054321e-05, -4.0054321e-05,  3.8862228e-04,  3.8862228e-04,
           3.8862228e-04,  3.8862228e-04,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  6.3776970e-05,  6.3776970e-05,
           6.3776970e-05,  6.3776970e-05, -1.0418892e-04, -1.0418892e-04,
          -1.0418892e-04, -1.0418892e-04,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  1.1473894e-04,  1.1473894e-04,
           1.1473894e-04,  1.1473894e-04,  5.5253506e-05,  5.5253506e-05,
           5.5253506e-05,  5.5253506e-05,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  2.7441978e-04,  2.7441978e-04,
           2.7441978e-04,  2.7441978e-04,  5.6314468e-04,  5.6314468e-04,
           5.6314468e-04,  5.6314468e-04,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  9.8288059e-05,  9.8288059e-05,
           9.8288059e-05,  9.8288059e-05,  1.7571449e-04,  1.7571449e-04,
       

In [40]:
#sample['em_barrel_Lr1'] = sample.pop('em_barrel_Lr1_fine')

In [14]:
#sample
len(ini_list)

26

In [19]:
for i in range(len(ini_list)):
    print(ini_list[i],":",final_sample1[ini_list[i]].shape)

em_barrel_Lr0 : (1, 7, 11)
em_barrel_Lr1 : (1, 56, 11)
em_barrel_Lr2 : (1, 7, 11)
em_barrel_Lr3 : (1, 7, 11)
tile_barrel_Lr1 : (1, 7, 11)
tile_barrel_Lr2 : (1, 7, 11)
tile_barrel_Lr3 : (1, 7, 11)
tracks_image : (1, 4, 13)
p_Eratio : (1,)
p_Reta : (1,)
p_Rhad : (1,)
p_Rphi : (1,)
p_TRTPID : (1,)
p_numberOfSCTHits : (1,)
p_ndof : (1,)
p_dPOverP : (1,)
p_deltaEta1 : (1,)
p_f1 : (1,)
p_f3 : (1,)
p_deltaPhiRescaled2 : (1,)
p_weta2 : (1,)
p_d0 : (1,)
p_d0Sig : (1,)
p_qd0Sig : (1,)
p_nTracks : (1,)
p_sct_weight_charge : (1,)


In [1]:
conv.predict(final_sample1)

**The ONNX part starts here:**

In [15]:
sess_ort = ort.InferenceSession("/lcg/storage13/atlas/gupta/checkpoint.onnx")

In [16]:
for i in range(len(sess_ort.get_inputs())):
    print(sess_ort.get_inputs()[i].name,":",sess_ort.get_inputs()[i].shape)

em_barrel_Lr0 : ['N', 7, 11]
em_barrel_Lr1 : ['N', 56, 11]
em_barrel_Lr2 : ['N', 7, 11]
em_barrel_Lr3 : ['N', 7, 11]
tile_barrel_Lr1 : ['N', 7, 11]
tile_barrel_Lr2 : ['N', 7, 11]
tile_barrel_Lr3 : ['N', 7, 11]
tracks_image : ['N', 4, 13]
p_Eratio : ['N']
p_Reta : ['N']
p_Rhad : ['N']
p_Rphi : ['N']
p_TRTPID : ['N']
p_numberOfSCTHits : ['N']
p_ndof : ['N']
p_dPOverP : ['N']
p_deltaEta1 : ['N']
p_f1 : ['N']
p_f3 : ['N']
p_deltaPhiRescaled2 : ['N']
p_weta2 : ['N']
p_d0 : ['N']
p_d0Sig : ['N']
p_qd0Sig : ['N']
p_nTracks : ['N']
p_sct_weight_charge : ['N']


In [17]:
input_name = sess_ort.get_inputs()[0].name
print("input name", input_name)
input_shape = sess_ort.get_inputs()[0].shape
print("input shape", input_shape)
input_type = sess_ort.get_inputs()[0].type
print("input type", input_type)

input name em_barrel_Lr0
input shape ['N', 7, 11]
input type tensor(float)


In [18]:
len(sess_ort.get_inputs())

26

In [19]:
output_name = sess_ort.get_outputs()[0].name
print("output_name", output_name)
output_shape = sess_ort.get_outputs()[0].shape
print("output shape", output_shape)
output_type = sess_ort.get_outputs()[0].type
print("output type", output_type)

output_name dense_2
output shape ['N', 2]
output type tensor(float)


In [20]:
m = onnx.load('/lcg/storage13/atlas/gupta/checkpoint.onnx')

In [21]:
output = []
for o in m.graph.output:
        output.append(o.name)

In [22]:
output

['dense_2']

In [23]:
output_name

'dense_2'

In [24]:
input_feeds = {}
key = ['em_barrel_Lr0', 'em_barrel_Lr1', 'em_barrel_Lr2', 'em_barrel_Lr3', 'tile_barrel_Lr1', 
'tile_barrel_Lr2', 'tile_barrel_Lr3', 'tracks_image','p_Eratio', 'p_Reta', 'p_Rhad', 
'p_Rphi', 'p_TRTPID', 'p_numberOfSCTHits', 'p_ndof', 'p_dPOverP', 
'p_deltaEta1', 'p_f1', 'p_f3', 'p_deltaPhiRescaled2', 'p_weta2', 
'p_d0', 'p_d0Sig', 'p_qd0Sig', 'p_nTracks', 'p_sct_weight_charge']
for i in range(26):
    input_feeds[sess_ort.get_inputs()[i].name] = final_sample1[key[i]]

In [25]:
input_feeds

{'em_barrel_Lr0': array([[[ 0.0000000e+00,  0.0000000e+00, -4.0054321e-05, -4.0054321e-05,
          -4.0054321e-05, -4.0054321e-05,  3.8862228e-04,  3.8862228e-04,
           3.8862228e-04,  3.8862228e-04,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  6.3776970e-05,  6.3776970e-05,
           6.3776970e-05,  6.3776970e-05, -1.0418892e-04, -1.0418892e-04,
          -1.0418892e-04, -1.0418892e-04,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  1.1473894e-04,  1.1473894e-04,
           1.1473894e-04,  1.1473894e-04,  5.5253506e-05,  5.5253506e-05,
           5.5253506e-05,  5.5253506e-05,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  2.7441978e-04,  2.7441978e-04,
           2.7441978e-04,  2.7441978e-04,  5.6314468e-04,  5.6314468e-04,
           5.6314468e-04,  5.6314468e-04,  0.0000000e+00],
         [ 0.0000000e+00,  0.0000000e+00,  9.8288059e-05,  9.8288059e-05,
           9.8288059e-05,  9.8288059e-05,  1.7571449e-04,  1.7571449e-04,
       

In [26]:
result = sess_ort.run(output, input_feeds)

In [27]:
result

[array([[2.8377053e-04, 9.9971622e-01]], dtype=float32)]