# Single Inference

> Warning: The major caveat here is that this really should be run via script and **NOT** in a notebook. That being said, this is the best way to introduce the concepts here.

In [1]:
import caffe
from larcv import larcv

Welcome to ROOTaaS 6.06/08


The below instantiates the network based off the description in the prototxt and the saved pre-trained snapshot.

In [None]:
net = caffe.Net('/workspace/plainresnet10b/sp_plainresnet10b_test.prototxt',
                '/workspace/plainresnet10b/sp_plainresnet10b_iter_195750.caffemodel.h5',
                caffe.TEST)

In [None]:
filler_name = 'ThreadDatumFiller'
# check if larcv IO processor does in fact exist and registered in a factory
if not larcv.ThreadFillerFactory.exist_filler(filler_name):
    print '\033[93mFiller',filler_name,'does not exist...\033[00m'

# get IO instance, ThreadDatumFiller instance, from the factory
filler = larcv.ThreadFillerFactory.get_filler(filler_name)

# get # events to be processed 
num_events = filler.get_n_entries()

In [None]:
# force random access to be false for an inference
filler.set_random_access(False)

# construct our own IO to fetch ROI object for physics analysis, use RED mode w/ same input files
myio = larcv.IOManager(0,"AnaIO")
for f in filler.pd().io().file_list():
    myio.add_in_file(f)
myio.initialize()
print
print '\033[95mTotal number of events\033[00m:',num_events
print '\033[95mBatch size\033[00m:', self._batch_size
print


In [None]:
      event_counter = 0    # this variable denotes which TTree entry we are @ in the loop below
        stop_counter  = 1e10 # well, unused, but one can set a break condition by configuring this parameter
        correct_count = 0    # counter for correctly labeled events

        # now continue a loop till the end of the input file (event list)
        while 1:

            # if previous result is loaded, check if we should process the current entry or not
            if done_list and (event_counter in done_list):
                event_counter+=1
                continue

            # force the filler to move the next event-to-read pointer to the entry of our interest
            filler.set_next_index(event_counter)

            # number of entries we expect to process in this mini-batch
            num_entries = num_events - event_counter
            if num_entries > self._batch_size: 
                num_entries = self._batch_size

            # now run the network for a mini-batch, sleep while the thread is running
            net.forward()
            while filler.thread_running():
                time.sleep(0.001)

            # retrieve ROI product producer from the filler, so we can read-in ROI products through myroi 
            roi_producer = filler.producer(1)

            # get a vector of integers that record TTree entry numbers processed in this mini-batch
            entries = filler.processed_entries()
            if entries.size() != self._batch_size:
                print "\033[93mBatch counter mis-match!\033[00m"
                raise Exception

            # retrieve data already read-and-stored-in-memory from caffe blob
            adcimgs = net.blobs["data"].data    # this is image
            labels  = net.blobs["label"].data   # this is label
            scores  = net.blobs["softmax"].data # this is final output softmax vector
            
            # loop over entry of mini-batch outcome
            for index in xrange(num_entries):
                
                if not entries[index] == event_counter:
                    print '\033[93mLogic error... inconsistency found in expected entry (%d) vs. processing entry (%d)' % (event_counter,entries[index])
                    self.__class__._terminate = True
                    break
                # skip if this is alredy recorded entry
                if done_list and (event_counter in done_list):
                    event_counter +=1
                    continue
                
                # declare csv_vals dictionary instance, and fill necessary key-value pairs.
                # later we have an explicit check if all keys are filled.
                # this is helpful to avoid a mistake when someone udpate later the script
                # to include/exclude variables in CSV_VARS definition and forgot to update this
                # portion of the code.
                csv_vals={}
                adcimg = adcimgs [index] # ADC raw image
                label  = labels  [index]  # Labels
                score  = scores  [index]  # results
                # fill things that can be filled from caffe blob
                csv_vals['entry'  ] = entries[index]
                csv_vals['npx'    ] = (adcimg > 0).sum()
                csv_vals['label'  ] = int(label)
                csv_vals['prediction'] = score.argmax()
                csv_vals['eminus' ] = score[0]
                csv_vals['gamma'  ] = score[1]
                csv_vals['muminus'] = score[2]
                csv_vals['piminus'] = score[3]
                csv_vals['proton' ] = score[4]

                if int(label) == score.argmax():
                    correct_count += 1
                
                # now get ROI data from myroi, our separate IO handle, to record physics parameters
                myio.read_entry(entries[index])
                event_roi = myio.get_data(1,roi_producer)
                
                csv_vals['nparticle']=0
                csv_vals['ndecay']=0
                csv_vals['energy_dep']=0.
                # loop over ROIs
                for roi in event_roi.ROIArray():
                    if roi.MCSTIndex() == larcv.kINVALID_USHORT:
                        # ROI from simb::MCTruth
                        csv_vals['energy_start']=roi.EnergyInit()
                        csv_vals['mass'] = larcv.ParticleMass(roi.PdgCode())
                        px,py,pz = (roi.Px(),roi.Py(),roi.Pz())
                        ptot = np.sqrt(np.power(px,2)+np.power(py,2)+np.power(pz,2))
                        csv_vals['mom_start'] = ptot
                        csv_vals['dcosx_start'] = px/ptot
                        csv_vals['dcosy_start'] = py/ptot
                        csv_vals['dcosz_start'] = pz/ptot
                    else:
                        # ROI from sim::MCShower and sim::MCTrack
                        csv_vals['nparticle']+=1
                        if roi.ParentTrackID() == roi.TrackID():
                            csv_vals['energy_dep'] = roi.EnergyDeposit()
                        elif np.abs(roi.PdgCode()) == 13 and np.abs(roi.ParentPdgCode()) == 211:
                            csv_vals['ndecay'] += 1
                        elif np.abs(roi.PdgCode()) == 11 and np.abs(roi.ParentPdgCode()) == 13:
                            csv_vals['ndecay'] += 1
                # record in csv format
                line = ''
                for v in CSV_VARS:
                    try:
                        line += '%s,' % str(csv_vals[v])
                    except KeyError:
                        print '\033[93mCould not locate field\033[00m:',v
                        self.__class__._terminate=True
                        break
                line=line.rstrip(',')
                line+='\n'
                fout.write(line)

                # break if stop counter is met
                event_counter += 1

                # update an user which entry we are processing
                sys.stdout.write('Processed entry %d ... accuracy @ %g    \r' % (event_counter,float(correct_count)/event_counter))

                if event_counter >= stop_counter:
                    break
                # break if termination is called
                if self.__class__._terminate:
                    break

            # break if all entries are processed
            if num_entries < self._batch_size:
                break
            # break if stop counter is met
            if event_counter >= stop_counter:
                break
            # break if termination is called
            if self.__class__._terminate:
                print
                print '\033[93mAborting upon kernel kill signal...\033[00m'
                break
        print
        # close outputs and input io
        fout.close()
        myio.finalize()
        # destroy thread filler via factory, an owner
        larcv.ThreadFillerFactory.destroy_filler(self._filler_name)