In [48]:
import deepcalo as dpcal
import json
import keras as ks
import tensorflow as tf
import sys
import onnxruntime as ort
import os
import tensorflow as tf
import tf2onnx
import matplotlib.pyplot as plt
from importlib import import_module
from tqdm import tqdm
import numpy as np
import onnx

In [49]:
print(dpcal.__version__)
print(sys.version)
print(ort.__version__)
print(tf2onnx.__version__)

0.2.3
3.8.6 (v3.8.6:db455296be, Sep 23 2020, 13:31:39) 
[Clang 6.0 (clang-600.0.57)]
1.7.0
1.8.5


In [50]:
!ls ../Downloads/Zee_mc_1000_epochs_3_8_5

combine_model.0008-31.3847.hdf5 model.0046-2.9464.json
dataparams.pkl                  model.h5
dataparams.txt                  weights.0046-2.9464.hdf5


In [51]:
model_path = '../Downloads/Zee_mc_1000_epochs_3_8_5/model.0046-2.9464.json'
print(model_path)

../Downloads/Zee_mc_1000_epochs_3_8_5/model.0046-2.9464.json


In [52]:
with open(model_path, 'r') as model_json:
    arch = json.load(model_json)
    model = ks.models.model_from_json(arch, custom_objects={'FiLM': dpcal.layers.FiLM()})

In [53]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
em_barrel (InputLayer)          [(None, 56, 11, 4)]  0                                            
__________________________________________________________________________________________________
scalars (InputLayer)            [(None, 16)]         0                                            
__________________________________________________________________________________________________
up_sampling2d (UpSampling2D)    (None, 56, 55, 4)    0           em_barrel[0][0]                  
__________________________________________________________________________________________________
scalar_net (Functional)         (None, 256)          5120        scalars[0][0]                    
____________________________________________________________________________________________

In [54]:
model.load_weights("../Downloads/Zee_mc_1000_epochs_3_8_5/weights.0046-2.9464.hdf5")

In [55]:
spec = (

    tf.TensorSpec((None,  5), tf.float32, name="event_info"),

    tf.TensorSpec((None,  10, 13, 1), tf.float32, name="tracks"),

    tf.TensorSpec((None,  56, 11, 4), tf.float32, name="em_barrel"),

    tf.TensorSpec((None,  1), tf.float32, name="multiply_output_with"),

    tf.TensorSpec((None,  16), tf.float32, name="scalars"),

)

## Convert keras model to Onnx

In [56]:
model_proto, _ = tf2onnx.convert.from_keras(

    model, input_signature=spec, opset=13, output_path="../Downloads/model.0046_2.9464.onnx")

## Validating the Onnx model

In [57]:
sess_ort = ort.InferenceSession("../Downloads/model.0046_2.9464.onnx")

In [58]:
for i in range(len(sess_ort.get_inputs())):
    print(sess_ort.get_inputs()[i].name,":",sess_ort.get_inputs()[i].shape)

event_info : ['unk__451', 5]
tracks : ['unk__452', 10, 13, 1]
em_barrel : ['unk__453', 56, 11, 4]
multiply_output_with : ['unk__454', 1]
scalars : ['unk__455', 16]


In [59]:
output_name = sess_ort.get_outputs()[0].name
print("output_name", output_name)
output_shape = sess_ort.get_outputs()[0].shape
print("output shape", output_shape)
output_type = sess_ort.get_outputs()[0].type
print("output type", output_type)

output_name multiply
output shape ['unk__456', 1]
output type tensor(float)


In [60]:
m = onnx.load('../Downloads/model.0046_2.9464.onnx')

In [61]:
output = []
for o in m.graph.output:
        output.append(o.name)

In [62]:
output

['multiply']

In [63]:
def load_dataset(filenames, tag,data_params, batch_size=2048, load_single_file=False, shuffle=True, merge=True,
                 additional_info=True, cache=False, multiple_with=True, cut_ratio=None,
                 training_in_data=False, autotune = tf.data.experimental.AUTOTUNE,time_lr= False, ext_energy_range=None): 
    options = tf.data.Options()
    
    # print(autotune)
    options.experimental_deterministic = not shuffle!=0
    options.experimental_threading.private_threadpool_size = 16
    
    file_path = tf.data.Dataset.list_files(filenames, shuffle=shuffle!=0
                                           ).with_options(options)
    dataset = tf.data.TFRecordDataset(
        file_path, compression_type='GZIP', num_parallel_reads=1 if load_single_file else autotune #autotune
        ).prefetch(buffer_size=autotune)  # automatically interleaves reads from multiple files tf.data.experimental.AUTOTUNE
    
    dataset = dataset.map(lambda x: read_tfrecord_new(x, tag, merge=merge,data_params=data_params),
                                                          num_parallel_calls=autotune)

    if shuffle:
        if isinstance(shuffle, bool):
            dataset = dataset.shuffle(12)
        else:
            dataset = dataset.shuffle(shuffle)
    if cut_ratio != None:
        dataset = dataset.filter(lambda x, label: tf.math.abs(x['event_info'][0]/label-1) < cut_ratio)
    if ext_energy_range != None:
        if ext_energy_range ==0:
            dataset = dataset.filter(lambda x, label: label/tf.math.cosh(x['event_info'][1]) > ext_energy_range)
        elif ext_energy_range > 0:
            dataset = dataset.filter(lambda x, label: label/tf.math.cosh(x['event_info'][1]) < ext_energy_range)
        elif ext_energy_range < 0:
            dataset = dataset.filter(lambda x, label: label/tf.math.cosh(x['event_info'][1]) > -ext_energy_range)

    dataset = dataset.batch(batch_size, drop_remainder=False)
    if cache:
        dataset = dataset.cache()    
        # pass
    return dataset

In [64]:
def read_tfrecord_new(example, tag, data_params, merge=True):
    data ={}
    tfrecord_format = {
                    "em_barrel_Lr0": tf.io.FixedLenFeature([56,11,], tf.float32),
                    "em_barrel_Lr1": tf.io.FixedLenFeature([56,11,], tf.float32),
                    "em_barrel_Lr2": tf.io.FixedLenFeature([56,11,], tf.float32),
                    "em_barrel_Lr3": tf.io.FixedLenFeature([56,11,], tf.float32),
                    "targets": tf.io.FixedLenFeature([], tf.float32),
                    "multiply_output_name": tf.io.FixedLenFeature([], tf.float32)
                        }
    for sca in data_params['scalar_names']:
        tfrecord_format[sca] = tf.io.FixedLenFeature([], tf.float32)
    for tra in data_params['track_names']:
        tfrecord_format[tra] = tf.io.FixedLenFeature([10,], tf.float32)
    for gate in data_params['gate_img_prefix']:
        names = ['_Lr0', '_Lr1', '_Lr2', '_Lr3']
        for i in names:
            tfrecord_format[gate+i] = tf.io.FixedLenFeature([56,11,], tf.float32)
    if data_params['additional_info']:
        if (tag=='Zmumugam'):
            tfrecord_format['event_info'] = tf.io.FixedLenFeature([15,], tf.float32)
        else:
            tfrecord_format['event_info'] = tf.io.FixedLenFeature([5,], tf.float32)
    dataset = tf.io.parse_single_example(example, tfrecord_format)
    data = {name: tf.cast(dataset[name], tf.float32, name=name) for name in tfrecord_format.keys()}
    label = data['targets']
    data.pop('targets')

    if merge: # merge all the images together so the channel = 4
        data['em_barrel']  = tf.stack([data['em_barrel_Lr0'], data['em_barrel_Lr1'], data['em_barrel_Lr2'], data['em_barrel_Lr3']], axis=-1)
        [data.pop(i) for i in ["em_barrel_Lr0", "em_barrel_Lr1", "em_barrel_Lr2", "em_barrel_Lr3"]]
        # if 'time_em_barrel_Lr0' in tfrecord_format.keys():     
        for gate in data_params['gate_img_prefix']:
            data[gate]  = tf.stack([data[gate+'_Lr0'], data[gate+'_Lr1'], data[gate+'_Lr2'], data[gate+'_Lr3']], axis=-1)
            [data.pop(i) for i in [gate+'_Lr0', gate+'_Lr1', gate+'_Lr2', gate+'_Lr3']]
            # data[gate] = tf.math.abs(data[gate])# <= 0.05 ## abs
    for i in ['track_names', 'scalar_names']:
        if data_params[i]:
            tracks = [data[i] for i in data_params[i]]
            data[i.split('_')[0]+'s']  = tf.stack(tracks, axis=-1 if i=='track_names' else 0)
            [data.pop(i) for i in data_params[i]]
    return data, label


In [65]:
import glob
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # -1 disable
if True:
    data = {}

    files = glob.glob('*.tfrecords') ## change dir
            
    tag = 'Zee'
        
    # data_path = f'../tfrecords_data/{tag}'
    particle = 'electrons' if tag=='Zee' else 'photons'
    data_conf = import_module(f'..{particle}_variables_conf',  'variables_params.subpkg')
    data_params = data_conf.get_params()
    if True:
        size = 0
        target = []
            
        time_image=[]
        em_image=[]
        scalar = []
        event_info=[]
        tracks=[]
        multiply_output_name=[]
        nr_of_cores= 8
        train= load_dataset(files[:],#files[:1],
                            tag=tag, shuffle=False, merge=True, load_single_file=True,
                            additional_info=True, autotune = nr_of_cores,
                            batch_size=500, time_lr = data_params['gate_img_prefix'], 
                            data_params=data_params, cut_ratio=None,
                            ext_energy_range=None).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

        n=1
        nr =0
        time=[]
        for t, e in tqdm(train.take(n), total = n):
            size += len(e)

            target.append(e.numpy())
            event_info.append(t['event_info'].numpy())
            scalar.append(t['scalars'].numpy())
            tracks.append(t['tracks'].numpy())
            multiply_output_name.append(t['multiply_output_name'].numpy())
            time_image.append(t['time_em_barrel'].numpy())
            em_image.append(t['em_barrel'].numpy())
            nr +=1
for i,j in zip([event_info, scalar, tracks, multiply_output_name,
                    time_image, em_image],['event_info', 'scalars', 'tracks',
                                           'multiply_output_name', 'time_em_barrel',
                                           'em_barrel']):
    data[j] = np.array(i)
        
    '''It is MC. Now data should be the input to the model and target the target'''
    '''You should be able to just import tfrecord_load_data'''
    '''and then run the line above to import the data'''

100%|██████████| 1/1 [00:00<00:00,  5.14it/s]


In [66]:
data.keys()

dict_keys(['event_info', 'scalars', 'tracks', 'multiply_output_name', 'time_em_barrel', 'em_barrel'])

In [67]:
print("event_info: ",data['event_info'].shape)
print("scalars: ",data['scalars'].shape)
print("tracks: ",data['tracks'].shape)
print("multiply_output_name: ",data['multiply_output_name'].shape)
print("time_em_barrel: ",data['time_em_barrel'].shape)
print("em_barrel: ",data['em_barrel'].shape)

event_info:  (1, 500, 5)
scalars:  (1, 500, 16)
tracks:  (1, 500, 10, 13)
multiply_output_name:  (1, 500)
time_em_barrel:  (1, 500, 56, 11, 4)
em_barrel:  (1, 500, 56, 11, 4)


In [68]:
data['event_info'] = data['event_info'].reshape(500,5)
data['scalars'] = data['scalars'].reshape(500,16)
data['tracks'] = data['tracks'].reshape(500,10, 13,1)
data['multiply_output_name'] = data['multiply_output_name'].reshape(500,1)
data['time_em_barrel'] = data['time_em_barrel'].reshape(500,56, 11, 4)
data['em_barrel'] = data['em_barrel'].reshape(500,56, 11, 4)

In [69]:
print("event_info: ",data['event_info'].shape)
print("scalars: ",data['scalars'].shape)
print("tracks: ",data['tracks'].shape)
print("multiply_output_name: ",data['multiply_output_name'].shape)
print("time_em_barrel: ",data['time_em_barrel'].shape)
print("em_barrel: ",data['em_barrel'].shape)

event_info:  (500, 5)
scalars:  (500, 16)
tracks:  (500, 10, 13, 1)
multiply_output_name:  (500, 1)
time_em_barrel:  (500, 56, 11, 4)
em_barrel:  (500, 56, 11, 4)


In [70]:
len(sess_ort.get_inputs())

5

In [71]:
input_feeds = {}
key = ['event_info', 'tracks', 'time_em_barrel','multiply_output_name', 'scalars']
for i in range(5):
    input_feeds[sess_ort.get_inputs()[i].name] = data[key[i]]

In [72]:
!input_feeds

/bin/bash: input_feeds: command not found


In [73]:
kerasPredict = model.predict(input_feeds)
kerasPredict_list = list(kerasPredict)
len(kerasPredict_list)

500

In [74]:
onnxPredict = sess_ort.run(output, input_feeds)

In [75]:
onnxPredict_np = np.array(onnxPredict)

In [76]:
onnxPredict_np = onnxPredict_np.reshape(500,1)
onnxPredict_list = list(onnxPredict_np)
len(onnxPredict_list)

500

In [78]:
kerasPredict_list[:10]

[array([3947.143], dtype=float32),
 array([380.76245], dtype=float32),
 array([0.], dtype=float32),
 array([241.66402], dtype=float32),
 array([266.27936], dtype=float32),
 array([433.56155], dtype=float32),
 array([462.17758], dtype=float32),
 array([0.], dtype=float32),
 array([260.5142], dtype=float32),
 array([91.381966], dtype=float32)]

In [79]:
onnxPredict_list[:10]

[array([3947.1467], dtype=float32),
 array([380.76288], dtype=float32),
 array([0.], dtype=float32),
 array([241.66383], dtype=float32),
 array([266.27927], dtype=float32),
 array([433.56192], dtype=float32),
 array([462.1776], dtype=float32),
 array([0.], dtype=float32),
 array([260.5143], dtype=float32),
 array([91.38199], dtype=float32)]

In [47]:
plt.figure()
plt.plot(kerasPredict_list, label='kerasModel')
plt.plot(onnxPredict_list, label='onnxModel')
plt.title('output from inference')
plt.ylim(0, 600)
plt.xlim(0, 10)
plt.xlabel('batchSize')
plt.ylabel('Modeloutpt')
plt.legend(loc='upper right')
plt.rcParams["figure.figsize"] = (10,6)
plt.savefig('../Downloads/kerasVsOnnx_1.png', dpi = 100)