In [1]:
import os
import tensorflow as tf
import sys
import numpy as np
import ROOT
import glob
import time

Welcome to JupyROOT 6.18/04


## TFRecord related methods

In [2]:
################################################
################################################
def _array_float32_feature(ndarray):
    return lambda array: tf.train.Feature(float_list=tf.train.FloatList(value=array.reshape(-1)))
################################################
##FIXME: reduce precision to 32 bits
################################################
def _array_int64_feature(ndarray):
    return lambda array: tf.train.Feature(int64_list=tf.train.Int64List(value=array.reshape(-1)))
################################################
################################################
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
################################################
################################################
def create_features(x, y):
    dtype_feature_x = _array_float32_feature(x)
    dtype_feature_y = _array_int64_feature(y)
    d_feature = {}
    d_feature['UVW_data'] = dtype_feature_x(x)
    d_feature['label'] = dtype_feature_y(y)
    features = tf.train.Features(feature = d_feature)
    return features
################################################
################################################
def createWriter(fileName):
    result_tf_file = fileName + '.tfrecords'
    writer = tf.io.TFRecordWriter(result_tf_file)
    return writer, result_tf_file
################################################
################################################
def saveSingleExampleToTFRecord(writer, features):
    example = tf.train.Example(features = features)
    serialized = example.SerializeToString()
    writer.write(serialized)
################################################
################################################

## Methods reading the ROOT files

In [None]:
################################################
################################################
def get_numpy_from_histo(root_histo):
    nBinsX = root_histo.GetNbinsX()
    nBinsY = root_histo.GetNbinsY()
    nBinsY = 92 ##FIXME: Need uniform frame size
    numpy_histo = np.zeros((nBinsX,nBinsY))
    for iBinX in range(0, nBinsX):
        for iBinY in range(0, nBinsY):
            numpy_histo[iBinX, iBinY] = root_histo.GetBinContent(iBinX, iBinY)
    return numpy_histo
################################################
##FIX ME: provide source of labels
################################################
def read_root(fileName, normalize=False):
    projection_histos = {"U":0, "V":0, "W":0}
    rootFile = ROOT.TFile(fileName,"r")
    keysList = rootFile.GetListOfKeys()
    numberOfHistos = len(keysList)
    label = 0
    index = 0
    while index<numberOfHistos:
        objName = keysList[index].GetName()
        eventId = objName.split("evt")[1]
        for projName in projection_histos.keys():
            histo_name = "hraw_"+projName+"_vs_time_evt"+eventId
            root_histo = rootFile.Get(histo_name)
            numpy_histo = get_numpy_from_histo(root_histo)
            if normalize: 
                maxValue = np.amax(numpy_histo)
                numpy_histo = np.where(numpy_histo<0,0,numpy_histo/maxValue)
            print("histo name: ",histo_name," shape:",numpy_histo.shape)
            numpy_histo = np.pad(numpy_histo, ((0,0),(0,32),(0,0)))
            print("histo name: ",histo_name," shape:",numpy_histo.shape)
            projection_histos[projName] = numpy_histo
       
        features = np.stack(arrays=list(projection_histos.values()), axis=2)
        features_cropped = np.stack(arrays=list(projection_histos.values()), axis=2)
        labels = np.array([0,0])
        if index==0:
            print("features.shape: ",features.shape)
            print("labels.shape: ",labels.shape)
        index += 3 
        #print("eventId:",eventId)
        yield features, labels 
################################################
##Test histogram reading
################################################
fileName = "/scratch/akalinow/ProgrammingProjects/MachineLearning/ELITPC/data/UVWProjections_2018-06-19T15:13:33.941_0008.root"
for item in read_root(fileName):
    print(item)
    break
    
f = ROOT.TFile(fileName)
h_U = f.Get("hraw_W_vs_time_evt15416")
c1 = ROOT.TCanvas()
c1.Draw()
h_U.Draw("col")

## The final conversion methods

In [15]:
################################################
################################################
def ROOT_to_TFRecord(normalize=False):
    path = "/scratch/akalinow/ProgrammingProjects/MachineLearning/ELITPC/data/"
    fileName = "UVWProjections_2018-06-19T15:13:33.941_0008"
    number_of_files = len(glob.glob(path + "*.root"))
    number_of_examples = 0
    print ("Found {} files".format(number_of_files))
    start_time = time.perf_counter()
    writer, result_tf_file = createWriter(fileName)
    for idx, file in enumerate(glob.glob(path + "*.root")):
        
        for numpy_histogram, labels in read_root(file, normalize=normalize):
            features = create_features(numpy_histogram, labels)
            saveSingleExampleToTFRecord(writer, features)
            number_of_examples+=1
        if idx == number_of_files - 1:
            writer.close()
            print ("Serializing {} examples into {} done!".format(number_of_examples,result_tf_file))
            print("Execution time: {:.2f} seconds".format(time.perf_counter() - start_time))
################################################
################################################
ROOT_to_TFRecord(normalize=True)
################################################
################################################

Found 1 files
features.shape:  (512, 92, 3)
labels.shape:  (2,)
Serializing 1927 examples into UVWProjections_2018-06-19T15:13:33.941_0008.tfrecords done!
Execution time: 175.17 seconds


## Read the TFRecord format

In [None]:
featuresShape = (512, 92, 3)
cropped_featuresShape = (64, 64, 3)
labelsShape = (2,)
################################################
################################################
feature_description = {
    'UVW_data': tf.io.FixedLenFeature(featuresShape, tf.float32),
    'label': tf.io.FixedLenFeature(labelsShape, tf.int64),
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)
################################################
################################################
def readTFRecordFile(fileNames):
    raw_dataset = tf.data.TFRecordDataset(fileNames)
    return raw_dataset.map(_parse_function)
################################################
################################################
fileNames = ["UVWProjections_2018-06-19T15:13:33.941_0008.tfrecords"]
dataset = readTFRecordFile(fileNames)

for item in dataset.take(1):
      print(repr(item))      