In [225]:
import torch.nn.functional as F
from brevitas.nn import QuantConv2d, QuantLinear, QuantReLU, QuantIdentity
from brevitas.quant import Int8WeightPerTensorFixedPoint, Int8ActPerTensorFixedPoint
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
import torch
import onnx
from finn.util.test import get_test_model_trained
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN
from qonnx.transformation.infer_shapes import InferShapes
from qonnx.transformation.fold_constants import FoldConstants
from qonnx.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames, RemoveStaticGraphInputs
from finn.util.basic import make_build_dir
from finn.util.visualization import showInNetron
import os


In [226]:
notebook_name = "/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST"
finn_root = os.getcwd()
build_dir = finn_root+ notebook_name +"/build"
model_dir = finn_root+ notebook_name +"/model"
data_dir = finn_root+ notebook_name +"/data"

# Create directories if they do not exist
os.makedirs(build_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
print(f"Data directory: {data_dir}")
print(f"Finn root directory: {finn_root}")
print(f"Build directory: {build_dir}")
print(f"Model directory: {model_dir}")

Data directory: /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST/data
Finn root directory: /home/changhong/prj/finn/notebooks
Build directory: /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST/build
Model directory: /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST/model


In [227]:
# Hyperparameters
EPOCHS = 50
LR = 0.001
BATACH_SIZE = 128
RANDOM_SEED = 1998

In [228]:
def analyze_model_sparsity(model):
    param_details = []  # 存储各层详细信息
    
    # 收集所有参数信息
    for name, module in model.named_modules():
        if not name:  # 跳过根模块
            continue
            
        layer_type = module.__class__.__name__
        
        for param_name, param in module.named_parameters(recurse=False):
            param_count = param.numel()
            non_zero_count = torch.count_nonzero(param).item()
            sparsity = 100 * (1 - non_zero_count / param_count) if param_count > 0 else 0
            param_details.append((f"{name}.{param_name}", layer_type, param_count, non_zero_count, sparsity))
    
    # 打印详细参数信息（对齐列）
    if param_details:
        # 计算每列最大宽度
        max_path = max(len(str(d[0])) for d in param_details)
        max_type = max(len(str(d[1])) for d in param_details)
        max_count = max(len(f"{d[2]:,}") for d in param_details)  # 带千位分隔符
        max_nonzero = max(len(f"{d[3]:,}") for d in param_details)
        
        # 打印表头
        print(f"{'Parameter Path':<{max_path}} | {'Layer Type':<{max_type}} | {'Param Count':>{max_count}} | {'Non-zero':>{max_nonzero}} | {'Sparsity (%)':>10}")
        print("-" * (max_path + max_type + max_count + max_nonzero + 30))  # 动态分隔线长度
        
        # 打印每行数据
        for detail in param_details:
            print(f"{detail[0]:<{max_path}} | {detail[1]:<{max_type}} | {detail[2]:>{max_count},} | {detail[3]:>{max_nonzero},} | {detail[4]:>10.2f}%")
    
    # 计算并打印总统计
    if param_details:
        total_params = sum(d[2] for d in param_details)
        total_non_zero = sum(d[3] for d in param_details)
        total_sparsity = 100 * (1 - total_non_zero / total_params) if total_params > 0 else 0
        
        print("-" * (max_path + max_type + max_count + max_nonzero + 30))
        print(f"{'TOTAL':<{max_path}} | {'-':<{max_type}} | {total_params:>{max_count},} | {total_non_zero:>{max_nonzero},} | {total_sparsity:>10.2f}%")
        compression_ratio = total_params / total_non_zero if total_non_zero > 0 else float('inf')
        print(f"\nCompression Ratio (pruning only): {compression_ratio:.2f}x")
        
        # 如果考虑8-bit量化（1字节/参数）与FP32（4字节/参数）的对比
        effective_compression = 4 * compression_ratio  # 4 = 32/8
        print(f"Effective Compression (pruning + quantization): {effective_compression:.2f}x")


In [229]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split



# Define data transformations
# Normalize using MNIST mean (0.1307) and std deviation (0.3081)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load full MNIST training set (60,000 samples)
full_train_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=True,       # Load training set (not test set)
    download=True,    # Download if not exists
    transform=transform
)

# Standard academic split (5:1:1 ratio)
# Train: 50,000 | Val: 10,000 | Test: 10,000 (official test set)
train_size = 50000    # 5/6 of training data
val_size = 10000      # 1/6 of training data

# Split the full training set into train/val subsets
train_dataset, val_dataset = random_split(
    full_train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(RANDOM_SEED)  # Ensure reproducibility
)

# Load official test set (10,000 samples)
test_dataset = torchvision.datasets.MNIST(
    root=data_dir,
    train=False,      # Load test set
    download=True,
    transform=transform
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATACH_SIZE,
    shuffle=True      # Shuffle training data
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=BATACH_SIZE,
    shuffle=False     # No need to shuffle validation data
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATACH_SIZE,  # Larger batch for evaluation
    shuffle=False
)

# Print dataset statistics
print("Dataset split complete:")
print(f"Training set: {len(train_dataset)} samples")
print(f"Validation set: {len(val_dataset)} samples")
print(f"Test set: {len(test_dataset)} samples")


Dataset split complete:
Training set: 50000 samples
Validation set: 10000 samples
Test set: 10000 samples


In [230]:
from torch.nn import Module
import torch.nn as nn
import torch.nn.functional as F
import brevitas.nn as qnn
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
import torch


total_bits = 8   #width for weights and activations
n = 7            #fractional part
class LeNet5(Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # Changhong, add input quant layer
        # self.quant_input = qnn.QuantIdentity(
        #     quant_type=QuantType.INT,
        #     bit_width=total_bits,
        #     max_val=1.0,  # 假设输入数据已归一化到 [0,1]
        #     restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
        #     scaling_impl_type=ScalingImplType.CONST
        # )
        self.quant_input = qnn.QuantIdentity(
            quant_type=QuantType.INT,
            bit_width=8,
            scaling_init=1.0,               # 初始缩放因子
            scaling_impl_type=ScalingImplType.PARAMETER,  # 允许训练中调整
            scaling_per_output_channel=False,  # 对于单通道输入，通常不需要每个输出通道的缩放
            )


        self.conv1 = qnn.QuantConv2d(in_channels= 1,
                                     out_channels= 20,
                                     kernel_size= 3,
                                     padding= 1,
                                     bias= False,
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width= total_bits,
                                     weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                     weight_scaling_impl_type=ScalingImplType.CONST,
                                     weight_scaling_const=1.0)
        self.relu1 = qnn.QuantReLU(quant_type=QuantType.INT, 
                                   bit_width=8, 
                                   max_val= 1- 1/128.0,
                                   restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                   scaling_impl_type=ScalingImplType.CONST )

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        

        self.conv2 = qnn.QuantConv2d(in_channels= 20,
                                     out_channels= 50,
                                     kernel_size= 3,
                                     padding= 1,
                                     bias= False,
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=8,
                                     weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                     weight_scaling_impl_type=ScalingImplType.CONST,
                                     weight_scaling_const=1.0 )

        self.relu2 = qnn.QuantReLU(quant_type=QuantType.INT, 
                                   bit_width=8, 
                                   max_val= 1- 1/128.0,
                                   restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                   scaling_impl_type=ScalingImplType.CONST )

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        """
        # for 32-bit precision FC layers
        self.fc1   = nn.Linear(7*7*50, 500)

        self.relu3 = nn.ReLU()

        self.fc2   = nn.Linear(500,10)

        """
        

        # for fixed-point precision FC layers
        self.fc1   = qnn.QuantLinear(7*7*50, 500,
                                     bias= True,
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=32,
                                     weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                     weight_scaling_impl_type=ScalingImplType.CONST,
                                     weight_scaling_const=1.0)
        

        self.relu3 = qnn.QuantReLU(quant_type=QuantType.INT, 
                                   bit_width=8, 
                                   max_val= 1- 1/128.0,
                                   restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                   scaling_impl_type=ScalingImplType.CONST )

        self.fc2   = qnn.QuantLinear(500, 10,
                                     bias= True,
                                     weight_quant_type=QuantType.INT, 
                                     weight_bit_width=8,
                                     weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
                                     weight_scaling_impl_type=ScalingImplType.CONST,
                                     weight_scaling_const=1.0)
    def forward(self, x):
        out = self.quant_input(x)  # Apply input quantization
        out = self.relu1(self.conv1(out))
        out = F.max_pool2d(out, 2)
        out = self.relu2(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.relu3(self.fc1(out))
        out = self.fc2(out)
        out = F.log_softmax(out, dim=1)
        return out

In [231]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [232]:
from sklearn.metrics import accuracy_score
def test(model, test_loader, device):
    # ensure model is in eval mode
    model.eval() 
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs, target = inputs.to(device), target.to(device)
            outputs = model(inputs)
            
            # For multi-class classification (10 classes for MNIST)
            _, predicted = torch.max(outputs.data, 1)  # get the index of the max log-probability
            
            y_true.extend(target.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    print(f"Test Top-1 accuracy(%): {accuracy_score(y_true, y_pred) * 100:.2f}%")
    print(f"Test Top-1 error rate(%):) {(1 - accuracy_score(y_true, y_pred)) * 100:.2f}%")

In [233]:
model = LeNet5()
model.load_state_dict(torch.load(model_dir + "/lenet_mnist_int8.pth"))
analyze_model_sparsity(model)
model.to(device)
test(model, test_loader, device)

Parameter Path                                                                     | Layer Type       | Param Count |  Non-zero | Sparsity (%)
--------------------------------------------------------------------------------------------------------------------------------------------------
quant_input.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value | ParameterScaling |         1 |         1 |       0.00%
conv1.weight                                                                       | QuantConv2d      |       180 |       180 |       0.00%
conv2.weight                                                                       | QuantConv2d      |     9,000 |     9,000 |       0.00%
fc1.weight                                                                         | QuantLinear      | 1,225,000 | 1,225,000 |       0.00%
fc1.bias                                                                           | QuantLinear      |       500 |       500 |       0.00%
fc2.weight

  return F.conv2d(input, weight, bias, self.stride,
  output_tensor = linear(x, quant_weight, quant_bias)


Test Top-1 accuracy(%): 99.13%
Test Top-1 error rate(%):) 0.87%


In [234]:
import copy
import torch.nn.utils.prune as prune


def apply_l1_pruning_type(model, layer_type, pruning_percentage=0.5):
    # layer_type: QuantLinear or QuantConv2d
    model_tmp = copy.deepcopy(model).to(device)
    for name, module in model_tmp.named_modules():
        if isinstance(module, layer_type):
            # if the para number <100, skip pruning
            if module.weight.numel() > 10:
                prune.l1_unstructured(module, name='weight', amount=pruning_percentage)
                prune.remove(module, 'weight')  
    return model_tmp



def apply_l1_pruning_by_param_path(model, param_path, pruning_percentage=0.5, verbose=False):
    """
    改进版：确保能正确剪枝所有层类型
    
    参数:
        verbose: 打印剪枝详细信息
    """
    model_copy = copy.deepcopy(model)
    model_copy.to(device)
    
    # 分割路径
    parts = param_path.split('.')
    param_name = parts[-1]
    module_path = '.'.join(parts[:-1])
    
    # 递归查找模块
    target_module = model_copy
    if module_path:  # 如果不是顶级参数
        for part in module_path.split('.'):
            try:
                target_module = getattr(target_module, part)
            except AttributeError:
                raise ValueError(f"Module path '{module_path}' not found (failed at '{part}')")
    
    # 检查参数
    if not hasattr(target_module, param_name):
        raise ValueError(f"Parameter '{param_name}' not found in module '{module_path}'")
    
    param = getattr(target_module, param_name)
    if verbose:
        print(f"Pruning {param_path} | Shape: {tuple(param.shape)} | Elements: {param.numel()}")

    # 执行剪枝（移除数量>1的限制）
    prune.l1_unstructured(target_module, name=param_name, amount=pruning_percentage)
    prune.remove(target_module, param_name)
    
    return model_copy





pmodel = apply_l1_pruning_by_param_path(model, 'conv1.weight', 0)
pmodel = apply_l1_pruning_by_param_path(pmodel, 'conv2.weight', 0.4)
pmodel = apply_l1_pruning_by_param_path(pmodel, 'fc1.weight', 0.995)
pmodel = apply_l1_pruning_by_param_path(pmodel, 'fc1.bias', 0)
pmodel = apply_l1_pruning_by_param_path(pmodel, 'fc2.weight', 0.5)
pmodel = apply_l1_pruning_by_param_path(pmodel, 'fc2.weight', 0.5)

print("="*50 + " After Pruning " + "="*50)

analyze_model_sparsity(pmodel) 
test(pmodel, test_loader, device)

Parameter Path                                                                     | Layer Type       | Param Count | Non-zero | Sparsity (%)
----------------------------------------------------------------------------------------------------------------------------------------------
quant_input.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value | ParameterScaling |         1 |     1 |       0.00%
conv1.weight                                                                       | QuantConv2d      |       180 |   180 |       0.00%
conv2.weight                                                                       | QuantConv2d      |     9,000 | 5,400 |      40.00%
fc1.weight                                                                         | QuantLinear      | 1,225,000 | 6,125 |      99.50%
fc1.bias                                                                           | QuantLinear      |       500 |   500 |       0.00%
fc2.bias                           

  return F.conv2d(input, weight, bias, self.stride,
  output_tensor = linear(x, quant_weight, quant_bias)


Test Top-1 accuracy(%): 64.38%
Test Top-1 error rate(%):) 35.62%


In [235]:
def freeze_zero_weights(model):
    for name, param in model.named_parameters():
        if "weight" in name and param.requires_grad:
            mask = (param != 0).float()
            param.register_hook(lambda grad, mask=mask: grad * mask)

def retrain_model(model, train_loader, num_epochs=5):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()  # MNIST使用交叉熵损失
    
    # 冻结零权重（确保你的freeze_zero_weights实现支持稀疏梯度）
    freeze_zero_weights(model)  

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for data in train_loader:
            inputs, target = data
            inputs, target = inputs.to(device), target.to(device)
            
            # MNIST数据不需要.float()转换
            optimizer.zero_grad()
            output = model(inputs)  # 移除了.float()
            loss = criterion(output, target)  # 注意target不需要.float()
            loss.backward()
            
            # 跳过冻结权重的更新
            for name, param in model.named_parameters():
                if param.grad is not None and torch.all(param == 0):
                    param.grad = None
            
            optimizer.step()
            
            # 计算统计量
            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")
    
    return model


retrained_model = retrain_model(pmodel, train_loader, num_epochs=30)
print("="*50 + " After Retraining " + "="*50)
analyze_model_sparsity(retrained_model)   # retrained model
test(retrained_model, test_loader, device)

  return F.conv2d(input, weight, bias, self.stride,
  output_tensor = linear(x, quant_weight, quant_bias)


Epoch 1/30 | Loss: 0.1120 | Acc: 97.38%
Epoch 2/30 | Loss: 0.0379 | Acc: 98.98%
Epoch 3/30 | Loss: 0.0283 | Acc: 99.23%
Epoch 4/30 | Loss: 0.0230 | Acc: 99.34%
Epoch 5/30 | Loss: 0.0194 | Acc: 99.46%
Epoch 6/30 | Loss: 0.0163 | Acc: 99.58%
Epoch 7/30 | Loss: 0.0145 | Acc: 99.64%
Epoch 8/30 | Loss: 0.0122 | Acc: 99.73%
Epoch 9/30 | Loss: 0.0107 | Acc: 99.77%
Epoch 10/30 | Loss: 0.0094 | Acc: 99.80%
Epoch 11/30 | Loss: 0.0082 | Acc: 99.83%
Epoch 12/30 | Loss: 0.0072 | Acc: 99.86%
Epoch 13/30 | Loss: 0.0065 | Acc: 99.89%
Epoch 14/30 | Loss: 0.0056 | Acc: 99.91%
Epoch 15/30 | Loss: 0.0047 | Acc: 99.95%
Epoch 16/30 | Loss: 0.0042 | Acc: 99.94%
Epoch 17/30 | Loss: 0.0036 | Acc: 99.97%
Epoch 18/30 | Loss: 0.0033 | Acc: 99.97%
Epoch 19/30 | Loss: 0.0028 | Acc: 99.98%
Epoch 20/30 | Loss: 0.0024 | Acc: 99.98%
Epoch 21/30 | Loss: 0.0021 | Acc: 99.99%
Epoch 22/30 | Loss: 0.0020 | Acc: 99.99%
Epoch 23/30 | Loss: 0.0020 | Acc: 99.99%
Epoch 24/30 | Loss: 0.0013 | Acc: 100.00%
Epoch 25/30 | Loss: 0.00

In [236]:
# accuracy drop down: 99.28% -> 99.01% => 0.27% < 0.3%
# it is no-loss compression
torch.save(retrained_model.state_dict(), model_dir +"/"+'lenet_minst_int8_prune' + ".pth")

# Export to ONNX
retrained_model.cpu()  # Ensure model is on CPU for ONNX export
onnx_model_path = model_dir + "/lenet_mnist_int8_pruned.onnx"
torch.onnx.export(retrained_model,
                  torch.randn(1, 1, 28, 28),
                  onnx_model_path,
                  input_names=['input'],
                  output_names=['output'],
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
                  opset_version=11) 

  signed = torch.tensor(signed, dtype=torch.bool)
  training = torch.tensor(training, dtype=torch.bool)
  return F.conv2d(input, weight, bias, self.stride,
  output_tensor = linear(x, quant_weight, quant_bias)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [237]:
import numpy as np

ready_model_filename = model_dir + "/lenet_mnist_int8_prune_ready.onnx"

input_shape = (1, 1, 28, 28)

input_a = np.random.randint(0, 1, size=input_shape).astype(np.float32)
input_a = 2 * input_a - 1
scale = 1.0
input_t = torch.from_numpy(input_a * scale)

#Move to CPU before export
retrained_model.cpu()

# Export to ONNX
export_qonnx(
    retrained_model, export_path=ready_model_filename, input_t=input_t
)

# clean-up
qonnx_cleanup(ready_model_filename, out_file=ready_model_filename)

print("Model saved to %s" % ready_model_filename)

Model saved to /home/changhong/prj/finn/notebooks/EF-US-Engine-Free-Unstructured-Sparsity-Design-Alleviates-Accelerator-Bottlenecks/casestudy1/LeNet_MNIST/model/lenet_mnist_int8_prune_ready.onnx
