# Training the model

Important notes on the training method from the paper

-------------------------
Training is carried out by optimising the multinomial logistic regression objective (softmax) using mini-batch gradient descent (based on back-propagation (LeCun et al., 1989)) with momentum. 
The batch size was set to 256, momentum to 0.9. The learning rate was initially set to 10^−2, and then decreased by a factor of 10 when the validation set accuracy stopped improving. In total, the learning rate was decreased 3 times, and the learning was stopped after 370K iterations (74 epochs).

In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from VGG_base import VGG

# Check device    
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    print_freq = 10 # print every 10 batches
    train_loss = 0.0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # compute output
        outputs = model(inputs)        
        loss = criterion(outputs, targets)
        
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # record loss
        train_loss += loss.item()
        
        if batch_idx % print_freq == 0:
            print('Epoch %d, Batch: %d, Loss: %f' % (epoch+1, batch_idx+1, train_loss/print_freq))
            train_loss = 0.0

def validate(model, val_loader, criterion):
    model.evaluate()
    print_freq = 10 # print every 10 batches
    val_loss = 0.0
    
    with torch.no_grad(): # no need to track history
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # compute output
            outputs = model(inputs)        
            loss = criterion(outputs, targets)

            # record loss
            val_loss += loss.item()

            if batch_idx % print_freq == 0:
                print('Validation on Batch: %d, Loss: %f' % (batch_idx+1, val_loss/print_freq))
                val_loss = 0.0       

# Load CIFAR10 dataset
print('==> Preparing data...')
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cifar_trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

train_loader = DataLoader(cifar_trainset, batch_size=256,
                                            shuffle=True, num_workers=4)

cifar_valset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

val_loader = DataLoader(cifar_trainset, batch_size=256,
                                            shuffle=True, num_workers=4)

# Model
print('==> Building model...')
model = VGG('D', input_size=32) # VGG16 is configuration D (refer to paper)
model = model.to(device)

if device == "cuda:0":
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

# Training
num_epochs = 74
lr = 0.01
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=5e-4)

print('==> Training...')
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(num_epochs):
    # train one epoch
    train(model, train_loader, criterion, optimizer, epoch)
    # validate
    val_loss = validate()
    # adjust learning rate with scheduler
    scheduler.step(val_loss)
    
print('==> Finished Training')

In [None]:
print(model)