In [168]:
import torch.nn.functional as F
from brevitas.nn import QuantConv2d, QuantLinear, QuantReLU, QuantIdentity
from brevitas.quant import Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
import torch
import onnx
from finn.util.test import get_test_model_trained
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from finn.util.basic import make_build_dir
from finn.util.visualization import showInNetron
import os


In [169]:
! nvidia-smi
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Fri Jun 20 16:52:31 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.07             Driver Version: 570.133.07     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX 4000 Ada Gene...    Off |   00000000:02:00.0  On |                  Off |
| 30%   40C    P8             11W /  130W |    3309MiB /  20475MiB |     37%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [170]:
# print current working directory
print(f"Current working directory: {os.getcwd()}")

Current working directory: /home/changhong/prj/finn/notebooks


In [171]:
notebook_name = "/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST_BNN"
finn_root = os.getcwd()
build_dir = finn_root+ notebook_name +"/build"
model_dir = finn_root+ notebook_name +"/model"
data_dir = finn_root+ notebook_name +"/data"

# Create directories if they do not exist
os.makedirs(build_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
print(f"Data directory: {data_dir}")
print(f"Finn root directory: {finn_root}")
print(f"Build directory: {build_dir}")
print(f"Model directory: {model_dir}")

Data directory: /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST_BNN/data
Finn root directory: /home/changhong/prj/finn/notebooks
Build directory: /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST_BNN/build
Model directory: /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST_BNN/model


# Model define
This CNV model is modified from brevitas one

In [172]:
import torch
from torch.nn import BatchNorm1d
from torch.nn import BatchNorm2d
from torch.nn import MaxPool2d, AvgPool2d
from torch.nn import Module
from torch.nn import ModuleList

from brevitas.core.restrict_val import RestrictValueType
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantLinear

from brevitas_examples.bnn_pynq.models.common import CommonActQuant
from brevitas_examples.bnn_pynq.models.common import CommonWeightQuant
from brevitas_examples.bnn_pynq.models.tensor_norm import TensorNorm




# Build a standard CNV 1W1A

In [173]:
import configparser

# CNV_OUT_CH_POOL = [(64, False), (64, True), (128, False), (128, True), (256, False), (256, False)]
# INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)]
# LAST_FC_IN_FEATURES = 512
# LAST_FC_PER_OUT_CH_SCALING = False
# POOL_SIZE = 2
# KERNEL_SIZE = 3

# LeNet-5
CNV_OUT_CH_POOL = [(6, True), (16, True), (120, False)]  
INTERMEDIATE_FC_FEATURES = [(120, 84)]  
LAST_FC_IN_FEATURES = 84 
LAST_FC_PER_OUT_CH_SCALING = False
POOL_SIZE = 2  
KERNEL_SIZE = 5  

model_name = '2c3f1w1a_mnist'

class CNV(Module):

    def __init__(self, num_classes, weight_bit_width, act_bit_width, in_bit_width, in_ch):
        super(CNV, self).__init__()

        self.conv_features = ModuleList()
        self.linear_features = ModuleList()

        self.conv_features.append(QuantIdentity( # for Q1.7 input format
            act_quant=CommonActQuant,
            bit_width=in_bit_width,
            min_val=- 1.0,
            max_val=1.0 - 2.0 ** (-7),
            narrow_range=False,
            restrict_scaling_type=RestrictValueType.POWER_OF_TWO))

        for out_ch, is_pool_enabled in CNV_OUT_CH_POOL:
            self.conv_features.append(
                QuantConv2d(
                    kernel_size=KERNEL_SIZE,
                    in_channels=in_ch,
                    out_channels=out_ch,
                    bias=False,
                    weight_quant=CommonWeightQuant,
                    weight_bit_width=weight_bit_width))
            in_ch = out_ch
            self.conv_features.append(BatchNorm2d(in_ch, eps=1e-4))
            self.conv_features.append(
                QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width))
            if is_pool_enabled:
                #self.conv_features.append(MaxPool2d(kernel_size=2))
                self.conv_features.append(AvgPool2d(kernel_size=2))

        for in_features, out_features in INTERMEDIATE_FC_FEATURES:
            self.linear_features.append(
                QuantLinear(
                    in_features=in_features,
                    out_features=out_features,
                    bias=False,
                    weight_quant=CommonWeightQuant,
                    weight_bit_width=weight_bit_width))
            self.linear_features.append(BatchNorm1d(out_features, eps=1e-4))
            self.linear_features.append(
                QuantIdentity(act_quant=CommonActQuant, bit_width=act_bit_width))

        self.linear_features.append(
            QuantLinear(
                in_features=LAST_FC_IN_FEATURES,
                out_features=num_classes,
                bias=False,
                weight_quant=CommonWeightQuant,
                weight_bit_width=weight_bit_width))
        self.linear_features.append(TensorNorm())

        for m in self.modules():
            if isinstance(m, QuantConv2d) or isinstance(m, QuantLinear):
                torch.nn.init.uniform_(m.weight.data, -1, 1)

    def clip_weights(self, min_val, max_val):
        for mod in self.conv_features:
            if isinstance(mod, QuantConv2d):
                mod.weight.data.clamp_(min_val, max_val)
        for mod in self.linear_features:
            if isinstance(mod, QuantLinear):
                mod.weight.data.clamp_(min_val, max_val)

    def forward(self, x):
        x = 2.0 * x - torch.tensor([1.0], device=x.device)
        for mod in self.conv_features:
            x = mod(x)
        x = x.view(x.shape[0], -1)
        for mod in self.linear_features:
            x = mod(x)
        return x 


def cnv(cfg):
    weight_bit_width = cfg.getint('QUANT', 'WEIGHT_BIT_WIDTH')
    act_bit_width = cfg.getint('QUANT', 'ACT_BIT_WIDTH')
    in_bit_width = cfg.getint('QUANT', 'IN_BIT_WIDTH')
    num_classes = cfg.getint('MODEL', 'NUM_CLASSES')
    in_channels = cfg.getint('MODEL', 'IN_CHANNELS')
    net = CNV(
        weight_bit_width=weight_bit_width,
        act_bit_width=act_bit_width,
        in_bit_width=in_bit_width,
        num_classes=num_classes,
        in_ch=in_channels)
    return net

config = configparser.ConfigParser()
config['MODEL'] = {
    'NUM_CLASSES': '10',
    'IN_CHANNELS': '1',
    'DTASET': 'MNIST',
}
config['QUANT'] = {
    'WEIGHT_BIT_WIDTH': '1',
    'ACT_BIT_WIDTH': '1',
    'IN_BIT_WIDTH': '8',
}

model = cnv(config)


In [174]:
# export the model to ONNX format
onnx_model_path = model_dir + f"/{model_name}.onnx"
# torch.onnx.export(cfg_6c3f_1w1a_minst, 
#                   torch.randn(1, 1, 32, 32), 
#                   onnx_model_path, 
#                   input_names=['input'], 
#                   output_names=['output'], 
#                   dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})


export_qonnx(model, torch.randn(1, 1, 32, 32), onnx_model_path)
qonnx_cleanup(onnx_model_path, out_file=onnx_model_path)



In [175]:
# Load the MNIST dataset
transform = transforms.Compose([
             transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()]
)


full_train_set = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
train_size = int(0.8 * len(full_train_set))
val_size = len(full_train_set) - train_size

train_set, val_set = torch.utils.data.random_split(full_train_set, [train_size, val_size])
test_set = datasets.MNIST(data_dir, train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)

print("Dataset split complete:")
print(f"Training set: {len(train_set)} samples")
print(f"Validation set: {len(val_set)} samples")
print(f"Test set: {len(test_set)} samples")

Dataset split complete:
Training set: 48000 samples
Validation set: 12000 samples
Test set: 10000 samples


In [176]:
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_loss = running_loss / len(loader)
    train_acc = 100 * correct / total
    return train_loss, train_acc

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_loss = running_loss / len(loader)
    val_acc = 100 * correct / total
    return val_loss, val_acc

def test(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    test_acc = 100 * correct / total
    return test_acc


In [None]:
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 500

# 选择一种学习率调度策略
# 1. 阶梯式下降
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# 2. 指数衰减
# scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

# 3. 余弦退火
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# 4. 带热重启的余弦退火
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1)

# 5. 根据指标降低学习率（如验证集loss不再下降时）
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=10, verbose=True)


best_val_acc = 0.0

model_save_path = model_dir + f"/{model_name}.pth"
best_model_save_path = model_dir + f"/best_{model_name}.pth"

for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    if hasattr(model, 'clip_weights'):
        model.clip_weights(-1.0, 1.0)

    # 打印当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f'Epoch [{epoch+1}/{num_epochs}], '
          f'LR: {current_lr:.6f}, '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    # 根据验证集准确率调整学习率（对于ReduceLROnPlateau）
    scheduler.step(val_acc)  # 如果是其他调度器，直接使用 scheduler.step()
    
    # 保存最佳模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_save_path)

# 测试最佳模型
model.load_state_dict(torch.load(best_model_save_path))
test_acc = test(model, test_loader, device)
print(f'Test Accuracy of the best model on the test images: {test_acc:.2f}%')


  return F.conv2d(input, weight, bias, self.stride,
  output_tensor = linear(x, quant_weight, quant_bias)
