In [1]:
import onnx
import torch
import torch.nn as nn
from torch.nn import BatchNorm2d
from torch.nn import MaxPool2d
from torch.nn import Module
from brevitas.nn import QuantLinear
from brevitas.nn import QuantReLU
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantConv2d
from brevitas.core.quant import QuantType
import torchvision
import torchvision.transforms as transforms

In [2]:
batch_size = 64
num_classes = 10
learning_rate = 0.001
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
train_dataset = torchvision.datasets.MNIST(root = './data',
                                           train = True,
                                           transform = transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                           download = True)


test_dataset = torchvision.datasets.MNIST(root = './data',
                                          train = False,
                                          transform = transforms.Compose([
                                                  transforms.Resize((32,32)),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1325,), std = (0.3105,))]),
                                          download=True)


train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

In [16]:
class Lenet_Quant(nn.Module):
    def __init__(self, num_classes):
        super(Lenet_Quant, self).__init__()
        self.layer1 = nn.Sequential(
        QuantConv2d(1, 6, kernel_size=(5,5), bias=False), #stride =1 padding = 0 by default
        BatchNorm2d(6),
        QuantReLU(),
        MaxPool2d(kernel_size =2, stride = 2))
        self.layer2 = nn.Sequential(QuantConv2d(6, 16, kernel_size = (5,5), bias = False),
        BatchNorm2d(16),
        QuantReLU(),
        MaxPool2d(kernel_size=2, stride =2))
        self.layer3 = nn.Sequential(
        QuantLinear(400, 120, bias = False),
        QuantReLU(),
        QuantLinear(120, 84, bias = False),
        QuantReLU(),
        QuantLinear(84, num_classes, bias = False))
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.layer3(out)
        return out

In [17]:
quant_model = Lenet_Quant(num_classes).to(device)

#Setting the loss function
cost = nn.CrossEntropyLoss()

#Setting the optimizer with the model parameters and learning rate
optimizer = torch.optim.Adam(quant_model.parameters(), lr=learning_rate)

#this is defined to print how many steps are remaining when training
total_step = len(train_loader)

In [18]:
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        images = images.to(device)
        labels = labels.to(device)
        
        #Forward pass
        outputs = quant_model(images)
        loss = cost(outputs, labels)
        	
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        		
        if (i+1) % 400 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
        		           .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

Epoch [1/10], Step [400/938], Loss: 0.1300
Epoch [1/10], Step [800/938], Loss: 0.0725
Epoch [2/10], Step [400/938], Loss: 0.0083
Epoch [2/10], Step [800/938], Loss: 0.0063
Epoch [3/10], Step [400/938], Loss: 0.0469
Epoch [3/10], Step [800/938], Loss: 0.1000
Epoch [4/10], Step [400/938], Loss: 0.0004
Epoch [4/10], Step [800/938], Loss: 0.0012
Epoch [5/10], Step [400/938], Loss: 0.0500
Epoch [5/10], Step [800/938], Loss: 0.0126
Epoch [6/10], Step [400/938], Loss: 0.0018
Epoch [6/10], Step [800/938], Loss: 0.0033
Epoch [7/10], Step [400/938], Loss: 0.0100
Epoch [7/10], Step [800/938], Loss: 0.0226
Epoch [8/10], Step [400/938], Loss: 0.0200
Epoch [8/10], Step [800/938], Loss: 0.0431
Epoch [9/10], Step [400/938], Loss: 0.0074
Epoch [9/10], Step [800/938], Loss: 0.0044
Epoch [10/10], Step [400/938], Loss: 0.0033
Epoch [10/10], Step [800/938], Loss: 0.0017


In [20]:
# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
  
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = quant_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

Accuracy of the network on the 10000 test images: 99.07 %


In [25]:
torch.save(quant_model.state_dict(), "qlenet5.pth")

In [24]:
#Export to Onnx
import brevitas.onnx as bo
from brevitas.export import FINN_MANAGER
export_onnx_path = "Lenet5_Quant.onnx"
input_shape = (1, 1, 32, 32)
bo.export_finn_onnx(quant_model, input_shape, export_onnx_path)

AttributeError: module 'brevitas.onnx' has no attribute 'export_finn_onnx'