This is a demo of the KFAC solver on CIFAR10/100 using several convolutional networks. Require `torch` and `torchvision` to be installed before running. Tested on Python 3.7 with Torch 1.7 as well as Python 3.6 with Torch 1.4.

In [1]:
import argparse
import random
import time
from math import floor
from pathlib import Path

import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100

from tr_kfac_opt import KFACOptimizer
from nn_models import *  # contains several common networks
from utils import fmt_args, pp_time

To start, let's first set up the argument parser for the experiments.

In [2]:
class ExpArgParser(argparse.ArgumentParser):
    def __init__(self, *args, **kwargs):
        super().__init__(description='FITRE experiment -- using KFAC', *args, **kwargs)

        self.add_argument('--benchmark', type=str, choices=['cifar10', 'cifar100'], default='cifar10',
                          help='the benchmark to run')

        self.add_argument('--model', type=str, default='QAlexNetS',
                          choices=['QAlexNetS', 'QAlexNetSb', 'VGG16', 'VGG16b'],
                          help='neural network model')
        self.add_argument('--init', type=str, default='def',
                          choices=['def', 'km', 'xavier', '0', '1'],
                          help='init network model')

        self.add_argument('--batch-size', type=int, default=200,
                          help='input batch size for training')
        self.add_argument('--test-batch-size', type=int, default=200,
                          help='input batch size for testing')
        self.add_argument('--epochs', type=int, default=10,
                          help='number of epochs to train (default: 10)')
        self.add_argument('--da', type=int, default=1)

        self.add_argument('--seed', type=int, default=1,
                          help='random seed (default: 1)')

        # kfac related arguments
        self.add_argument('--weight-decay', type=float, default=0, metavar='weight',
                          help='learning rate (default: 0)')
        self.add_argument('--damp', type=float, default=0.01, metavar='damp',
                          help='damping (default: 0.01)')
        self.add_argument('--max-delta', type=float, default=100, metavar='maxdelta',
                          help='max delta (default: 100)')
        self.add_argument('--check-grad', action='store_true', default=False,
                          help='gradient')
        return

    def parse_args(self, args=None, namespace=None):
        res = super().parse_args(args, namespace)

        # set random seed for all
        random.seed(res.seed)
        np.random.seed(res.seed)
        torch.manual_seed(res.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(res.seed)
        
        res.cuda = torch.cuda.is_available()
        return res
    pass

Next, let's generate networks of specific architectures to train.

In [3]:
def get_net(args: argparse.Namespace) -> nn.Module:
    num_classes = {
        'cifar10': 10,
        'cifar100': 100
    }[args.benchmark]

    model_list = {
        # b means batch normalization
        'QAlexNetS': lambda: QAlexNetS(num_classes=num_classes),
        'QAlexNetSb': lambda: QAlexNetSb(num_classes=num_classes),
        'VGG16': lambda: VGG('VGG16', num_class=num_classes),
        'VGG16b': lambda: VGGb('VGG16', num_class=num_classes),
    }
    model = model_list[args.model]()

    if args.init == "km":
        model.apply(init_params)
    elif args.init == "xavier":
        model.apply(normal_init)
    elif args.init == "0":
        model.apply(zeros_init)
    elif args.init == '1':
        model.apply(ones_init)
    return model

Then, let's prepare the datasets for training.

In [4]:
def get_data_loader(args: argparse.Namespace, train: bool) -> Dataset:
    if args.benchmark == 'cifar10':
        if args.da == 0:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
            ])
        else:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

        transform = transform_train if train else transform_test
        return CIFAR10('./data', train=train, download=True, transform=transform)

    if args.benchmark == 'cifar100':
        if args.da == 0:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
            ])
        else:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
            ])
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

        transform = transform_train if train else transform_test
        return CIFAR100('./data', train=train, download=True, transform=transform)

    raise NotImplementedError(f'Benchmark {args.benchmark} unsupported yet.')

Let's now run the experiments, using default hyper-parameters of KFAC optimizer. Off notebook, more customized commands of the following shape can be executed:
```
python3 kfac_exp.py --batch-size 200 --epochs 200 --model QAlexNetS --seed 1 --init def --damp=0.01 --check-grad
```

In [5]:
parser = ExpArgParser()
# args = parser.parse_args()  # uncomment to use this line for command line usage
args = parser.parse_args([])  # for notebook demo, we pass in [] to use default arguments
print(fmt_args(args))


===== configuration =====
  benchmark: cifar10
  model: QAlexNetS
  init: def
  batch_size: 200
  test_batch_size: 200
  epochs: 10
  da: 1
  seed: 1
  weight_decay: 0
  damp: 0.01
  max_delta: 100
  check_grad: False
  cuda: True
===== end of configuration =====



In [6]:
device = torch.device('cuda') if args.cuda else torch.device('cpu')
model = get_net(args).to(device)

loader_kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_ds = get_data_loader(args, True)  # reload training set loader in every epoch
test_loader = DataLoader(get_data_loader(args, False),
                         batch_size=args.batch_size, shuffle=False, **loader_kwargs)

criterion = nn.CrossEntropyLoss()

def eval_test(model):
    model.eval()
    test_loss = 0.
    correct = 0
    total_batches = len(test_loader)
    total_pts = len(test_loader.dataset)
    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            test_loss += criterion(output, target).item() * len(data)
            correct += (pred == target).sum().item()
            print(f'Evaluated {i}/{total_batches} batches...', end='\r')

    test_loss /= total_pts
    acc = 1.0 * correct / total_pts
    model.train()
    return test_loss, acc

loss, acc = eval_test(model)
print(f'\rBefore any training, the test set loss is {loss}, accuracy is {acc}.')

Files already downloaded and verified
Files already downloaded and verified
Before any training, the test set loss is 2.30317054271698, accuracy is 0.0965.


We can now define KFACOptimizer and run multiple epoches of training.

In [7]:
kfac_opt = KFACOptimizer(model=model,
                         momentum=0.0,
                         stat_decay=0.8,
                         kl_clip=1e-0,
                         damping=args.damp,
                         weight_decay=args.weight_decay,
                         check_grad=args.check_grad,
                         max_delta=args.max_delta,
                         Tf=1)
t0 = time.time()
for epoch in range(args.epochs):
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, **loader_kwargs)
    tot_batches = len(train_loader)

    model.train()
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        # round 1 backprop
        outputs = model(inputs)
        kfac_opt.zero_grad()
        kfac_opt.acc_stats = True
        obj = criterion(outputs, outputs.argmax(dim=1))  # requires classification
        obj.backward(retain_graph=True)

        # round 2 backprop
        kfac_opt.zero_grad()
        loss = criterion(outputs, labels)
        kfac_opt.acc_stats = False
        loss.backward(create_graph=True)

        def _batch_loss():
            with torch.no_grad():
                # It's only for computing the loss, which needs no grad.
                model.eval()
                outputs = model(inputs)
                _loss = criterion(outputs, labels).item()
            model.train()
            return _loss

        kfac_opt.step(closure=_batch_loss)
        print(f'[{pp_time(time.time() - t0)}] Epoch {epoch}, batch {i} / {tot_batches}, loss {loss.item()}',
              end='\r')

    test_loss, test_acc = eval_test(model)
    print(f'[{pp_time(time.time() - t0)}] Epoch {epoch} -- Test loss: {test_loss}, Test accuracy: {test_acc}.')

[1m 34s (94.096 seconds)] Epoch 0 -- Test loss: 1.0403913414478303, Test accuracy: 0.6421.
[3m 8s (188.136 seconds)] Epoch 1 -- Test loss: 0.9435971391201019, Test accuracy: 0.6757.
[4m 42s (282.982 seconds)] Epoch 2 -- Test loss: 0.8827460837364197, Test accuracy: 0.6969.
[6m 17s (377.338 seconds)] Epoch 3 -- Test loss: 0.8288209521770478, Test accuracy: 0.7157.
[7m 51s (471.877 seconds)] Epoch 4 -- Test loss: 0.8044604575634002, Test accuracy: 0.7276.
[9m 26s (566.796 seconds)] Epoch 5 -- Test loss: 0.8292316210269928, Test accuracy: 0.7239.
[11m 1s (661.799 seconds)] Epoch 6 -- Test loss: 0.7942041862010956, Test accuracy: 0.7329.
[12m 36s (756.352 seconds)] Epoch 7 -- Test loss: 0.7785956192016602, Test accuracy: 0.7408.
[14m 10s (850.842 seconds)] Epoch 8 -- Test loss: 0.7730200099945068, Test accuracy: 0.7438.
[15m 45s (945.526 seconds)] Epoch 9 -- Test loss: 0.7478364408016205, Test accuracy: 0.7529.


Now training has finished.

In [8]:
print('Training finished.')

Training finished.
