## Notebook to test commands for semantic segmentation data and network

## Import Modules

In [None]:
# python
import os,sys,commands,time

# ROOT/larcv
import ROOT
from larcv import larcv

# torch
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
#import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

# Set path to larcvdataset repository
# to get it: clone https://github.com/deeplearnphysics/larcvdataset
# uncommen the next two lines to use it
#_PATH_TO_LARCVDATASET_REPO_="location of "
#sys.path.append(_PATH_TO_LARCVDATASET_REPO_)
import larcvdataset

# Set path to pytorch-uresnet
# to get it: clone https://github.com/deeplearnphysics/pytorch-uresnet
# uncommen the next two lines to use it
#_PATH_TO_PYTORCHURESNET_REPO_="location of "
#sys.path.append(_PATH_TO_PYTORCHURESNET_REPO_)
import uresnet

%matplotlib notebook
import matplotlib.pyplot as plt

### Definitions of LArCVDataset and UResNet

If helpful, you can dump out information about the LArCVDataset and UResNet classes.

* LArCVDataset: provides interface to data within a larcv root file
* UResNet: UNet with resnet modules. (cite)


In [None]:
# uncomment to see
#help(larcvdataset.LArCVDataset)
#help(uresnet.UResNet)

# Create an instance of the network

In [None]:
net = uresnet.UResNet(inplanes=16,input_channels=1,num_classes=3,showsizes=True)

In [None]:
# uncomment dump network definition
# print net

In [None]:
# load up on the GPU
net.cuda()

In [None]:
# create loss function
class PixelWiseNLLLoss(nn.modules.loss._WeightedLoss):
    def __init__(self,weight=None, size_average=True, ignore_index=-100 ):
        super(PixelWiseNLLLoss,self).__init__(weight,size_average)
        self.ignore_index = ignore_index
        self.reduce = False
        self.mean = torch.mean.cuda()

    def forward(self,predict,target,pixelweights):
        """
        predict: (b,c,h,w) tensor with output from logsoftmax
        target:  (b,h,w) tensor with correct class
        pixelweights: (b,h,w) tensor with weights for each pixel
        """
        _assert_no_grad(target)
        _assert_no_grad(pixelweights)
        # reduce for below is false, so returns (b,h,w)
        pixelloss = F.nll_loss(predict,target, self.weight, self.size_average, self.ignore_index, self.reduce)
        return self.mean(pixelloss*pixelweights)

lossfcn = PixelWiseNLLLoss()

## Setup Configuration File

The LArCVDataset class is basically a wrapper around larcv.ThreadFiller. 
To configure it, one needs to provide a configuration file, which we write here.

Remember to point to the location of the input larcv root file.

In [None]:
# write threadfiller io configuration file
ioconfig = """ThreadProcessorTest: {
  Verbosity:3
  NumThreads: 2
  NumBatchStorage: 2
  RandomAccess: true
  InputFiles: ["/home/taritree/working/dlphysics/pytorch-uresnet/practice_test_2k.root"]
  ProcessName: ["imagetest","segmenttest"]
  ProcessType: ["BatchFillerImage2D","BatchFillerImage2D"]
  ProcessList: {
    imagetest: {
      Verbosity:3
      ImageProducer: "data"
      Channels: [2]
      EnableMirror: false
    }
    segmenttest: {
      Verbosity:3
      ImageProducer: "segment"
      Channels: [2]
      EnableMirror: false
    }
  }
}
"""
with open("test_threadfiller.cfg",'w') as f:
    print >> f,ioconfig

# Create an instance of LArCVDataset using our configuration file

In [None]:
# create larcvdataset instance
io = larcvdataset.LArCVDataset("test_dataloader.cfg","ThreadProcessorTest")

### Start up the LArCVDataset

When started, the object will launch threads that are responsible for taking data from the root file and putting it into a dictionary of numpy arrays.

When we start we, we need to pass in the batchsize.

In [None]:
io.start(1)

## Get a batch

We use the `[ ]` operator to get our first batch. Note: the argument is currently meaningless.

In [None]:
# get the batch: returns a dictionary of numpy arrays
data = io[0]
print data.keys()

# get the individual elements
# img: numpy array with the image
# seg: numpy array with the class labels
img = data["imagetest"]
seg = data["segmenttest"]

# we want to reshape the arrays into (batch, channels, H, W)
img = img.reshape((1,1,256,256))
seg = seg.reshape((1,256,256))
wgt = np.ones( (1,256,256) )

## Define a function to plot the images

In [None]:
def showImgAndLabels(image2d,label2d):
    # Dump images
    fig, (ax0,ax1) = plt.subplots(1, 2, figsize=(10,10), facecolor='w')
    ax0.imshow(image2d, interpolation='none', cmap='jet', origin='lower')
    ax1.imshow(label2d, interpolation='none', cmap='jet', origin='lower',vmin=0., vmax=3.1)
    ax0.set_title('Data',fontsize=20,fontname='Georgia',fontweight='bold')
    #ax0.set_xlim(xlim)
    #ax0.set_ylim(ylim)
    ax1.set_title('Label',fontsize=20,fontname='Georgia',fontweight='bold')
    #ax1.set_xlim(xlim)
    #ax1.set_ylim(ylim)
    plt.show()

## Plot the data

In [None]:
showImgAndLabels(img.reshape((256,256)),seg.reshape((256,256)) )

In [None]:
# convert numpy array to torch array
timage  = torch.from_numpy(img).cuda()
ttarget = torch.from_numpy(seg).cuda()
tweight = torch.from_numpy(wgt).cuda()

# convert to torch autograd variable
image_var = torch.autograd.Variable(timage)
target_var = torch.autograd.Variable(ttarget)
weight_var = torch.autograd.Variable(tweight)

In [None]:
# push through the net to test it

s = time.time()
output = net(image_var)
s = time.time()-s
print "forward time: ",s
# note: first time is slow, about 600 ms, (as network allocating mem?)
#       next forward pass is about 15 ms

In [None]:
# loss
s = time.time()
loss = lossfcn(output,target_var,weight_var)

## Stop the LArCVDataset interface

When stopped. The threads resonsible for reading in data are terminated.

In [None]:
io.stop()