# AlexNet In Pytorch

In [None]:
import torch, torchvision, os

def train_alexnet(num_steps=100000):
    print("Making alexnet...")
    alexnet = make_untrained_alexnet()
    alexnet.train()
    print("Loading datasets...")
    train_loader, val_loader = get_train_and_val_data_loaders()
    print("Training classifier...")
    checkpointer = make_checkpointing_function(val_loader, checkpoint_dir='checkpoints')
    train_classifier(alexnet, train_loader,
                     max_iter=num_steps,
                     momentum=0.9,
                     init_lr=2e-2,
                     weight_decay=5e-4,
                     monitor=checkpointer)
    return alexnet



In [None]:
from torch import nn
from collections import OrderedDict
def make_untrained_alexnet():
    # channel widths
    w = [3, 96, 256, 384, 384, 256, 4096, 4096, 365]
    # Alexnet splits channels into groups
    groups = [1, 2, 1, 2, 2]
    model = nn.Sequential(OrderedDict([
        ('conv1', nn.Conv2d(w[0], w[1], kernel_size=11,
            stride=4,
            groups=groups[0], bias=True)),
        ('relu1', nn.ReLU(inplace=True)),
        ('pool1', nn.MaxPool2d(kernel_size=3, stride=2)),
        ('conv2', nn.Conv2d(w[1], w[2], kernel_size=5, padding=2,
            groups=groups[1], bias=True)),
        ('relu2', nn.ReLU(inplace=True)),
        ('pool2', nn.MaxPool2d(kernel_size=3, stride=2)),
        ('conv3', nn.Conv2d(w[2], w[3], kernel_size=3, padding=1,
            groups=groups[2], bias=True)),
        ('relu3', nn.ReLU(inplace=True)),
        ('conv4', nn.Conv2d(w[3], w[4], kernel_size=3, padding=1,
            groups=groups[3], bias=True)),
        ('relu4', nn.ReLU(inplace=True)),
        ('conv5', nn.Conv2d(w[4], w[5], kernel_size=3, padding=1,
            groups=groups[4], bias=True)),
        ('relu5', nn.ReLU(inplace=True)),
        ('pool5', nn.MaxPool2d(kernel_size=3, stride=2)),
        ('flatten', nn.Flatten()),
        ('fc6', nn.Linear(w[5] * 6 * 6, w[6], bias=True)),
        ('relu6', nn.ReLU(inplace=True)),
        ('dropout6', nn.Dropout()),
        ('fc7', nn.Linear(w[6], w[7], bias=True)),
        ('relu7', nn.ReLU(inplace=True)),
        ('dropout7', nn.Dropout()),
        ('fc8', nn.Linear(w[7], w[8]))
    ]))
    # Setup the initial parameters randomly
    for n, p in model.named_parameters():
        if 'bias' in n:
            torch.nn.init.zeros_(p)
        else:
            torch.nn.init.kaiming_normal_(p, nonlinearity='relu')
    model.cuda() #no cuda on colab
    model.train()
    return model

Display network

In [None]:
a = make_untrained_alexnet()
for n, p in a.named_parameters():
    print(n, tuple(p.shape))

Dataset

In [None]:
def get_places_data_set(split, crop_size=227, download=True):
    transform=torchvision.transforms.Compose([
               torchvision.transforms.ToTensor(), # first, convert image to PyTorch tensor
               torchvision.transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                ])
    if split == 'train':
      return torchvision.datasets.MNIST(
          root='./data',
          train=True,
          download=True,
          transform=transform
      )
    return torchvision.datasets.MNIST(
          root='./data',
          train=False,
          download=True,
          transform=transform
      )
def get_train_and_val_data_loaders():
    return [
        torch.utils.data.DataLoader(
            get_places_data_set(split),
            batch_size=32, shuffle=(split == 'train'))
        for split in ['train', 'val']
    ]

Training

In [None]:
def train_classifier(model, train_data_loader, max_iter,
                     momentum=0.9, init_lr=2e-2, weight_decay=5e-4,
                     monitor=None):
    if monitor is not None:
        monitor(model, 0, 0.0, 0.0, 0)

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=init_lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, init_lr, max_iter)
    iter_num = 0


    while iter_num < max_iter:
        for t_input, t_target in train_data_loader:
            # Copy data into the gpu
            input_var, target_var = [d.cuda() for d in [t_input, t_target]]
            # Evaluate model
            output = model(input_var)
            loss = torch.nn.functional.cross_entropy(output, target_var)
            # Perform one step of SGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step() # Learning rate schedule
            # Check training set accuracy
            _, pred = output.max(1)
            batch_size = len(t_input)
            accuracy = target_var.detach().eq(pred).float().sum().item() / batch_size
            # Advance, and print out some stats
            iter_num += 1

            print(1)
            if monitor is not None:
                monitor(model, iter_num, loss, accuracy, batch_size)
            if iter_num >= max_iter:
                break

Evaluation

In [None]:
def measure_val_accuracy_and_loss(model, val_data_loader):
    '''
    Evaluates the model (in inference mode) on holdout data.
    '''
    model.eval()
    val_loss, val_acc = AverageMeter(), AverageMeter()
    for input, target in val_data_loader:
        input_var, target_var = [d.cuda() for d in [input, target]]
        with torch.no_grad():
            output = model(input_var)
            loss = torch.nn.functional.cross_entropy(output, target_var)
            _, pred = output.max(1)
            accuracy = (target_var.eq(pred)
                    ).data.float().sum().item() / input.size(0)
        val_acc.update(accuracy, input.size(0))
        val_loss.update(loss.data.item(), input.size(0))
    return val_acc, val_loss

def save_model_iteration(model, iter_num, checkpoint_dir):
    '''
    Saves the current parameters of the model to a file.
    '''
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'iter_{iter_num}.pth'))

def make_checkpointing_function(val_data_loader, checkpoint_dir=None, checkpoint_freq=100):
    '''
    Makes a callback to monitor training and make checkpoints.
    '''
    avg_train_accuracy, avg_train_loss = AverageMeter(), AverageMeter()
    def monitor(model, iter_num, loss, accuracy, batch_size):
        avg_train_accuracy.update(accuracy, batch_size)
        avg_train_loss.update(loss, batch_size)
        if iter_num % checkpoint_freq == 0:
            val_accuracy, val_loss = measure_val_accuracy_and_loss(model, val_data_loader)
            if checkpoint_dir is not None:
                save_model_iteration(model, iter_num, checkpoint_dir)
            print(f'Iter {iter_num}, ' +
                  f'train acc {avg_train_accuracy.avg:.3g} loss {avg_train_loss.avg:.3g}, ' +
                  f'val acc {val_accuracy.avg:.3g}, loss {val_loss.avg:.3g}')
            model.train()
    return monitor

class AverageMeter(object):
    '''
    To keep running averages.
    '''
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        if self.count:
            self.avg = self.sum / self.count

In [None]:
num_iterations = 100
train_alexnet(num_iterations)