<a href="https://colab.research.google.com/github/andreunifi/Deployment-of-Quantized-Neural-Networks-on-FPGA/blob/main/cnn_quant_qat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install brevitas



In [2]:
import os

# Specify the directory for saving checkpoints
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Save model and optimizer state
def save_checkpoint(epoch, model, optimizer, loss, checkpoint_dir, filename="checkpoint.pth.tar"):
    checkpoint_path = os.path.join(checkpoint_dir, filename)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

In [3]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Brevitas imports
from brevitas.nn import QuantConv2d, QuantLinear, QuantReLU, QuantIdentity
from brevitas.quant import Int8ActPerTensorFixedPoint, Int8WeightPerTensorFixedPoint


In [4]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
num_epochs = 20
batch_size = 4
learning_rate = 0.001

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


100%|██████████| 170M/170M [00:03<00:00, 43.0MB/s]


In [5]:
# Quantized CNN
class QuantCNN(torch.nn.Module):
    def __init__(self):
        super(QuantCNN, self).__init__()
        bit_width = 8
        self.quant_input = QuantIdentity(act_quant=Int8ActPerTensorFixedPoint, bit_width=bit_width, return_quant_tensor=True)
        self.relu = QuantReLU(act_quant=Int8ActPerTensorFixedPoint, bit_width=bit_width, return_quant_tensor=True)

        self.conv1 = QuantConv2d(3, 32, 3, padding=1, weight_quant=Int8WeightPerTensorFixedPoint)
        self.conv2 = QuantConv2d(32, 64, 3, padding=1, weight_quant=Int8WeightPerTensorFixedPoint)
        self.conv3 = QuantConv2d(64, 128, 3, padding=1, weight_quant=Int8WeightPerTensorFixedPoint)
        self.conv4 = QuantConv2d(128, 256, 3, padding=1, weight_quant=Int8WeightPerTensorFixedPoint)

        self.pool = torch.nn.MaxPool2d(2, 2)
        self.dropout = torch.nn.Dropout(0.5)

        self._get_flatten_size()

        self.fc1 = QuantLinear(self.flatten_size, 256, weight_quant=Int8WeightPerTensorFixedPoint)

        self.fc2 = QuantLinear(256, 128, weight_quant=Int8WeightPerTensorFixedPoint)

        self.fc3 = QuantLinear(128, 10, weight_quant=Int8WeightPerTensorFixedPoint)

    def _get_flatten_size(self):
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 32, 32)
            dummy = self.pool(self.relu(self.conv1(dummy)))
            dummy = self.pool(self.relu(self.conv2(dummy)))
            dummy = self.pool(self.relu(self.conv3(dummy)))
            dummy = self.pool(self.relu(self.conv4(dummy)))
            self.flatten_size = dummy.view(1, -1).size(1)

    def forward(self, x):
        x = self.quant_input(x)
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.pool(self.relu(self.conv4(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        x = self.fc3(x)
        return x


# Create and train the model
model = QuantCNN().to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 2000 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
    save_checkpoint(epoch + 1, model, optimizer, loss.item(), checkpoint_dir, filename=f'checkpoint_epoch_{epoch+1}.pth.tar')

print('Finished Training')

# Evaluation
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    n_class_correct = [0 for _ in range(10)]
    n_class_samples = [0 for _ in range(10)]
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

        for i in range(len(labels)):
            label = labels[i]
            pred = predicted[i]
            if label == pred:
                n_class_correct[label] += 1
            n_class_samples[label] += 1

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network: {acc:.2f} %')

    for i in range(10):
        acc = 100.0 * n_class_correct[i] / n_class_samples[i]
        print(f'Accuracy of {classes[i]}: {acc:.2f} %')

  return super().rename(names)


Epoch [1/20], Step [2000/12500], Loss: 2.3140
Epoch [1/20], Step [4000/12500], Loss: 2.3144
Epoch [1/20], Step [6000/12500], Loss: 2.3152
Epoch [1/20], Step [8000/12500], Loss: 2.2910
Epoch [1/20], Step [10000/12500], Loss: 2.2900
Epoch [1/20], Step [12000/12500], Loss: 2.3062
Checkpoint saved at ./checkpoints/checkpoint_epoch_1.pth.tar
Epoch [2/20], Step [2000/12500], Loss: 2.2925
Epoch [2/20], Step [4000/12500], Loss: 2.2889
Epoch [2/20], Step [6000/12500], Loss: 2.3070
Epoch [2/20], Step [8000/12500], Loss: 2.2974
Epoch [2/20], Step [10000/12500], Loss: 2.2890
Epoch [2/20], Step [12000/12500], Loss: 2.3001
Checkpoint saved at ./checkpoints/checkpoint_epoch_2.pth.tar
Epoch [3/20], Step [2000/12500], Loss: 2.2967
Epoch [3/20], Step [4000/12500], Loss: 2.3012
Epoch [3/20], Step [6000/12500], Loss: 2.3101
Epoch [3/20], Step [8000/12500], Loss: 2.3046
Epoch [3/20], Step [10000/12500], Loss: 2.2937
Epoch [3/20], Step [12000/12500], Loss: 2.2015
Checkpoint saved at ./checkpoints/checkpoint

KeyboardInterrupt: 