# Training ResNeXt

This Notebook will go through the process of training the ResNeXt model with an architecture that is optimized for the Cifar Dataset (as specified in [1] https://arxiv.org/pdf/1611.05431.pdf).

In [1]:
import os
import sys
import json
import shutil
import torch
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import time

In [2]:
sys.path.append("..") # Note, this line is needed only since this notebook is being run from inside 'bharat' directory
                      # If you run this notebook in the same folder as the the 'models-py-code' directory, remove this line
from pymodels.resnext import CifarResNeXt # This is where we actually refer to the model we wish to train

## Set all required parameters
Remember to change the `session_id` for each new model you wish to train or if you change of the other parameters

In [3]:
args={}

# Input Data Parameters
args['data_path'] = '/datasets/ee285s-public/'
args['dataset'] = 'cifar100' # The other option is 'cifar10'

# Optimization options
args['epochs'] = 150
args['start_epoch'] = 0 # Change only if necessary. Used for resuming sessions
args['batch_size'] = 96
args['learning_rate'] = 0.1
args['momentum'] = 0.9 # Momentum
args['decay'] = 0.0005 # Weight decay (L2 penalty)
args['test_bs'] = 64 # Test Batch Size
args['schedule'] = [75, 115] # Decrease learning rate at these epochs
args['gamma'] = 0.1 # LR is multiplied by gamma on schedule.

# Checkpoints and Session Parameters
args['save_dir'] = './' # Folder to save checkpoint file and best model file
args['load'] = 'resnext-0-checkpoint.pth' # Path to load Checkpoint file 
args['session_id'] = '1' # Remember to change the session id for each new model you train

# Architecture
args['depth'] = 29 # Model depth
args['cardinality'] = 8 # Model cardinality (group)
args['base_width'] = 64 # Number of channels in each group
args['widen_factor'] = 4 # Widen factor. 4 -> 64, 8 -> 128, ...

# Acceleration
args['ngpu'] = 1 # 0 = CPU
args['prefetch'] = 2 # Pre-fetching threads

# i/o
args['log_dir'] = './log/' # Log folder

# Create a Dictionary to describe the state every epoch
epoch_state = {'args' : args}

### Initialize Logging

In [4]:
if not os.path.isdir(args['log_dir']):
        os.makedirs(args['log_dir'])

log_file_name = 'resnext-{}-log.txt'.format(args['session_id'])
if args['load'] == '':
    log = open(os.path.join(args['log_dir'], log_file_name), 'w')
else:
    log = open(os.path.join(args['log_dir'], log_file_name), 'a')    
print('Exisiting log {} loaded.'.format(log_file_name))

Exisiting log resnext-0-log.txt loaded.


### Setup the Input Data

In [5]:
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize(mean, std)])

test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)])

if args['dataset'] == 'cifar10':
    train_data = dset.CIFAR10(args['data_path'], train=True, transform=train_transform, download=False)
    test_data = dset.CIFAR10(args['data_path'], train=False, transform=test_transform, download=False)
    nlabels = 10
else:
    train_data = dset.CIFAR100(args['data_path'], train=True, transform=train_transform, download=False)
    test_data = dset.CIFAR100(args['data_path'], train=False, transform=test_transform, download=False)
    nlabels = 100

train_loader = torch.utils.data.DataLoader(train_data, batch_size=args['batch_size'], shuffle=True,
                                                                       num_workers=args['prefetch'], pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args['test_bs'], shuffle=False,
                                                                     num_workers=args['prefetch'], pin_memory=True)

### Setup the Model, Criterion, and Optimizer

In [6]:
resnext_model = CifarResNeXt(args['cardinality'], args['depth'], nlabels, args['base_width'], args['widen_factor'])
print(resnext_model)
if args['ngpu'] > 1:
    resnext_model = torch.nn.DataParallel(resnext_model, device_ids=list(range(args['ngpu'])))

if args['ngpu'] > 0:
    resnext_model.cuda()

optimizer = torch.optim.SGD(resnext_model.parameters(), args['learning_rate'], momentum=args['momentum'],
                            weight_decay=args['decay'], nesterov=True)

CifarResNeXt(
  (conv_1_3x3): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (stage_1): Sequential(
    (stage_1_bottleneck_0): ResNeXtBottleneck(
      (conv_reduce): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential(
        (shortcut_conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      )
    )
    (stage_1_bottleneck_1): ResNeXtBottleneck(
      (

### Setup the Train and Test Functions

In [7]:
def train(epoch):
    resnext_model.train()
    loss_avg = 0.0
    t0 = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = torch.autograd.Variable(data.cuda()), torch.autograd.Variable(target.cuda())

        # forward
        output = resnext_model(data)

        # backward
        optimizer.zero_grad()
        t0 = time.time()
        loss = F.cross_entropy(output, target)
        t1 = time.time()
        loss.backward()
        optimizer.step()

        # exponential moving average
        loss_avg = loss_avg * 0.2 + float(loss) * 0.8

        print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Timer: {:.4f} sec.".format(epoch, batch_idx, len(train_loader), loss.data[0], (t1 - t0)))

    epoch_state['train_loss'] = loss_avg
    t1 = time.time()
    print("===> Epoch {} Complete: Avg. Loss: {:.4f} | Time: {:.4f}".format(epoch, loss_avg, (t1-t0)))

# test function (forward only)
def test():
    resnext_model.eval()
    loss_avg = 0.0
    correct = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = torch.autograd.Variable(data.cuda()), torch.autograd.Variable(target.cuda())

        # forward
        output = resnext_model(data)
        loss = F.cross_entropy(output, target)

        # accuracy
        pred = output.data.max(1)[1]
        correct += float(pred.eq(target.data).sum())

        # test loss average
        loss_avg += float(loss)

    epoch_state['test_loss'] = loss_avg / len(test_loader)
    epoch_state['test_accuracy'] = correct / len(test_loader.dataset)

### Define a function to save a checkpoint

In [8]:
def save_checkpoint(state, is_best, filename='resnext-{}-checkpoint.pth'.format(args['session_id'])):
    filepath = os.path.join(args['save_dir'], filename)
    torch.save(state, filepath)
    if is_best:
        bestfilepath = os.path.join(args['save_dir'], 'model-best-resnext-{}.pth'.format(args['session_id']))
        shutil.copyfile(filepath, bestfilepath)

### Load and Resume from checkpoint (if set)

In [None]:
if not args['load'] == '':
    if os.path.isfile(args['load']):
        print("=> Loading Checkpoint '{}'".format(args['load']))
        checkpoint = torch.load(args['load'])
        args['start_epoch'] = checkpoint['epoch'] + 1
        args['learning_rate'] = checkpoint['learning_rate']
        resnext_model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> Loaded checkpoint '{}' (epoch {})"
              .format(args['load'], checkpoint['epoch']))
    else:
        print("=> No checkpoint found at '{}'".format(args['load']))

### Main Loop

In [None]:
best_accuracy = 0.0

for epoch in range(args['start_epoch'], args['epochs']):
    if epoch in args['schedule']:
        args['learning_rate'] *= args['gamma']
        for param_group in optimizer.param_groups:
            param_group['lr'] = args['learning_rate']

    epoch_state['epoch'] = epoch
    train(epoch)
    test()
    if epoch_state['test_accuracy'] > best_accuracy:
        best_accuracy = epoch_state['test_accuracy']
        isbest = True
    
    save_checkpoint({
        'epoch' : epoch,
        'learning_rate' : args['learning_rate'],
        'state_dict' : resnext_model.state_dict(),
        'optimizer' : optimizer.state_dict()
    }, isbest)
    
    log.write('%s\n' % json.dumps(epoch_state))
    log.flush()
    #print(epoch_state)
    print("[Epoch {}] Best accuracy: {:.4f}".format(epoch, best_accuracy))

log.close()