In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.quantization
from torch.quantization import QuantStub
import matplotlib.pyplot as plt
import numpy as np
import copy
import os
import time
import torch.nn as nn
import torch.nn.functional as F
import psutil

import torch.optim as optim

from torch.profiler import profile, record_function, ProfilerActivity

In [2]:
"""1. LOAD AND NORMALIZE CIFAR10"""

transform = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

batch_size = 4

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)
testsize = len(testset)

classes = ('0','1','2','3','4','5','6','7','8','9')



In [3]:
"""1.1 SHOW SOME TRAINING IMAGES JUST FOR FUN"""
# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [4]:
"""2. DEFINE A CONVOLUTIONAL NEURAL NETWORK"""
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.AvgPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        #self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(16, 120, 5)
        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)
        self.dequant = torch.quantization.DeQuantStub()


    def forward(self, x):

        x = self.quant(x)
        x = self.pool(torch.tanh(self.conv1(x)))
        x = self.pool(torch.tanh(self.conv2(x)))
        
        x = torch.tanh(self.conv3(x))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = torch.tanh(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)

        return x

net = Net()

"""3. DEFINE A LOSS FUNCTION AND OPTIMIZER"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [5]:
"""4. TRAIN THE NETWORK"""
for epoch in range(1):  # loop over the dataset multiple times


            
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
              (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
print('Finished Training')

[1,  2000] loss: 0.982
[1,  4000] loss: 0.307
[1,  6000] loss: 0.230
[1,  8000] loss: 0.181
[1, 10000] loss: 0.162
[1, 12000] loss: 0.130
[1, 14000] loss: 0.112
Finished Training


In [6]:
"""Accuracy per Class"""

# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print("Accuracy for class {:5s} is: {:.1f} %".format(classname, accuracy))


model = torch.quantization.quantize_dynamic(
    net,
    {torch.nn.Linear},
    dtype=torch.qint8)

def validation_model(model, criterion, optimizer, device):
    model.eval()
    running_corrects = 0
    running_loss = 0.0
    time_start = time.perf_counter()
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
    time_elapsed = (time.perf_counter() - time_start)
    
    val_acc = running_corrects.double() / testsize
    print('Test accuracy: {:4f}'.format(val_acc))
    
    print('Inference complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    
    process = psutil.Process(os.getpid())
    mem = process.memory_info().rss/1024.0/1024.0
    print("report mem usage %5.3f MB" % mem)

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

# compare the sizes
f=print_size_of_model(net,"fp32")
q=print_size_of_model(model,"int8")
print("{0:.2f} times smaller".format(f/q))

# compare the performance
print("Floating point FP32")
get_ipython().run_line_magic('timeit', 'net.forward(inputs)')

print("Quantized INT8")
get_ipython().run_line_magic('timeit', 'model.forward(inputs)')

# run the float model
out1 = net(inputs)
mag1 = torch.mean(abs(out1)).item()
print('mean absolute value of output tensor values in the FP32 model is {0:.5f} '.format(mag1))

# run the quantized model
out2 = model(inputs)
mag2 = torch.mean(abs(out2)).item()
print('mean absolute value of output tensor values in the INT8 model is {0:.5f}'.format(mag2))

# compare them
mag3 = torch.mean(abs(out1-out2)).item()
print('mean absolute value of the difference between the output tensors is {0:.5f} or {1:.2f} percent'.format(mag3,mag3/mag1*100))

validation_model(model, criterion, optimizer,'cpu')



Accuracy for class 0     is: 98.7 %
Accuracy for class 1     is: 98.8 %
Accuracy for class 2     is: 97.3 %
Accuracy for class 3     is: 95.9 %
Accuracy for class 4     is: 95.9 %
Accuracy for class 5     is: 97.8 %
Accuracy for class 6     is: 97.2 %
Accuracy for class 7     is: 98.1 %
Accuracy for class 8     is: 95.9 %
Accuracy for class 9     is: 96.1 %
model:  fp32  	 Size (KB): 249.735
model:  int8  	 Size (KB): 218.363
1.14 times smaller
Floating point FP32
1.82 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Quantized INT8
2.44 ms ± 217 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
mean absolute value of output tensor values in the FP32 model is 2.68151 
mean absolute value of output tensor values in the INT8 model is 2.68336
mean absolute value of the difference between the output tensors is 0.01019 or 0.38 percent
Test accuracy: 0.971767
Inference complete in 0m 43s
report mem usage 254.391 MB
