# CIFAR-10 TRAINING
# Imports

In [None]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

# NNI package for model quantization aware training (QAT).
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype

# Import NaiveModel from naive.py
from naive_cifar import NaiveModel

## Training Functions

In [None]:
# Model training function
def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('{:2.0f}%  Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))


# Model testing function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print('Loss: {}  Accuracy: {}%)\n'.format(
        test_loss, 100 * correct / len(test_loader.dataset)))

## Train and Test

In [1]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Split dataset into train/test sets.
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=True, download=True, transform=trans),
    batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('data', train=False, transform=trans),
    batch_size=1000, shuffle=True)


idim = next(iter(train_loader))[0][0].size()[1]
ifmap = next(iter(train_loader))[0][0].size()[0]

# Two things should be kept in mind when set this configure_list:
# 1. When deploying model on backend, some layers will be fused into one layer. For example, the consecutive
# conv + bn + relu layers will be fused into one big layer. If we want to execute the big layer in quantization
# mode, we should tell the backend the quantization information of the input, output, and the weight tensor of
# the big layer, which correspond to conv's input, conv's weight and relu's output.
# 2. Same tensor should be quantized only once. For example, if a tensor is the output of layer A and the input
# of the layer B, you should configure either {'quant_types': ['output'], 'op_names': ['a']} or
# {'quant_types': ['input'], 'op_names': ['b']} in the configure_list.

# Quantization configuration -
# conv1, conv2 layers have INT8 weight, INT8 input activations.
# relu layers have INT8 output activations
# FC layers have INT8 weight, input, output activations

# INT8 weight, INT8 activations
num_bits = 8

configure_list = [{
    'quant_types': ['weight', 'input'],
    'quant_bits': {'weight': num_bits, 'input': num_bits},
    'quant_start_step': 2,
    'op_names': ['conv1', 'conv2']
}, {
    'quant_types': ['output'],
    'quant_bits': {'output': num_bits},
    'quant_start_step': 2,
    'op_names': ['relu1', 'relu2', 'relu3']
}, {
    'quant_types': ['output', 'weight', 'input'],
    'quant_bits': {'output': num_bits, 'weight': num_bits, 'input': num_bits},
    'quant_start_step': 2,
    'op_names': ['fc1', 'fc2'],
}]

# you can also set the quantization dtype and scheme layer-wise through configure_list like:
# configure_list = [{
#         'quant_types': ['weight', 'input'],
#         'quant_bits': {'weight': 8, 'input': 8},
#         'op_names': ['conv1', 'conv2'],
#         'quant_dtype': 'int',
#         'quant_scheme': 'per_channel_symmetric'
#       }]
# For now quant_dtype's options are 'int' and 'uint. And quant_scheme's options are per_tensor_affine,
# per_tensor_symmetric, per_channel_affine and per_channel_symmetric.

# per_tensor_symmetric quantization scheme - see [Jacob et. al]
set_quant_scheme_dtype('weight', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
set_quant_scheme_dtype('input', 'per_tensor_symmetric', 'int')

model = NaiveModel().to(device)
dummy_input = torch.randn(1, ifmap, idim, idim).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# To enable batch normalization folding in the training process, you should
# pass dummy_input to the QAT_Quantizer.

quantizer = QAT_Quantizer(model, configure_list, optimizer, dummy_input=dummy_input)
quantizer.compress()

# Train and test/evaluate for 10 epoch
model.to(device)
for epoch in range(10):
    print('# Epoch {} #'.format(epoch))
    train(model, device, train_loader, optimizer)
    test(model, device, test_loader)

# Save quantized model parameters
model_path = "cifar_model.pth"
calibration_path = "cifar_calibration.pth"
onnx_path = "cifar_model.onnx"
input_shape = (1, ifmap, idim, idim)
torch.save(model.state_dict(), model_path)

# Quantize model with QAT_Quantizer (see NNI package)
#https://nni.readthedocs.io/en/stable/compression/quantization.html
# https://github.com/microsoft/nni

qmodel = NaiveModel().to(device)
dummy_input = torch.randn(1, ifmap, idim, idim).to(device)
optimizer = torch.optim.SGD(qmodel.parameters(), lr=0.01, momentum=0.5)
# To enable batch normalization folding in the training process, you should
# pass dummy_input to the QAT_Quantizer.
quantizer = QAT_Quantizer(qmodel, configure_list, optimizer, dummy_input=dummy_input)
quantizer.compress()
state = torch.load(model_path, map_location='cpu')
qmodel.load_state_dict(state, strict=True)
test(qmodel, device, test_loader)

  warn(


Files already downloaded and verified
# Epoch 0 #
 0%  Loss 2.300363302230835
13%  Loss 2.264918565750122
26%  Loss 2.1490731239318848
38%  Loss 1.9851025342941284
51%  Loss 2.090935230255127
64%  Loss 1.8688970804214478
77%  Loss 1.8867095708847046
90%  Loss 1.5262688398361206
Loss: 1.69958466796875  Accuracy: 39.04%)

# Epoch 1 #
 0%  Loss 1.6992064714431763
13%  Loss 1.7389558553695679
26%  Loss 1.4938193559646606
38%  Loss 1.8420473337173462
51%  Loss 1.4800643920898438
64%  Loss 1.512792944908142
77%  Loss 1.2945753335952759
90%  Loss 1.6110529899597168
Loss: 1.4807634155273437  Accuracy: 46.43%)

# Epoch 2 #
 0%  Loss 1.527542233467102
13%  Loss 1.456183671951294
26%  Loss 1.3507345914840698
38%  Loss 1.7552413940429688
51%  Loss 1.521737813949585
64%  Loss 1.8493103981018066
77%  Loss 1.2394335269927979
90%  Loss 1.2515230178833008
Loss: 1.4326294555664063  Accuracy: 47.48%)

# Epoch 3 #
 0%  Loss 1.444076418876648
13%  Loss 1.4969987869262695
26%  Loss 1.2856460809707642
38%  L