# Batch Inference

> Warning: Much like the previous step, this suffers from running on a CPU and from withing an iPython notebook.

In [1]:
import matplotlib 
# This can also be ipython magic such as `matplotlib inline` or `matplotlib notebook`
matplotlib.use('Agg')

import os,commands,signal,sys,tempfile
os.environ['GLOG_minloglevel'] = '2' # set message level to warning
import pandas
import caffe
import numpy as np
import ROOT
import lmdb
import time
#imports the larcv tools from ROOT
from larcv import larcv
import matplotlib.pyplot as plt

Welcome to ROOTaaS 6.06/08


This function very simply finds all of the files matching a given range of iterations. The idea is to re-run the metric for accuracy on each one and plot out the accuracy as a function of iteration or `iteration/200000:=epoch`.

In [None]:
def find_iter(directory, start, stop):
    logging.debug("Finding iteration files for {} in [{},{}]".format(directory,\
 start, stop))
    files = [ i for i in os.listdir(directory) if i.endswith('caffemodel.h5')]
    ret=[]
    for _file in files:
        iter_num=int(_file.split('_')[-1].split('.')[0])
        if iter_num>=start and iter_num<=stop:
            ret.append(os.path.join(directory,_file))
    logging.debug("Found: [{}] Entries".format(len(ret)))
    return ret

This cell creates a tuple of two lists. The first list holds the iteration number of each snapshot and the second holds the accuracy of the snapshot.

In [None]:
def accuracy_as_a_function_of_iteration(directory, start, stop, model):
    files = find_iter(directory, start, stop)
    out = ([],[])
    for _file in files:
        iter_num=int(_file.split('_')[-1].split('.')[0])
        net = caffe.Net(model, _file,caffe.TEST)
        filler_name = 'DataFiller'
        # check if larcv IO processor does in fact exist and registered in a factory
        if not larcv.ThreadFillerFactory.exist_filler(filler_name):
            print '\033[93mFiller',filler_name,'does not exist...\033[00m'


        # get IO instance, ThreadDatumFiller instance, from the factory
        filler = larcv.ThreadFillerFactory.get_filler(filler_name)
        filler.pd().random_access(False)
        # get num events to be processed 
        num_events = filler.get_n_entries()

        # construct our own IO to fetch ROI object for physics analysis, use RED mode w/ same input files
        myio = larcv.IOManager(0,"AnaIO")
        myio.add_in_file('../test.root')
        myio.initialize()
        # force the filler to move the next event-to-read pointer to the entry of our interest
        filler.set_next_index(0)

        accuracy=0.0
        for _iter in range(1000):
            # This will take some time. Esp w/o a GPU
            net.forward()

            # Wait until the filler is done filling the buffer
            while filler.thread_running():
                time.sleep(0.001)

            # get a vector of integers that record TTree entry numbers processed in this mini-batch
            entries = filler.processed_entries()

            # retrieve data already read-and-stored-in-memory from caffe blob
            adcimgs = net.blobs["data"].data    # this is image
            labels  = net.blobs["label"].data   # this is label
            scores  = net.blobs["softmax"].data # this is final output softmax vector
            # loop over entry of mini-batch outcome
            for index in xrange(1):
                print "Entry: ", entries[index]
                print "npx: ", (adcimgs > 0).sum()
                print "Label: ", int(labels)
                print "Prediction: ", scores.argmax()
                print "EMinus: ", scores[index][0]
                print "Gamma: ", scores[index][1]
                print "MuMinus: ", scores[index][2]
                print "PiMinus: ", scores[index][3]
                print "Proton: ", scores[index][4]

                if int(labels) == scores.argmax():
                    accuracy+=1.0
        accuracy = accuracy/1000.0
        out[0].append(iter_num)
        out[1].append(accuracy)
        myio.finalize()
        # destroy thread filler via factory, an owner
        larcv.ThreadFillerFactory.destroy_filler(filler_name)       
    return out

In [8]:
x,y = accuracy_as_a_function_of_iteration('/workspace/plainresnet10b/', 0, 10000, 
                                          '/workspace/plainresnet10b_untrained/sp_plainresnet10b_test.prototxt')
fig = plt.figure()
plt.plot(x, y)
plt.show()  