# PyTorch Classification Example

In this notebook, we're going to use ResNet-18 implemented in pyTorch to classify the 5-particle example training data.

This tutorial is meant to walk through some of the necessary steps to load images stored in LArCV files and train a network.  For more details on how to use pytorch, refer to the official pytorch tutorials.

This notebook will try to be self-contained in terms of code. 
However, you can find the code separated into different files in the following repositories

* LArCVDataset: concrete instance of pytorch Dataset class written for LArCV2 IO
* pytorch-classification-example: many of the files and scripts found in this tutorial

You will also need the training data. Go to the [open data page](http://deeplearnphysics.org/DataChallenge/) and download the either the 5k or 50k training/validation samples.


In [1]:
# Import our modules

# python
import os,sys
import shutil
import time
import traceback

# numpy
import numpy as np

# torch
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

# ROOT/LArCV
import ROOT
from larcv import larcv

Welcome to JupyROOT 6.12/04


# Setup Data IO

## Location of data on your local machine

Set the path to the data files in this block.

In [2]:
path_to_train_data="/home/taritree/working/dlphysics/testset/train_50k.root"
path_to_test_data="/home/taritree/working/dlphysics/testset/test_40k.root"

## Define LArCVDataset

First, we define a class that will load our data. There is many ways to do this. We create a concrete instance of pytorch's `Dataset` class, which can be used in the `DataLoader` class (which we do not use).

In [3]:
# from: https://github.com/deeplearnphysics/larcvdataset

larcv.PSet # touch this to force libBase to load, which has CreatePSetFromFile
from larcv.dataloader2 import larcv_threadio
from torch.utils.data import Dataset

class LArCVDataset(Dataset):
    """ LArCV data set interface for PyTorch"""

    def __init__( self, cfg, verbosity=0 ):
        self.verbosity = verbosity
        self.batchsize = None

        # we setup the larcv threadfiller class, which handles io from larcv files
        # this follows steps from larcv tutorials
        
        # setup cfg dictionary needed for larcv_threadio
        self.cfg = cfg        
        self.filler_cfg = {}
        self.filler_cfg["filler_name"] = "ThreadProcessor"
        self.filler_cfg["verbosity"]   = self.verbosity
        self.filler_cfg["filler_cfg"]  = self.cfg
        if not os.path.exists(self.cfg):
            raise ValueError("Could not find filler configuration file: %s"%(self.cfg))

        # we read the first line of the config file, which should have name of config parameter set
        linepset = open(self.cfg,'r').readlines()
        self.cfgname = linepset[0].split(":")[0].strip()
        
        # we load the pset ourselves, as we want access to values in 'ProcessName' list
        # will use these as the names of the data products loaded. store in self.datalist
        self.pset = larcv.CreatePSetFromFile(self.cfg,self.cfgname).get("larcv::PSet")(self.cfgname)
        datastr_v = self.pset.get("std::vector<std::string>")("ProcessName")
        self.datalist = []
        for i in range(0,datastr_v.size()):
            self.datalist.append(datastr_v[i])
        
        # finally, configure io
        self.io = larcv_threadio()        
        self.io.configure(self.filler_cfg)

    def __len__(self):
        return int(self.io.fetch_n_entries())

    def __getitem__(self, idx):
        self.io.next()
        out = {}
        for name in self.datalist:
            out[name] = self.io.fetch_data(name).data()
        return out
        
    def __str__(self):
        return dumpcfg()

    def start(self,batchsize):
        """exposes larcv_threadio::start which is used to start the thread managers"""
        self.batchsize = batchsize
        self.io.start_manager(self.batchsize)

    def stop(self):
        """ stops the thread managers"""
        self.io.stop_manager()

    def dumpcfg(self):
        """dump the configuration file to a string"""
        print open(self.cfg).read()

## Write configuration files for the LArCV ThreadFiller class

We define the configurations in this block, then write to file. We will load the files later when we create LArCVDataset instances for both the training and test data.

In [4]:
train_cfg="""ThreadProcessor: {
  Verbosity:3
  NumThreads: 3
  NumBatchStorage: 3
  RandomAccess: true
  InputFiles: ["%s"]  
  ProcessName: ["image","label"]
  ProcessType: ["BatchFillerImage2D","BatchFillerPIDLabel"]
  ProcessList: {
    image: {
      Verbosity:3
      ImageProducer: "data"
      Channels: [2]
      EnableMirror: true
    }
    label: {
      Verbosity:3
      ParticleProducer: "mctruth"
      PdgClassList: [2212,11,211,13,22]
    }
  }

"""%(path_to_train_data)

test_cfg="""ThreadProcessorTest: {
  Verbosity:3
  NumThreads: 2
  NumBatchStorage: 2
  RandomAccess: true
  InputFiles: ["%s"]
  ProcessName: ["imagetest","labeltest"]
  ProcessType: ["BatchFillerImage2D","BatchFillerPIDLabel"]
  ProcessList: {
    imagetest: {
      Verbosity:3
      ImageProducer: "data"
      Channels: [2]
      EnableMirror: false
    }
    labeltest: {
      Verbosity:3
      ParticleProducer: "mctruth"
      PdgClassList: [2212,11,211,13,22]
    }
  }
}

"""%(path_to_test_data)

train_cfg_out = open("train_dataloader.cfg",'w')
print >> train_cfg_out,train_cfg
train_cfg_out.close()

test_cfg_out  = open("test_dataloader.cfg",'w')
print >> test_cfg_out,test_cfg
test_cfg_out.close()

# Setup Network

## Define network

We use ResNet-18 as implemented in the torchvision module.  We reproduce it here and make a slight modification: we change the number of input channels from 3 to 1.  The original resnet expects an RGB image.  For our example, we only use the image from one plane from our hypothetical LAr TPC detector.

Original can be found [here](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py).

In [5]:
import torch.nn as nn
import math

# define convolution without bias that we will use throughout the network
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


# implements one ResNet unit
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    
# define the network. It provides options for 
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, input_channels=3):
        """
        inputs
        ------
        block: type of resnet unit
        layers: list of 4 ints. defines number of basic block units in each set of resnet units
        num_classes: output classes
        input_channels: number of channels in input images
        """
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # had to change stride of avgpool from original from 1 to 2
        self.avgpool = nn.AvgPool2d(7, stride=2)

        # I've added dropout to the network
        self.dropout = nn.Dropout2d(p=0.5,inplace=True)

        #print "block.expansion=",block.expansion                                                                                                                                                           
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        #print "avepool: ",x.data.shape                                                                                                                                                                     
        x = x.view(x.size(0), -1)
        #print "view: ",x.data.shape                                                                                                                                                                        
        x = self.fc(x)

        return x


    
# define a helper function for ResNet-18
def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.                                                                                                                                                                        
                                                                                                                                                                                                            
    Args:                                                                                                                                                                                                   
        pretrained (bool): If True, returns a model pre-trained on ImageNet                                                                                                                                 
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model
