In [66]:
# https://xilinx.github.io/brevitas/getting_started.html

from torch.nn import Module
import torch.nn.functional as F

import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant

import os
import onnx
import torch
import numpy as np
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, Dataset
from brevitas.nn import QuantLinear, QuantReLU
import torch.nn as nn
from sklearn.metrics import accuracy_score
from tqdm import tqdm, trange

weight_bit_width = 2
act_bit_width = 2

torch.manual_seed(0)

class QuantWeightActLeNet(Module):
    def __init__(self):
        super(QuantWeightActLeNet, self).__init__()
        self.quant_inp = qnn.QuantIdentity(bit_width=4)
        self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_bit_width=weight_bit_width)
        self.relu1 = qnn.QuantReLU(bit_width=act_bit_width)
        self.conv2 = qnn.QuantConv2d(6, 16, 5, bias=True, weight_bit_width=weight_bit_width)
        self.relu2 = qnn.QuantReLU(bit_width=act_bit_width)
        self.fc1   = qnn.QuantLinear(16*5*5, 120, bias=True, weight_bit_width=weight_bit_width)
        self.relu3 = qnn.QuantReLU(bit_width=act_bit_width)
        self.fc2   = qnn.QuantLinear(120, 84, bias=True, weight_bit_width=weight_bit_width)
        self.relu4 = qnn.QuantReLU(bit_width=act_bit_width)
        self.fc3   = qnn.QuantLinear(84, 10, bias=True)

    def forward(self, x):
        out = self.quant_inp(x)
        out = self.relu1(self.conv1(out))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.reshape(out.shape[0], -1)
        out = self.relu3(self.fc1(out))
        out = self.relu4(self.fc2(out))
        out = self.fc3(out)
        return out

model = QuantWeightActLeNet()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Target device: " + str(device))

model.to(device); # The semicolon is for not printing the model.


Target device: cpu


### Data processing

In [67]:
# This is for 4-bit quantization. The quantization value can be changed in line 11. 

import numpy as np
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os

quant_param = 4 # log2(quant_param) bits. #256 means no quanization, 2 means 1-bit quantization

def quantize_image(image):
    """Quantize and binarize an image."""
    image = image.astype(np.float32)
    image = np.floor(image / (256/quant_param))  # Example: reducing to 4-bit quantization
    return image.astype(np.float32)

def save_dataset_as_npz(data, labels, filename):
    """Save the dataset as a .npz file."""
    np.savez_compressed(filename, data=data, labels=labels)

# Load CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Quantize the images
train_images = np.array([quantize_image(image.numpy().transpose(1, 2, 0) * 255) for image, _ in train_dataset])
train_labels = np.array(train_dataset.targets)

test_images = np.array([quantize_image(image.numpy().transpose(1, 2, 0) * 255) for image, _ in test_dataset])
test_labels = np.array(test_dataset.targets)

# Save the datasets
os.makedirs('./quantized_data', exist_ok=True)
save_dataset_as_npz(train_images, train_labels, './quantized_data/cifar10_train.npz')
save_dataset_as_npz(test_images, test_labels, './quantized_data/cifar10_test.npz')

print('Saved the dataset as .npz files')


class CIFAR10QuantizedDataset(Dataset):
    def __init__(self, npz_file):
        data = np.load(npz_file)
        self.images = data['data']
        self.labels = data['labels']

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].astype(np.float32) / 255.0
        label = self.labels[idx]
        image = torch.tensor(image.transpose(2, 0, 1))  # HWC to CHW format
        label = torch.tensor(label, dtype=torch.long)
        return image, label

# Load the quantized dataset
train_dataset = CIFAR10QuantizedDataset('./quantized_data/cifar10_train.npz')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = CIFAR10QuantizedDataset('./quantized_data/cifar10_test.npz')
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified
Saved the dataset as .npz files


### Training and testing

In [68]:
import torch.optim as optim

def train(model, train_loader, optimizer, criterion):
    losses = []
    # ensure model is in training mode
    model.train()    
    
    for i, data in enumerate(train_loader, 0):        
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()   
                
        # forward pass
        output = model(images.float())
        # loss = criterion(output, labels.unsqueeze(1))
        loss = criterion(output, labels)
        
        # backward pass + run optimizer to update weights
        loss.backward()
        optimizer.step()
        
        # keep track of loss value
        losses.append(loss.data.cpu().numpy()) 
    return losses


def test(model, test_loader):    
    # ensure model is in eval mode
    model.eval() 
    y_true = []
    y_pred = []
   
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            output = model(images.float())
            # run the output through sigmoid
            # output = torch.sigmoid(output_orig)  
            # compare against a threshold of 0.5 to generate 0/1
            # pred = (output.detach().cpu().numpy() > 0.5) * 1
            _, pred = torch.max(output.data, 1)
            labels = labels.cpu().float()
            y_true.extend(labels.tolist()) 
            y_pred.extend(pred.reshape(-1).tolist())
            # y_pred.extend((pred == labels).sum().item())
        
    return accuracy_score(y_true, y_pred)

num_epochs = 10
lr = 0.001 

def display_loss_plot(losses, title="Training loss", xlabel="Iterations", ylabel="Loss"):
    x_axis = [i for i in range(len(losses))]
    plt.plot(x_axis,losses)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

# loss criterion and optimizer
# criterion = nn.BCEWithLogitsLoss().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Setting seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

running_loss = []
running_test_acc = []
t = trange(num_epochs, desc="Training loss", leave=True)

for epoch in t:
        loss_epoch = train(model, train_loader, optimizer,criterion)
        test_acc = test(model, test_loader)
        t.set_description("Training loss = %f test accuracy = %f" % (np.mean(loss_epoch), test_acc))
        t.refresh() # to show immediately the update           
        running_loss.append(loss_epoch)
        running_test_acc.append(test_acc)


Training loss = 1.666302 test accuracy = 0.387600: 100%|█| 10/10 [00:59<00:00,  


In [69]:
test_accuracy = test(model, test_loader)
print('test accuracy = ', test_accuracy)

test accuracy =  0.3876


In [57]:
# Save the Brevitas model to disk
torch.save(model.state_dict(), "state_dict_LeNet_WeightAct.pth")

####  Convert to ONNX model

In [60]:
# from brevitas.export import export_onnx_qcdq
# import torch

# # Weight-activation model
# export_onnx_qcdq(model, torch.randn(1, 3, 32, 32), export_path='4b_weight_act_lenet.onnx');




####  Convert to QONNX model

In [70]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

ready_model_filename = '4b_weight_act_lenet_qonnx.onnx'
input_shape = (1, 3, 32, 32)

# create a QuantTensor instance to mark input as bipolar during export
input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
input_a = 2 * input_a - 1
scale = 1.0
input_t = torch.from_numpy(input_a * scale)

#Move to CPU before export
model.cpu()

# Export to ONNX
export_qonnx(
    model, export_path=ready_model_filename, input_t=input_t
)

# clean-up
qonnx_cleanup(ready_model_filename, out_file=ready_model_filename)

# ModelWrapper
model = ModelWrapper(ready_model_filename)
# Setting the input datatype explicitly because it doesn't get derived from the export function
model.set_tensor_datatype(model.graph.input[0].name, DataType["BIPOLAR"])
model = model.transform(ConvertQONNXtoFINN())
model.save(ready_model_filename)

print("Model saved to %s" % ready_model_filename)

Model saved to 4b_weight_act_lenet_qonnx.onnx




In [71]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

Stopping http://0.0.0.0:8081
Serving '4b_weight_act_lenet_qonnx.onnx' at http://0.0.0.0:8081


### Network surgery

In [62]:
# Move the model to CPU before surgery
# model = model.cpu()
# model

In [48]:
from copy import deepcopy

modified_model = deepcopy(model)

# W_orig = modified_model.conv1.weight#.data.detach().numpy()
# W_orig
# W_orig.shape

In [51]:
from brevitas.nn import QuantIdentity


class LeNetForExport(nn.Module):
    def __init__(self, my_pretrained_model):
        super(LeNetForExport, self).__init__()
        self.pretrained = my_pretrained_model
        self.qnt_output = QuantIdentity(
            quant_type='binary', 
            scaling_impl_type='const',
            bit_width=1, min_val=-1.0, max_val=1.0)
    
    def forward(self, x):
        # assume x contains bipolar {-1,1} elems
        # shift from {-1,1} -> {0,1} since that is the
        # input range for the trained network
        x = (x + torch.tensor([1.0]).to(x.device)) / 2.0  
        out_original = self.pretrained(x)
        out_final = self.qnt_output(out_original)   # output as {-1,1}     
        return out_final

model_for_export = LeNetForExport(modified_model)
model_for_export.to(device);

In [50]:
# W_orig = model_for_export.pretrained.conv1.weight#.data.detach().numpy()
# W_orig

Parameter containing:
tensor([[[[-0.0009,  0.0619, -0.0950, -0.0850, -0.0445],
          [ 0.0310, -0.0023,  0.0916, -0.0102,  0.0306],
          [-0.0349, -0.0227, -0.1103, -0.0765, -0.0476],
          [ 0.0043,  0.0456,  0.0693, -0.0783, -0.0503],
          [ 0.0419,  0.0959, -0.0238,  0.0864, -0.0186]],

         [[ 0.0122,  0.1046, -0.1071, -0.0727, -0.0292],
          [-0.0450,  0.0998, -0.0748, -0.0532, -0.0807],
          [-0.1081, -0.0674,  0.0993,  0.0515,  0.0560],
          [ 0.0061, -0.0592,  0.0195, -0.1078, -0.0834],
          [-0.0595,  0.0729,  0.0677, -0.0512, -0.0042]],

         [[ 0.0739,  0.1148,  0.0458,  0.0156,  0.0774],
          [-0.0680,  0.0215, -0.0895, -0.0800, -0.0596],
          [ 0.0522,  0.0464, -0.0684,  0.0349,  0.0634],
          [-0.0146,  0.0044,  0.0268,  0.0716,  0.1109],
          [-0.0890, -0.0423,  0.0454,  0.0957,  0.1005]]],


        [[[ 0.1019,  0.0230, -0.1004,  0.0106, -0.0722],
          [-0.1076,  0.1026,  0.0878, -0.1152,  0.0216],
 

In [45]:
# def test_padded_bipolar(model, test_loader):    
#     # ensure model is in eval mode
#     model.eval() 
#     y_true = []
#     y_pred = []
   
#     with torch.no_grad():
#         for data in test_loader:
#             inputs, target = data
#             inputs, target = inputs.to(device), target.to(device)
#             # pad inputs to 600 elements
#             input_padded = torch.nn.functional.pad(inputs, (0,7,0,0))
#             # convert inputs to {-1,+1}
#             input_scaled = 2 * input_padded - 1
#             # run the model
#             output = model(input_scaled.float())
#             y_pred.extend(list(output.flatten().cpu().numpy()))
#             # make targets bipolar {-1,+1}
#             expected = 2 * target.float() - 1
#             expected = expected.cpu().numpy()
#             y_true.extend(list(expected.flatten()))
        
#     return accuracy_score(y_true, y_pred)

def test_padded_bipolar(model, test_loader):    
    # ensure model is in eval mode
    model.eval() 
    y_true = []
    y_pred = []
   
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            output = model(images.float())
            # run the output through sigmoid
            # output = torch.sigmoid(output_orig)  
            # compare against a threshold of 0.5 to generate 0/1
            # pred = (output.detach().cpu().numpy() > 0.5) * 1
            _, pred = torch.max(output.data, 1)
            labels = labels.cpu().float()
            y_true.extend(labels.tolist()) 
            y_pred.extend(pred.reshape(-1).tolist())
            # y_pred.extend((pred == labels).sum().item())
        
    return accuracy_score(y_true, y_pred)

test_padded_bipolar(model_for_export, test_loader)

0.1