In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
import time

# 假设你的 QuantConv2d 等定义在 models.py 中
# 如果 models.py 也是你自己写的，请确保里面的 tensor 也是创建在正确的 device 上
from models import * # ==========================================
# 1. 设备配置 (MacBook 适配核心)
# ==========================================
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")  # Apple Silicon GPU 加速
    elif torch.cuda.is_available():
        return torch.device("cuda") # NVIDIA GPU
    else:
        return torch.device("cpu")  # 传统 CPU

device = get_device()
print(f"当前运行设备: {device}")

# ==========================================
# 2. 模型定义 (VGG16_Part1) - 已修改保留 BN
# ==========================================
class VGG16_Part1(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16_Part1, self).__init__()

        self.features = nn.Sequential(
            # --- Block 1 ---
            QuantConv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            QuantConv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 2 ---
            QuantConv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            QuantConv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 3 ---
            QuantConv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            QuantConv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            QuantConv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 4 ---
            QuantConv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 5 (修改区域) ---
            # 原本的第一层: 512 -> 512
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),

            # ================= PART 1 核心修改: 8x8 Squeezed Layer =================
            # 1. 适配层 (Adapter): 512 -> 8 (使用 1x1 卷积降维)
            QuantConv2d(512, 8, kernel_size=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),

            # 2. 目标层 (Target Layer): 8 -> 8 (3x3 卷积)
            # 【修改点】：保留 BN
            QuantConv2d(8, 8, kernel_size=3, padding=1, bias=False), # 有BN通常不需要bias
            nn.BatchNorm2d(8), # <--- 这里的 BN 被保留了
            nn.ReLU(inplace=True),

            # 3. 恢复层 (Expand): 8 -> 512 (使用 1x1 卷积升维)
            QuantConv2d(8, 512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # ===================================================================

            # Block 5 剩余部分
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AvgPool2d(kernel_size=1, stride=1),
        )

        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# 实例化并移至对应设备 (MPS or CPU)
model_part1 = VGG16_Part1().to(device)
print("VGG16_Part1 模型构建完成。目标 8x8 层已保留 BN。")

# ==========================================
# 3. 权重加载函数 (适配 Mac)
# ==========================================
def load_pretrained_weights(model, pretrained_path):
    if os.path.isfile(pretrained_path):
        print(f"=> loading checkpoint '{pretrained_path}'")

        # 【关键修改】：map_location 确保在 Mac 上能加载 CUDA 训练的权重
        checkpoint = torch.load(pretrained_path, map_location=device)

        pretrained_dict = checkpoint['state_dict']
        model_dict = model.state_dict()

        # 过滤不匹配层
        pretrained_dict = {k: v for k, v in pretrained_dict.items()
                           if k in model_dict and v.shape == model_dict[k].shape}

        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        print(f"已加载预训练权重。忽略了 {len(model.state_dict()) - len(pretrained_dict)} 个不匹配的参数层。")
    else:
        print(f"在 '{pretrained_path}' 未找到 checkpoint")

PRETRAINED_PATH = "result/VGG16_quant/model_best.pth.tar"
# 即使文件不存在也不报错，方便你直接运行代码测试逻辑
if os.path.exists(PRETRAINED_PATH):
    load_pretrained_weights(model_part1, PRETRAINED_PATH)
else:
    print(f"提示: '{PRETRAINED_PATH}' 不存在，跳过加载权重，使用随机初始化。")

# ==========================================
# 4. 数据加载 (适配 Mac)
# ==========================================
normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

# Mac 上 num_workers 设置过大可能导致多进程错误，通常 0 (主进程) 或 2 比较稳妥
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False, num_workers=2
)

# ==========================================
# 5. 训练循环 (适配 MPS/CPU)
# ==========================================
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model_part1.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.1)

def fine_tune_advanced(model, train_loader, test_loader, epochs=50):
    best_acc = 0
    model.to(device) # 确保模型在 MPS/CPU

    for epoch in range(epochs):
        # --- Training ---
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            # 将数据移动到 Mac 支持的设备
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()

        # 更新学习率
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        # --- Validation ---
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                output = model(inputs)
                _, predicted = output.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        acc = 100. * correct / total

        print(f"Epoch [{epoch+1}/{epochs}] (LR: {current_lr:.5f}) -> Test Accuracy: {acc:.2f}%")

        # 保存最佳模型 (保存到 CPU 格式，兼容性最好)
        if acc > best_acc:
            best_acc = acc
            # 确保文件夹存在
            if not os.path.exists("result"):
                os.makedirs("result")
            torch.save({'state_dict': model.state_dict()}, "result/vgg_16_part1_best.pth.tar")
            print(f"   New Best found: {best_acc:.2f}% (Saved)")

    print(f" Advanced Fine-tuning 完成！最终最佳精度: {best_acc:.2f}%")
    return best_acc

# 开始训练
fine_tune_advanced(model_part1, trainloader, testloader, epochs=50)

当前运行设备: mps
VGG16_Part1 模型构建完成。目标 8x8 层已保留 BN。
=> loading checkpoint 'result/VGG16_quant/model_best.pth.tar'
已加载预训练权重。忽略了 30 个不匹配的参数层。
Epoch [1/50] (LR: 0.01000) -> Test Accuracy: 87.64%
   New Best found: 87.64% (Saved)
Epoch [2/50] (LR: 0.01000) -> Test Accuracy: 89.74%
   New Best found: 89.74% (Saved)
Epoch [3/50] (LR: 0.01000) -> Test Accuracy: 88.58%
Epoch [4/50] (LR: 0.01000) -> Test Accuracy: 88.74%
Epoch [5/50] (LR: 0.01000) -> Test Accuracy: 89.67%
Epoch [6/50] (LR: 0.01000) -> Test Accuracy: 87.74%
Epoch [7/50] (LR: 0.01000) -> Test Accuracy: 89.91%
   New Best found: 89.91% (Saved)
Epoch [8/50] (LR: 0.01000) -> Test Accuracy: 89.45%
Epoch [9/50] (LR: 0.01000) -> Test Accuracy: 89.57%
Epoch [10/50] (LR: 0.01000) -> Test Accuracy: 89.10%
Epoch [11/50] (LR: 0.01000) -> Test Accuracy: 89.77%
Epoch [12/50] (LR: 0.01000) -> Test Accuracy: 89.51%
Epoch [13/50] (LR: 0.01000) -> Test Accuracy: 90.27%
   New Best found: 90.27% (Saved)
Epoch [14/50] (LR: 0.01000) -> Test Accuracy: 89.

92.91

In [2]:
import os
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from models import *  # 确保有 QuantConv2d 定义


# ==========================================
# 1. 设备配置 (MacBook 友好)
# ==========================================
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")  # Apple Silicon GPU
    elif torch.cuda.is_available():
        return torch.device("cuda")  # NVIDIA GPU
    else:
        return torch.device("cpu")  # CPU


device = get_device()
print(f"[Device] 当前运行设备: {device}")


# ==========================================
# 2. 模型定义 (VGG16_Part1，包含 8x8 目标层 + BN)
# ==========================================
class VGG16_Part1(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16_Part1, self).__init__()

        self.features = nn.Sequential(
            # --- Block 1 ---
            QuantConv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            QuantConv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 2 ---
            QuantConv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            QuantConv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 3 ---
            QuantConv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            QuantConv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            QuantConv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 4 ---
            QuantConv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # --- Block 5 (修改区域) ---
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),

            # 1. 适配层: 512 -> 8
            QuantConv2d(512, 8, kernel_size=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),

            # 2. 目标 8x8 层: 8 -> 8 (保留 BN)
            QuantConv2d(8, 8, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),

            # 3. 恢复层: 8 -> 512
            QuantConv2d(8, 512, kernel_size=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # Block 5 剩余部分
            QuantConv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.AvgPool2d(kernel_size=1, stride=1),
        )

        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


model_part1 = VGG16_Part1().to(device)
print("[Model] VGG16_Part1 构建完成，8x8 目标层 + BN 就绪。")


# ==========================================
# 3. 加载预训练权重（可选）
#    可以用你 >90% 的 VGG16_quant 结果作为初始化
# ==========================================
def load_pretrained_weights(model, pretrained_path):
    if os.path.isfile(pretrained_path):
        print(f"=> loading checkpoint '{pretrained_path}'")
        checkpoint = torch.load(pretrained_path, map_location=device)
        pretrained_dict = checkpoint['state_dict']
        model_dict = model.state_dict()

        # 只加载名字匹配 & shape 匹配的层，其余保持随机初始化
        filtered = {
            k: v for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        model_dict.update(filtered)
        model.load_state_dict(model_dict)
        print(f"=> 预训练权重加载完成，匹配并加载了 {len(filtered)} 层参数。")
    else:
        print(f"[WARN] 预训练权重 '{pretrained_path}' 不存在，使用随机初始化。")


# 你可以把这个路径改成自己已有的 VGG16 模型
PRETRAINED_PATH = "result/VGG16_quant/model_best.pth.tar"
load_pretrained_weights(model_part1, PRETRAINED_PATH)

# ==========================================
# 4. CIFAR-10 数据（Mac 友好 num_workers）
# ==========================================
normalize = transforms.Normalize(
    mean=[0.491, 0.482, 0.447],
    std=[0.247, 0.243, 0.262]
)

train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)

trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False, num_workers=2
)

# ==========================================
# 5. (可选) fine-tune，默认不调用
# ==========================================
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(
    model_part1.parameters(), lr=0.01,
    momentum=0.9, weight_decay=5e-4
)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[20, 40], gamma=0.1
)


def fine_tune_advanced(model, train_loader, test_loader, epochs=50):
    best_acc = 0.0
    model.to(device)
    for epoch in range(epochs):
        model.train()
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        # 验证
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        print(f"[Epoch {epoch + 1}/{epochs}] LR={current_lr:.5f}, TestAcc={acc:.2f}%")

        if acc > best_acc:
            best_acc = acc
            if not os.path.exists("result"):
                os.makedirs("result")
            torch.save({'state_dict': model.state_dict()},
                       "result/vgg_16_part1_best.pth.tar")
            print(f"  -> New Best: {best_acc:.2f}% (saved)")
    print(f"[Train Done] Best Acc: {best_acc:.2f}%")
    return best_acc


# 如需重新训练 Part1，可以解开这一行：
# fine_tune_advanced(model_part1, trainloader, testloader, epochs=50)

# ==========================================
# 6. BN 融合 + weight-combine α + 导出 int input/weight/output
# ==========================================
class SaveInput:
    def __init__(self):
        self.inputs = []

    def __call__(self, module, module_in):
        self.inputs.append(module_in[0].detach())

    def clear(self):
        self.inputs = []


def find_8x8_conv_and_bn(model: nn.Module):
    conv_8x8 = None
    bn_after = None
    conv_idx = None

    for idx, m in enumerate(model.features):
        if isinstance(m, QuantConv2d) and m.in_channels == 8 and m.out_channels == 8:
            conv_8x8 = m
            conv_idx = idx
            if idx + 1 < len(model.features) and isinstance(model.features[idx + 1], nn.BatchNorm2d):
                bn_after = model.features[idx + 1]
            break

    if conv_8x8 is None or bn_after is None:
        raise RuntimeError("未找到 8x8 QuantConv2d 及其后续 BN(8)，请检查网络结构。")

    print(f"[INFO] Found target 8x8 conv at features[{conv_idx}] and BN at features[{conv_idx + 1}].")
    return conv_8x8, bn_after, conv_idx


def fuse_conv_bn_1x(conv: nn.Conv2d, bn: nn.BatchNorm2d):
    w = conv.weight.detach()
    if conv.bias is None:
        b = torch.zeros(w.size(0), device=w.device, dtype=w.dtype)
    else:
        b = conv.bias.detach()

    running_mean = bn.running_mean.detach()
    running_var = bn.running_var.detach()
    gamma = bn.weight.detach()
    beta = bn.bias.detach()
    eps = bn.eps

    std = torch.sqrt(running_var + eps)
    w_fused = w * (gamma / std).reshape(-1, 1, 1, 1)
    b_fused = (b - running_mean) * (gamma / std) + beta
    return w_fused, b_fused


def quantize_unsigned(x: torch.Tensor, bits: int):
    qmax = 2 ** bits - 1
    x_clamped = torch.clamp(x, min=0.0)
    max_val = x_clamped.max().clamp(min=1e-8)
    scale = max_val / qmax
    x_int = torch.round(x_clamped / scale).clamp(0, qmax).to(torch.int32)
    return x_int, scale


def quantize_signed(x: torch.Tensor, bits: int):
    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)
    max_abs = x.abs().max().clamp(min=1e-8)
    scale = max_abs / qmax
    x_int = torch.round(x / scale).clamp(qmin, qmax).to(torch.int32)
    return x_int, scale


def export_8x8_layer_with_weight_combine_alpha(
        model: nn.Module,
        data_loader,
        device,
        bits_w: int = 4,
        bits_a: int = 4,
        bits_out: int = 16,
        alpha: float = 1.0,
        out_dir: str = "part3_8x8_fused"
):
    model.eval()
    model.to(device)

    # 1) 找到 8x8 conv 和其 BN
    conv_8x8, bn_8x8, conv_idx = find_8x8_conv_and_bn(model)

    # 2) hook 目标层的输入
    hook = SaveInput()
    handle = conv_8x8.register_forward_pre_hook(hook)

    dataiter = iter(data_loader)
    images, labels = next(dataiter)
    images = images.to(device)

    with torch.no_grad():
        _ = model(images)

    handle.remove()

    if len(hook.inputs) == 0:
        raise RuntimeError("Hook 没捕获到目标层输入，请检查 hook。")

    # 取 batch 中第一张的输入: [1, 8, H, W]
    x_in = hook.inputs[0][0:1].contiguous()
    print(f"[INFO] Captured target layer input shape: {x_in.shape}")

    # 3) Conv + BN 融合
    w_fused, b_fused = fuse_conv_bn_1x(conv_8x8, bn_8x8)
    print(f"[INFO] Fused weight shape: {w_fused.shape}, bias shape: {b_fused.shape}")

    # 4) 引入 weight-combine α（真正利用 α 的地方）
    w_fused_alpha = alpha * w_fused
    print(f"[INFO] Apply weight-combine alpha = {alpha}")

    # 5) 量化激活 & 权重
    x_int, scale_a = quantize_unsigned(x_in, bits=bits_a)
    w_int, scale_w = quantize_signed(w_fused_alpha, bits=bits_w)

    print(f"[INFO] scale_a={scale_a:.6e}, scale_w={scale_w:.6e}")

    # 还原实数用于卷积
    x_q = (x_int.float() * scale_a).to(device)
    w_q = (w_int.float() * scale_w).to(device)
    b_q = b_fused.to(device)

    # 6) 搭建等效 Conv2d 并前向
    conv_sim = nn.Conv2d(
        in_channels=8,
        out_channels=8,
        kernel_size=3,
        padding=1,
        bias=True
    ).to(device)

    with torch.no_grad():
        conv_sim.weight.copy_(w_q)
        conv_sim.bias.copy_(b_q)

        y_ref = conv_sim(x_q)
        y_ref = torch.relu(y_ref)
        y_int, scale_y = quantize_signed(y_ref, bits=bits_out)

    print(f"[INFO] Output shape: {y_ref.shape}, scale_y={scale_y:.6e}")

    # 7) 导出为 int txt
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    input_np = x_int.cpu().numpy().astype(np.int32).ravel()
    weight_np = w_int.cpu().numpy().astype(np.int32).ravel()
    output_np = y_int.cpu().numpy().astype(np.int32).ravel()

    np.savetxt(os.path.join(out_dir, "input_int.txt"), input_np, fmt='%d')
    np.savetxt(os.path.join(out_dir, "weight_int.txt"), weight_np, fmt='%d')
    np.savetxt(os.path.join(out_dir, "output_int.txt"), output_np, fmt='%d')

    with open(os.path.join(out_dir, "scales_alpha.txt"), "w") as f:
        f.write(f"alpha_weight_combine = {alpha}\n")
        f.write(f"scale_activation    = {scale_a:.8e}\n")
        f.write(f"scale_weight        = {scale_w:.8e}\n")
        f.write(f"scale_output        = {scale_y:.8e}\n")

    print("====================================================")
    print(f"[DONE] Exported to folder: {out_dir}")
    print("  - input_int.txt  (len = {})".format(len(input_np)))
    print("  - weight_int.txt (len = {})".format(len(weight_np)))
    print("  - output_int.txt (len = {})".format(len(output_np)))
    print("  - scales_alpha.txt")
    print("====================================================")

    return {
        "input_int": input_np,
        "weight_int": weight_np,
        "output_int": output_np,
        "alpha": alpha,
        "scale_a": scale_a,
        "scale_w": scale_w,
        "scale_y": scale_y,
    }


# ==========================================
# 7. 实际调用一次，输出“最后的结果”
# ==========================================
model_part1.eval()
model_part1.to(device)

export_info = export_8x8_layer_with_weight_combine_alpha(
    model_part1,
    testloader,
    device,
    bits_w=4,
    bits_a=4,
    bits_out=16,
    alpha=0.8,  # 这里就是你要实验的 α
    out_dir="part3_8x8_alpha0.8"
)

print("\n======== 最终导出结果 Summary ========")
print(f"alpha_weight_combine = {export_info['alpha']}")
print(f"scale_a (activation) = {export_info['scale_a']:.6e}")
print(f"scale_w (weight)     = {export_info['scale_w']:.6e}")
print(f"scale_y (output)     = {export_info['scale_y']:.6e}")
print("input_int[0:16]  =", export_info["input_int"][:16])
print("weight_int[0:16] =", export_info["weight_int"][:16])
print("output_int[0:16] =", export_info["output_int"][:16])
print("======================================")


[Device] 当前运行设备: mps
[Model] VGG16_Part1 构建完成，8x8 目标层 + BN 就绪。
=> loading checkpoint 'result/VGG16_quant/model_best.pth.tar'
=> 预训练权重加载完成，匹配并加载了 107 层参数。
[INFO] Found target 8x8 conv at features[40] and BN at features[41].
[INFO] Captured target layer input shape: torch.Size([1, 8, 2, 2])
[INFO] Fused weight shape: torch.Size([8, 8, 3, 3]), bias shape: torch.Size([8])
[INFO] Apply weight-combine alpha = 0.8
[INFO] scale_a=1.880996e-01, scale_w=1.344979e-02
[INFO] Output shape: torch.Size([1, 8, 2, 2]), scale_y=1.412921e-05
[DONE] Exported to folder: part3_8x8_alpha0.8
  - input_int.txt  (len = 32)
  - weight_int.txt (len = 576)
  - output_int.txt (len = 32)
  - scales_alpha.txt

alpha_weight_combine = 0.8
scale_a (activation) = 1.880996e-01
scale_w (weight)     = 1.344979e-02
scale_y (output)     = 1.412921e-05
input_int[0:16]  = [0 0 0 0 2 1 0 0 0 0 7 5 0 0 0 0]
weight_int[0:16] = [ 2 -5 -3  5  4 -6 -2  1 -3 -5  4  3  4 -4 -1  5]
output_int[0:16] = [    0  3939  3044     0 12176     0

In [3]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# 如果前面已经定义过 SaveInput 可以删掉这个重复定义
class SaveInput:
    def __init__(self):
        self.inputs = []

    def __call__(self, module, module_in):
        # module_in 是 tuple，取 [0]
        self.inputs.append(module_in[0].detach())

    def clear(self):
        self.inputs = []


def find_8x8_conv_and_bn(model: nn.Module):
    """
    在 VGG16_Part1.features 里找 in=8, out=8 的 QuantConv2d 和其后一层 BN(8)
    """
    conv_8x8 = None
    bn_8x8 = None
    conv_idx = None

    for idx, m in enumerate(model.features):
        # 这里假设 QuantConv2d 继承自 nn.Conv2d，有 in_channels / out_channels 属性
        if isinstance(m, QuantConv2d) and m.in_channels == 8 and m.out_channels == 8:
            conv_8x8 = m
            conv_idx = idx
            # 后一层应该是 BN
            if idx + 1 < len(model.features) and isinstance(model.features[idx + 1], nn.BatchNorm2d):
                bn_8x8 = model.features[idx + 1]
            break

    if conv_8x8 is None or bn_8x8 is None:
        raise RuntimeError("未找到 8x8 的 QuantConv2d + BN(8)，请检查 VGG16_Part1 的 Block5 结构。")

    print(f"[INFO] Found target 8x8 conv at features[{conv_idx}] and BN at features[{conv_idx + 1}].")
    return conv_8x8, bn_8x8, conv_idx


def fuse_conv_bn_1x(conv: nn.Conv2d, bn: nn.BatchNorm2d):
    """
    标准 Conv+BN 融合，得到等效的 w_fused, b_fused
    """
    w = conv.weight.detach()
    if conv.bias is None:
        b = torch.zeros(w.size(0), device=w.device, dtype=w.dtype)
    else:
        b = conv.bias.detach()

    running_mean = bn.running_mean.detach()
    running_var = bn.running_var.detach()
    gamma = bn.weight.detach()
    beta = bn.bias.detach()
    eps = bn.eps

    std = torch.sqrt(running_var + eps)
    w_fused = w * (gamma / std).reshape(-1, 1, 1, 1)
    b_fused = (b - running_mean) * (gamma / std) + beta
    return w_fused, b_fused


def quantize_unsigned(x: torch.Tensor, bits: int):
    """
    无符号量化到 [0, 2^bits-1]，返回整数张量和 scale
    """
    qmax = 2 ** bits - 1
    x_clamped = torch.clamp(x, min=0.0)
    max_val = x_clamped.max().clamp(min=1e-8)
    scale = max_val / qmax
    x_int = torch.round(x_clamped / scale).clamp(0, qmax).to(torch.int32)
    return x_int, scale


def quantize_signed(x: torch.Tensor, bits: int):
    """
    对称有符号量化到 [-2^(bits-1), 2^(bits-1)-1]
    """
    qmax = 2 ** (bits - 1) - 1
    qmin = -2 ** (bits - 1)
    max_abs = x.abs().max().clamp(min=1e-8)
    scale = max_abs / qmax
    x_int = torch.round(x / scale).clamp(qmin, qmax).to(torch.int32)
    return x_int, scale


def int_to_bin_unsigned(val: int, bits: int) -> str:
    """
    无符号转指定位宽二进制字符串
    """
    val = int(val)
    if val < 0:
        val = 0
    if val > (1 << bits) - 1:
        val = (1 << bits) - 1
    return format(val & ((1 << bits) - 1), '0{}b'.format(bits))


def int_to_bin_twos_complement(val: int, bits: int) -> str:
    """
    有符号整数 -> 指定位宽二进制补码（string）
    """
    val = int(val)
    max_pos = 2 ** (bits - 1) - 1
    min_neg = -2 ** (bits - 1)
    if val > max_pos:
        val = max_pos
    if val < min_neg:
        val = min_neg
    if val < 0:
        val = (1 << bits) + val
    return format(val & ((1 << bits) - 1), '0{}b'.format(bits))


def export_hw7_style_files_with_alpha(
        model: nn.Module,
        data_loader,
        device,
        alpha: float = 1.0,
        bits_a: int = 4,
        bits_w: int = 4,
        bits_psum: int = 16,
        nij_start: int = 0
):
    """
    作用：
      1. 从 VGG16_Part1 的 8x8 Conv+BN 抽一张图片的输入
      2. Conv+BN 融合得到 w_fused, b_fused
      3. 在 w_fused 上乘 weight-combine 因子 alpha
      4. 对激活做无符号 4bit 量化，对权重做有符号 4bit 量化
      5. 仿照作业提供代码的方式，生成：
           - activation.txt (4bit, row7->0, time-step 为列)
           - weight.txt     (4bit 2’s complement, col0..7, col 内 row7->0)
           - psum.txt       (16bit 2’s complement, time-step 为行, col7->0)
         这里 psum 是单个 kij=0 对应的 8x8 * 8 向量 的 dot-product 结果
    """
    model.eval()
    model.to(device)

    # 1) 找 8x8 的 conv 和 BN
    conv_8x8, bn_8x8, _ = find_8x8_conv_and_bn(model)

    # 2) 用 hook 抓这一层的输入
    hook = SaveInput()
    handle = conv_8x8.register_forward_pre_hook(hook)

    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    images = images.to(device)

    with torch.no_grad():
        _ = model(images)

    handle.remove()

    if len(hook.inputs) == 0:
        raise RuntimeError("Hook 没有捕获到 8x8 层的输入，请检查 hook 注册位置。")

    # 取 batch 里第一张图：shape [1, 8, H, W]
    x_in = hook.inputs[0][0:1].contiguous()
    B, C, H, W = x_in.shape
    assert C == 8, f"目标层输入通道数不是 8，而是 {C}"

    print(f"[INFO] Captured target layer input x_in shape: {x_in.shape}")

    # 3) Conv + BN 融合
    w_fused, b_fused = fuse_conv_bn_1x(conv_8x8, bn_8x8)
    print(f"[INFO] Fused weight shape: {w_fused.shape}, bias shape: {b_fused.shape}")

    # 4) 引入 weight-combine alpha
    w_fused_alpha = alpha * w_fused
    print(f"[INFO] Apply weight-combine alpha = {alpha}")

    # 5) 量化激活 & 权重（整数张量）
    x_int, scale_a = quantize_unsigned(x_in, bits=bits_a)
    w_int, scale_w = quantize_signed(w_fused_alpha, bits=bits_w)

    print(f"[INFO] scale_a (activation) = {scale_a:.6e}, scale_w (weight) = {scale_w:.6e}")

    # ==== 仿照作业代码：准备 a_tile / w_tile / psum_tile ====
    padding = 1
    array_size = 8

    # a_int: [C=8, H, W]
    a_int = x_int[0]  # [8, H, W] 4-bit 无符号整数
    # 先做 zero-padding
    a_pad = torch.zeros(C, H + 2 * padding, W + 2 * padding, device=a_int.device, dtype=a_int.dtype)
    a_pad[:, padding:padding + H, padding:padding + W] = a_int
    a_pad_flat = a_pad.view(C, -1)  # [8, P]
    P = a_pad_flat.size(1)
    print(f"[INFO] a_pad shape: {a_pad.shape}, flattened positions P = {P}")

    # 对应作业中的 a_tile[ic_tile,:,:]，这里只有 1 个 tile
    a_tile = a_pad_flat.unsqueeze(0)  # [1, 8, P]

    # 选择 time steps 数量：模仿作业用 64，但如果 P < 64 就用 P
    time_steps = min(64, P)
    if nij_start + time_steps > P:
        raise RuntimeError(f"nij_start({nij_start}) + time_steps({time_steps}) 超过了 a_pad_flat 的长度 {P}。")
    # X: [8, time_steps]
    X = a_tile[0, :, nij_start:nij_start + time_steps].to(torch.int32)

    print(f"[INFO] Will use nij in [{nij_start}, {nij_start + time_steps}) as time steps, X shape: {X.shape}")

    # 权重：w_int [8, 8, 3, 3] -> [8, 8, 9]
    w_int_flat = w_int.view(8, 8, -1)  # [out_c=8, in_c=8, kij=9]
    kij = 0  # 和你作业里的例子一样，先用 kij=0
    W = w_int_flat[:, :, kij].to(torch.int32)  # [8, 8]
    print(f"[INFO] Selected kij={kij}, W shape: {W.shape}")

    # 计算 psum_tile：对应 m(a_tile[ic_tile,:,nij]) 的那一步
    # 这里没有多 tile，直接 W @ X
    # W: [8, 8], X: [8, T] -> [8, T]
    psum_tile = W @ X  # int32 范围内足够
    print(f"[INFO] psum_tile shape: {psum_tile.shape}")

    # ==== 按作业的格式写 activation.txt / weight.txt / psum.txt ====

    bit_precision_a = bits_a
    bit_precision_w = bits_w
    bit_precision_p = bits_psum

    # ---------- activation.txt ----------
    with open('activation.txt', 'w') as f_act:
        f_act.write('#time0row7[msb-lsb],time0row6[msb-lst],....,time0row0[msb-lst]#\n')
        f_act.write('#time1row7[msb-lsb],time1row6[msb-lst],....,time1row0[msb-lst]#\n')
        f_act.write('#................#\n')

        for t in range(time_steps):  # time step
            for j in range(array_size):  # row index
                # 作业里是 X[7-j, i]
                val = int(X[7 - j, t].item())
                bin_str = int_to_bin_unsigned(val, bit_precision_a)
                f_act.write(bin_str)
            f_act.write('\n')

    print("[SAVE] activation.txt generated.")

    # ---------- weight.txt ----------
    with open('weight.txt', 'w') as f_w:
        f_w.write('#col0row7[msb-lsb],col0row6[msb-lst],....,col0row0[msb-lst]#\n')
        f_w.write('#col1row7[msb-lsb],col1row6[msb-lst],....,col1row0[msb-lst]#\n')
        f_w.write('#................#\n')

        # W[col, row]，作业是 val = W[col, 7-row]
        for col in range(array_size):
            for row in range(array_size):
                val = int(W[col, 7 - row].item())
                bin_str = int_to_bin_twos_complement(val, bit_precision_w)
                f_w.write(bin_str)
            f_w.write('\n')

    print("[SAVE] weight.txt generated.")

    # ---------- psum.txt ----------
    with open('psum.txt', 'w') as f_p:
        f_p.write('#time0col7[msb-lsb],time0col6[msb-lst],....,time0col0[msb-lst]#\n')
        f_p.write('#time1col7[msb-lsb],time1col6[msb-lst],....,time1col0[msb-lst]#\n')
        f_p.write('#................#\n')

        # 作业里：for t in range(psum_tile.size(1)):
        #           for col in range(psum_tile.size(0)):
        #               val = psum_tile[7-col, t]
        for t in range(time_steps):
            for col in range(array_size):
                val = int(psum_tile[7 - col, t].item())
                bin_str = int_to_bin_twos_complement(val, bit_precision_p)
                f_p.write(bin_str)
            f_p.write('\n')

    print("[SAVE] psum.txt generated.")

    # 额外存一份 scale 和 alpha 信息，方便你 debug
    with open('scales_alpha.txt', 'w') as f_s:
        f_s.write(f"alpha_weight_combine = {alpha}\n")
        f_s.write(f"scale_activation     = {scale_a:.8e}\n")
        f_s.write(f"scale_weight         = {scale_w:.8e}\n")

    print("==============================================")
    print("Export summary:")
    print(f"  alpha = {alpha}")
    print(f"  bits_a (activation) = {bits_a}, bits_w (weight) = {bits_w}, bits_psum = {bits_psum}")
    print(f"  time_steps (nij count) = {time_steps}")
    print("  Files: activation.txt, weight.txt, psum.txt, scales_alpha.txt")
    print("==============================================")

    return {
        "alpha": alpha,
        "scale_a": scale_a,
        "scale_w": scale_w,
        "time_steps": time_steps,
    }


# ===== 实际调用一次 =====
# 确保 model_part1, testloader, device 已在前面定义好
model_part1.eval()
model_part1.to(device)

info = export_hw7_style_files_with_alpha(
    model_part1,
    testloader,
    device,
    alpha=0.8,  # 这里就是你想用的 weight combine α
    bits_a=4,
    bits_w=4,
    bits_psum=16,
    nij_start=0  # 如果你想模仿作业用 200，可以改成 200（前提是 P 足够大）
)

print("\n[INFO] α / scale 信息：")
print("  alpha =", info["alpha"])
print("  scale_a =", info["scale_a"])
print("  scale_w =", info["scale_w"])
print("  time_steps =", info["time_steps"])


[INFO] Found target 8x8 conv at features[40] and BN at features[41].
[INFO] Captured target layer input x_in shape: torch.Size([1, 8, 2, 2])
[INFO] Fused weight shape: torch.Size([8, 8, 3, 3]), bias shape: torch.Size([8])
[INFO] Apply weight-combine alpha = 0.8
[INFO] scale_a (activation) = 1.880996e-01, scale_w (weight) = 1.344979e-02
[INFO] a_pad shape: torch.Size([8, 4, 4]), flattened positions P = 16
[INFO] Will use nij in [0, 16) as time steps, X shape: torch.Size([8, 16])
[INFO] Selected kij=0, W shape: torch.Size([8, 8])
[INFO] psum_tile shape: torch.Size([8, 16])
[SAVE] activation.txt generated.
[SAVE] weight.txt generated.
[SAVE] psum.txt generated.
Export summary:
  alpha = 0.8
  bits_a (activation) = 4, bits_w (weight) = 4, bits_psum = 16
  time_steps (nij count) = 16
  Files: activation.txt, weight.txt, psum.txt, scales_alpha.txt

[INFO] α / scale 信息：
  alpha = 0.8
  scale_a = tensor(0.1881, device='mps:0')
  scale_w = tensor(0.0134, device='mps:0')
  time_steps = 16
