# PyTorch Classification Example

In this notebook, we're going to use ResNet-18 implemented in pyTorch to classify the 5-particle example training data.

This tutorial is meant to walk through some of the necessary steps to load images stored in LArCV files and train a network.  For more details on how to use pytorch, refer to the official pytorch tutorials.

This notebook will try to be self-contained in terms of code. 
However, you can find the code separated into different files in the following repositories

* LArCVDataset: concrete instance of pytorch Dataset class written for LArCV2 IO
* pytorch-classification-example: many of the files and scripts found in this tutorial

You will also need the training data. Go to the [open data page](http://deeplearnphysics.org/DataChallenge/) and download the either the 5k or 50k training/validation samples.


In [1]:
# Import our modules

# python
import os,sys
import shutil
import time
import traceback

# numpy
import numpy as np

# torch
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

# ROOT/LArCV
import ROOT
from larcv import larcv

Welcome to JupyROOT 6.12/04


## Set the GPU to use

In [2]:
torch.cuda.device( 1 )

<torch.cuda.device at 0x7f1fc0135cd0>

# Setup Data IO

## Location of data on your local machine

Set the path to the data files in this block.

In [3]:
path_to_train_data="/home/taritree/working/dlphysics/testset/train_50k.root"
path_to_test_data="/home/taritree/working/dlphysics/testset/test_40k.root"
if not os.path.exists(path_to_train_data):
    print "Could not find the training data file."
if not os.path.exists(path_to_test_data):
    print "Could not find the validation data file."

## Define LArCVDataset

First, we define a class that will load our data. There are many ways to do this. We create a concrete instance of pytorch's `Dataset` class, which can be used in the `DataLoader` class (which we do not use).

In [4]:
# from: https://github.com/deeplearnphysics/larcvdataset

larcv.PSet # touch this to force libBase to load, which has CreatePSetFromFile
from larcv.dataloader2 import larcv_threadio
from torch.utils.data import Dataset

class LArCVDataset(Dataset):
    """ LArCV data set interface for PyTorch"""

    def __init__( self, cfg, fillername, verbosity=0, loadallinmem=False, randomize_inmem_data=True, max_inmem_events=-1 ):
        self.verbosity = verbosity
        self.batchsize = None
        self.randomize_inmem_data = randomize_inmem_data
        self.max_inmem_events = max_inmem_events
        self.loadallinmem = loadallinmem
        self.cfg = cfg  

        # we setup the larcv threadfiller class, which handles io from larcv files
        # this follows steps from larcv tutorials
        
        # setup cfg dictionary needed for larcv_threadio      
        self.filler_cfg = {}
        self.filler_cfg["filler_name"] = fillername
        self.filler_cfg["verbosity"]   = self.verbosity
        self.filler_cfg["filler_cfg"]  = self.cfg
        if not os.path.exists(self.cfg):
            raise ValueError("Could not find filler configuration file: %s"%(self.cfg))

        # we read the first line of the config file, which should have name of config parameter set
        linepset = open(self.cfg,'r').readlines()
        self.cfgname = linepset[0].split(":")[0].strip()
        
        # we load the pset ourselves, as we want access to values in 'ProcessName' list
        # will use these as the names of the data products loaded. store in self.datalist
        self.pset = larcv.CreatePSetFromFile(self.cfg,self.cfgname).get("larcv::PSet")(self.cfgname)
        datastr_v = self.pset.get("std::vector<std::string>")("ProcessName")
        self.datalist = []
        for i in range(0,datastr_v.size()):
            self.datalist.append(datastr_v[i])
        
        # finally, configure io
        self.io = larcv_threadio()        
        self.io.configure(self.filler_cfg)
        
        if self.loadallinmem:
            self._loadinmem()

    def __len__(self):
        if not self.loadallinmem:
            return int(self.io.fetch_n_entries())
        else:
            return int(self.alldata[self.datalist[0]].shape[0])

    def __getitem__(self, idx):
        if not self.loadallinmem:
            self.io.next()
            out = {}
            for name in self.datalist:
                out[name] = self.io.fetch_data(name).data()
        else:
            indices = np.random.randint(len(self),size=self.batchsize)
            out = {}
            for name in self.datalist:
                out[name] = np.zeros( (self.batchsize,self.alldata[name].shape[1]), self.alldata[name].dtype )
                for n,idx in enumerate(indices):
                    out[name][n,:] = self.alldata[name][idx,:]
        return out
        
    def __str__(self):
        return dumpcfg()
    
    def _loadinmem(self):
        """load data into memory"""
        nevents = int(self.io.fetch_n_entries())
        if self.max_inmem_events>0 and nevents>self.max_inmem_events:
            nevents = self.max_inmem_events

        print "Attempting to load all ",nevents," into memory. good luck"
        start = time.time()

        # start threadio
        self.start(1)

        # get one data element to get shape
        self.io.next()
        firstout = {}
        for name in self.datalist:
            firstout[name] = self.io.fetch_data(name).data()
            self.alldata = {}
        for name in self.datalist:
            self.alldata[name] = np.zeros( (nevents,firstout[name].shape[1]), firstout[name].dtype )
            self.alldata[name][0] = firstout[name][0,:]
        for i in range(1,nevents):
            self.io.next()
            if i%100==0:
                print "loading event %d of %d"%(i,nevents)
            for name in self.datalist:
                out = self.io.fetch_data(name).data()
                self.alldata[name][i,:] = out[0,:]

        print "elapsed time to bring data into memory: ",time.time()-start,"sec"
        self.stop()

    def start(self,batchsize):
        """exposes larcv_threadio::start which is used to start the thread managers"""
        self.batchsize = batchsize
        self.io.start_manager(self.batchsize)

    def stop(self):
        """ stops the thread managers"""
        self.io.stop_manager()

    def dumpcfg(self):
        """dump the configuration file to a string"""
        print open(self.cfg).read()
        


## Write configuration files for the LArCV ThreadFiller class

We define the configurations in this block, then write to file. We will load the files later when we create LArCVDataset instances for both the training and test data.

In [5]:
train_cfg="""ThreadProcessor: {
  Verbosity:3
  NumThreads: 3
  NumBatchStorage: 3
  RandomAccess: true
  InputFiles: ["%s"]  
  ProcessName: ["image","label"]
  ProcessType: ["BatchFillerImage2D","BatchFillerPIDLabel"]
  ProcessList: {
    image: {
      Verbosity:3
      ImageProducer: "data"
      Channels: [2]
      EnableMirror: true
    }
    label: {
      Verbosity:3
      ParticleProducer: "mctruth"
      PdgClassList: [2212,11,211,13,22]
    }
  }
}
"""%(path_to_train_data)

test_cfg="""ThreadProcessorTest: {
  Verbosity:3
  NumThreads: 2
  NumBatchStorage: 2
  RandomAccess: true
  InputFiles: ["%s"]
  ProcessName: ["imagetest","labeltest"]
  ProcessType: ["BatchFillerImage2D","BatchFillerPIDLabel"]
  ProcessList: {
    imagetest: {
      Verbosity:3
      ImageProducer: "data"
      Channels: [2]
      EnableMirror: false
    }
    labeltest: {
      Verbosity:3
      ParticleProducer: "mctruth"
      PdgClassList: [2212,11,211,13,22]
    }
  }
}
"""%(path_to_test_data)

train_cfg_out = open("train_dataloader.cfg",'w')
print >> train_cfg_out,train_cfg
train_cfg_out.close()

test_cfg_out  = open("valid_dataloader.cfg",'w')
print >> test_cfg_out,test_cfg
test_cfg_out.close()

# Setup Network

## Define network

We use ResNet-18 as implemented in the torchvision module.  We reproduce it here and make a slight modification: we change the number of input channels from 3 to 1.  The original resnet expects an RGB image.  For our example, we only use the image from one plane from our hypothetical LAr TPC detector.

Original can be found [here](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py).

In [6]:
import torch.nn as nn
import math

# define convolution without bias that we will use throughout the network
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


# implements one ResNet unit
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    
# define the network. It provides options for 
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, input_channels=3):
        """
        inputs
        ------
        block: type of resnet unit
        layers: list of 4 ints. defines number of basic block units in each set of resnet units
        num_classes: output classes
        input_channels: number of channels in input images
        """
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # had to change stride of avgpool from original from 1 to 2
        self.avgpool = nn.AvgPool2d(7, stride=2)

        # I've added dropout to the network
        self.dropout = nn.Dropout2d(p=0.5,inplace=True)

        #print "block.expansion=",block.expansion                                                                                                                                                           
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)
    
    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        #print "avepool: ",x.data.shape                                                                                                                                                                     
        x = x.view(x.size(0), -1)
        #print "view: ",x.data.shape                                                                                                                                                                        
        x = self.fc(x)

        return x


    
# define a helper function for ResNet-18
def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.                                                                                                                                                                        
                                                                                                                                                                                                            
    Args:                                                                                                                                                                                                   
        pretrained (bool): If True, returns a model pre-trained on ImageNet                                                                                                                                 
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


## Create instance of network

In [7]:
model = resnet18(pretrained=False,num_classes=5, input_channels=1)
model.cuda()

ResNet(
  (conv1): Conv2d (1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1))
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNo

## Define loss function

In [8]:
criterion = nn.CrossEntropyLoss().cuda()

## Define optimizer and set training parameters

In [9]:
lr = 1.0e-3
momentum = 0.9
weight_decay = 1.0e-3
batchsize = 50
batchsize_valid = 500
start_epoch = 0
epochs      = 1500
nbatches_per_epoch = 10000/batchsize
nbatches_per_valid = 1000/batchsize_valid

# We use SGD
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=momentum, weight_decay=weight_decay)

# Define training and validation steps

We define functions and classes to help us perform training.

### Define an object that will help us track averages

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

### Training step

In [11]:
def train(train_loader, model, criterion, optimizer, nbatches, epoch, print_freq):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    format_time = AverageMeter()
    train_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode                                                                                                                                                                                  
    model.train()

    for i in range(0,nbatches):
        #print "epoch ",epoch," batch ",i," of ",nbatches                                                                                                                                                   
        batchstart = time.time()

        end = time.time()
        data = train_loader[i]
        # measure data loading time                                                                                                                                                                         
        data_time.update(time.time() - end)

        end = time.time()
        img = data["image"]
        lbl = data["label"]
        img_np = np.zeros( (img.shape[0], 1, 256, 256), dtype=np.float32 )
        lbl_np = np.zeros( (lbl.shape[0] ), dtype=np.int )
        # batch loop                                                                                                                                                                                        
        for j in range(img.shape[0]):
            imgtmp = img[j].reshape( (256,256) )
            img_np[j,0,:,:] = padandcropandflip(imgtmp) # data augmentation                                                                                                                                 
            lbl_np[j] = np.argmax(lbl[j])
        input  = torch.from_numpy(img_np).cuda()
        target = torch.from_numpy(lbl_np).cuda()

        # measure data formatting time                                                                                                                                                                      
        format_time.update(time.time() - end)


        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output                                                                                                                                                                                    
        end = time.time()
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss                                                                                                                                                                  
        prec1 = accuracy(output.data, target, topk=(1,))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))
        
        # compute gradient and do SGD step                                                                                                                                                                  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_time.update(time.time()-end)

        # measure elapsed time                                                                                                                                                                              
        batch_time.update(time.time() - batchstart)
        
        if i % print_freq == 0:
            status = (epoch,i,nbatches,
                      batch_time.val,batch_time.avg,
                      data_time.val,data_time.avg,
                      format_time.val,format_time.avg,
                      train_time.val,train_time.avg,
                      losses.val,losses.avg,
                      top1.val,top1.avg)
            print "Epoch: [%d][%d/%d]\tTime %.3f (%.3f)\tData %.3f (%.3f)\tFormat %.3f (%.3f)\tTrain %.3f (%.3f)\tLoss %.3f (%.3f)\tPrec@1 %.3f (%.3f)"%status
            
    return losses.avg,top1.avg

### Validation step

Here we process the test data and accumilate the accuracy.

In [21]:
def validate(val_loader, model, criterion, nbatches, print_freq):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode                                                                                                                                                                               
    model.eval()

    end = time.time()
    for i in range(0,nbatches):
        data = val_loader[i]
        img = data["imagetest"]
        lbl = data["labeltest"]
        img_np = np.zeros( (img.shape[0], 1, 256, 256), dtype=np.float32 )
        lbl_np = np.zeros( (lbl.shape[0] ), dtype=np.int )
        for j in range(img.shape[0]):
            img_np[j,0,:,:] = img[j].reshape( (256,256) )
            lbl_np[j] = np.argmax(lbl[j])
        input  = torch.from_numpy(img_np).cuda()
        target = torch.from_numpy(lbl_np).cuda()

        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        # compute output                                                                                                                                                                                    
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss                                                                                                                                                                  
        prec1 = accuracy(output.data, target, topk=(1,))
        losses.update(loss.data[0], input.size(0))
        top1.update(prec1[0], input.size(0))

        # measure elapsed time                                                                                                                                                                              
        batch_time.update(time.time() - end)
        end = time.time()
        if i % print_freq == 0:
            status = (i,nbatches,batch_time.val,batch_time.avg,losses.val,losses.avg,top1.val,top1.avg)
            print "Test: [%d/%d]\tTime %.3f (%.3f)\tLoss %.3f (%.3f)\tPrec@1 %.3f (%.3f)"%status
 
    print "Test:Result* Prec@1 %.3f\tLoss %.3f"%(top1.avg,losses.avg)
    
    return float(top1.avg),float(losses.avg)

### utility functions

In [13]:
def adjust_learning_rate(optimizer, epoch, lr):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    #lr = lr * (0.5 ** (epoch // 300))                                                                                                                                                                      
    lr = lr
    #lr = lr*0.992                                                                                                                                                                                          
    #print "adjust learning rate to ",lr                                                                                                                                                                    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def dump_lr_schedule( startlr, numepochs ):
    for epoch in range(0,numepochs):
        lr = startlr*(0.5**(epoch//300))
        if epoch%10==0:
            print "Epoch [%d] lr=%.3e"%(epoch,lr)
    print "Epoch [%d] lr=%.3e"%(epoch,lr)
    return

def padandcropandflip(npimg2d):
    imgpad  = np.zeros( (264,264), dtype=np.float32 )
    imgpad[4:256+4,4:256+4] = npimg2d[:,:]
    if np.random.rand()>0.5:
        imgpad = np.flip( imgpad, 0 )
    if np.random.rand()>0.5:
        imgpad = np.flip( imgpad, 1 )
    randx = np.random.randint(0,8)
    randy = np.random.randint(0,8)
    return imgpad[randx:randx+256,randy:randy+256]

def save_checkpoint(state, is_best, p, filename='checkpoint.pth.tar'):
    if p>0:
        filename = "checkpoint.%dth.tar"%(p)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


# Load the datasets and start data loading threads


### Training data

For the training data, we ask that all the data is loaded into memory. Since we need to get many, many batches to train the network, reducing the time to get a batch of images will pay off in the long run.

However, we first pay an upgront cost: this step takes a LONG time.

In [14]:
capevents = 100 # for debugging, to keep the time this step takes to a minimum
iotrain = LArCVDataset("train_dataloader.cfg", "ThreadProcessor", loadallinmem=True, max_inmem_events=capevents)
iotrain.start(batchsize)

Attempting to load all  100  into memory. good luck
elapsed time to bring data into memory:  5.44157385826 sec
ThreadProcessor : {
  InputFiles : ["/home/taritree/working/dlphysics/testset/train_50k.root"]
  NumBatchStorage : 3
  NumThreads : 3
  ProcessName : ["image","label"]
  ProcessType : ["BatchFillerImage2D","BatchFillerPIDLabel"]
  RandomAccess : true
  Verbosity : 3
  ProcessList : {
    image : {
      Channels : [2]
      EnableMirror : true
      ImageProducer : "data"
      Verbosity : 3
    }

    label : {
      ParticleProducer : "mctruth"
      PdgClassList : [2212,11,211,13,22]
      Verbosity : 3
    }

  }

}

[93m setting verbosity [00m3


Error in <TProtoClass::FindDataMember>: data member with index 0 is not found in class thread
Error in <CreateRealData>: Cannot find data member # 0 of class thread for parent larcv::ThreadProcessor!


### Validation data

For the validation data, we do not load data into memory all at once. We will use the validation only periodically, in between many training steps. During those training steps, the thread filler will load data into memory.

In [15]:
iovalid = LArCVDataset("valid_dataloader.cfg", "ThreadProcessorTest")
iovalid.start(batchsize_valid)

[93m setting verbosity [00m3


# Training Loop

In [28]:
best_prec1 = 0.0

# we store output from the training loop
traininglog = open("log_training.txt",'w')

for epoch in range(start_epoch, epochs):
    if epoch%10==0:
        print "Epoch ",epoch
    adjust_learning_rate(optimizer, epoch, lr)
    epochout = "Epoch [%d]: "%(epoch)
    for param_group in optimizer.param_groups:
        epochout += "lr=%.3e"%(param_group['lr'])
    traininglog.write(epochout+'\n')

    # train for one epoch                                                                                                                                                                               
    try:
        train_ave_loss, train_ave_acc = train(iotrain, model, criterion, optimizer, nbatches_per_epoch, epoch, 50)
    except Exception,e:
        print "Error in training routine!"
        print e.message
        print e.__class__.__name__
        traceback.print_exc(e)
        break
    traininglog.write( "Epoch [%d] train aveloss=%.3f aveacc=%.3f\n"%(epoch,train_ave_loss,train_ave_acc) )

    # evaluate on validation set                                                                                                                                                                        
    try:
        prec1,valid_loss = validate(iovalid, model, criterion, nbatches_per_valid, 1)
    except Exception,e:
        print "Error in validation routine!"
        print e.message
        print e.__class__.__name__
        traceback.print_exc(e)
        break
    traininglog.write( "Test[%d]:Result* Prec@1 %.3f\tLoss %.3f\n"%(epoch,prec1,valid_loss) )
        
    # remember best prec@1 and save checkpoint                                                                                                                                                          
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1, best_prec1)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer' : optimizer.state_dict(),
    }, is_best, -1)
    if epoch==5*50:
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, False, epoch)


Epoch  0
Epoch: [0][0/200]	Time 0.096 (0.096)	Data 0.005 (0.005)	Format 0.013 (0.013)	Train 0.077 (0.077)	Loss 0.002 (0.002)	Prec@1 100.000 (100.000)
Epoch: [0][50/200]	Time 0.212 (0.206)	Data 0.003 (0.004)	Format 0.134 (0.126)	Train 0.075 (0.076)	Loss 0.004 (0.013)	Prec@1 100.000 (99.765)
Epoch: [0][100/200]	Time 0.212 (0.208)	Data 0.003 (0.004)	Format 0.133 (0.128)	Train 0.075 (0.076)	Loss 0.008 (0.012)	Prec@1 100.000 (99.762)
Epoch: [0][150/200]	Time 0.212 (0.209)	Data 0.004 (0.004)	Format 0.129 (0.129)	Train 0.079 (0.076)	Loss 0.005 (0.012)	Prec@1 100.000 (99.775)
Test: [0/2]	Time 0.741 (0.741)	Loss 2.303 (2.303)	Prec@1 50.200 (50.200)
Test: [1/2]	Time 0.736 (0.738)	Loss 2.461 (2.382)	Prec@1 51.200 (50.700)
Test:Result* Prec@1 50.700	Loss 2.382
Epoch: [1][0/200]	Time 0.094 (0.094)	Data 0.004 (0.004)	Format 0.011 (0.011)	Train 0.079 (0.079)	Loss 0.005 (0.005)	Prec@1 100.000 (100.000)
Epoch: [1][50/200]	Time 0.220 (0.215)	Data 0.004 (0.004)	Format 0.139 (0.132)	Train 0.077 (0.079)	Lo

KeyboardInterrupt: 

In [23]:
import os,sys,re


def make_training_plot( logfile, outputpath ):

    loglines = open(logfile,'r').readlines()

    # store tuples (epoch,loss,acc)                                                                                                                                                                         
    test_pts  = []
    train_pts = []
    lr_pts    = []
    lr_max = 0
    lr_min = 1.0e6

    epoch_scale = 0.2

    current_epoch = 0
    for l in loglines:
        l = l.strip()
        data = l.split()
        if "train aveloss" in l:
            pt = ( int(filter(str.isdigit,data[1])), float(re.findall("\d+\.\d+",data[3])[0]), float(re.findall("\d+\.\d+",data[4])[0]) )
            current_epoch = pt[0]
            train_pts.append(pt)
        if "Test:Result*" in l:
            pt = ( current_epoch, float(data[4]), float(data[2]) )
            test_pts.append(pt)
        if "lr=" in l:
            pt = ( int(filter(str.isdigit,data[1])), float( data[-1].split("=")[-1] ) )
            if pt[1]>lr_max:
                lr_max = pt[1]
            if pt[1]<lr_min:
                lr_min = pt[1]
            lr_pts.append( pt )


    sys.argv.append("-b")
    import ROOT as rt
    rt.gStyle.SetOptStat(0)

    graphs = {}
    graphs["trainacc"]  = rt.TGraph( len(train_pts) )
    graphs["trainloss"] = rt.TGraph( len(train_pts) )
    graphs["testacc"]   = rt.TGraph( len(test_pts) )
    graphs["testloss"]  = rt.TGraph( len(test_pts) )
    graphs["lr"]        = rt.TGraph( len(lr_pts) )

    accmax = 0
    accmin = 1.0e6
    lossmax = 0
    lossmin = 1.0e6
    for ipt,pt in enumerate(train_pts):
        graphs["trainacc"].SetPoint( ipt, pt[0]*epoch_scale, pt[2] )
        graphs["trainloss"].SetPoint( ipt, pt[0]*epoch_scale, pt[1] )
        if accmax<pt[2]:
            accmax = pt[2]
        if accmin>pt[2]:
            accmin = pt[2]
        if lossmax<pt[1]:
            lossmax = pt[1]
        if lossmin>pt[1]:
            lossmin = pt[1]

    for ipt,pt in enumerate(test_pts):
        graphs["testacc"].SetPoint( ipt, pt[0]*epoch_scale, pt[2] )
        graphs["testloss"].SetPoint( ipt, pt[0]*epoch_scale, pt[1] )
        if accmax<pt[2]:
            accmax = pt[2]
        if accmin>pt[2]:
            accmin = pt[2]
        if lossmax<pt[1]:
            lossmax = pt[1]
        if lossmin>pt[1]:
            lossmin = pt[1]


    c = rt.TCanvas("c","",1400,600)
    c.Divide(2,1)

    # hitogram to set scales                                                                                                                                                                                
    hloss = rt.TH1D("hloss",";epoch;loss",100, 0,train_pts[-1][0]*epoch_scale*1.1)
    hloss.SetMinimum( 0.5*lossmin )
    hloss.SetMaximum( 5.0*lossmax )

    hacc = rt.TH1D("hacc",";epoch;accuracy (percent)",100, 0,train_pts[-1][0]*epoch_scale*1.1)
    hacc.SetMinimum( 0.0 )
    hacc.SetMaximum( 100.0 )
    
        # Loss                                                                                                                                                                                                  
    c.cd(1).SetLogy(1)
    c.cd(1).SetGridx(1)
    c.cd(1).SetGridy(1)
    hloss.Draw()
    graphs["trainloss"].SetLineColor(rt.kBlack)
    graphs["testloss"].SetLineColor(rt.kBlue)
    graphs["lr"].SetLineColor(rt.kRed)
    graphs["trainloss"].Draw("LP")
    graphs["testloss"].Draw("LP")

    # superimpose lr graph                                                                                                                                                                                  
    rightmax = 1.1*lr_max
    rightmin = 0.9*lr_min
    scale    = rt.gPad.GetUymax()/rightmax
    for ipt,pt in enumerate(lr_pts):
        graphs["lr"].SetPoint( ipt, pt[0]*epoch_scale, pt[1]*scale )
    graphs["lr"].Draw("LPsame")
    lraxis = rt.TGaxis( rt.gPad.GetUxmax(), rt.gPad.GetUymin(), rt.gPad.GetUxmax(), rt.gPad.GetUymax(), rightmin, rightmax, 510, "+LG" )
    lraxis.SetLineColor(rt.kRed)
    lraxis.SetLabelColor(rt.kRed)
    lraxis.Draw()

    # Accuracy                                                                                                                                                                                              
    c.cd(2).SetLogy(0)
    c.cd(2).SetGridx(1)
    c.cd(2).SetGridy(1)
    hacc.Draw()
    graphs["trainacc"].SetLineColor(rt.kBlack)
    graphs["testacc"].SetLineColor(rt.kBlue)
    graphs["trainacc"].Draw("LP")
    graphs["testacc"].Draw("LP")

    c.Update()
    c.Draw()

    c.SaveAs(outputpath)



In [None]:
# make a plot

make_training_plot( "log_training.txt", "training.png")

In [None]:
# once the training is over. stop the fillers
iotrain.stop()
iovalid.stop()