<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/Wide_Resnet_28_10_1_5bit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from wide_resnet import Wide_ResNet
import copy

def quantize_tensor(tensor, n_bits):
    qmin = -(2**(n_bits - 1))
    qmax = qmin + 2**n_bits - 1
    scale = (tensor.max() - tensor.min()) / (qmax - qmin)
    zero_point = qmin - tensor.min() / scale
    quantized = torch.clamp(tensor / scale + zero_point, qmin, qmax)
    return torch.round(quantized) * scale - zero_point

def quantize_grads(model_copy, n_bits):
    for param in model_copy.parameters():
        if param.grad is not None:
            param.grad.data = quantize_tensor(param.grad.data, n_bits)

def main():

    batch_size = 128
    learning_rate = 0.1
    epochs = 200
    weight_decay = 0.0005
    momentum = 0.9
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 数据预处理
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载CIFAR-10数据集
    train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # 创建Wide ResNet模型
    model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0).to(device)

    # 设置损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    # 创建五个模型副本
    model_copies = [copy.deepcopy(model) for _ in range(5)]    

    # 设置学习率调度器
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

    # 训练和测试函数
    def train_epoch(model, dataloader, criterion, optimizer, device, bit):
        """
        对模型进行一次epoch的训练，并返回平均损失

        Args:
            model (nn.Module): 待训练模型
            dataloader (DataLoader): 训练数据集的DataLoader
            criterion (nn.Module): 损失函数
            optimizer (optim.Optimizer): 优化器
            device (str): 训练设备
            bit (int): 量化位数

        Returns:
            float: 平均损失
        """
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            # 进行梯度量化
            inputs = quantize_tensor(inputs, bit)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        return running_loss / len(dataloader)


    def test(model, dataloader, criterion, device, bit):
        """
        对模型进行一次测试，并返回测试准确率

        Args:
            model (nn.Module): 待测试模型
            dataloader (DataLoader): 测试数据集的DataLoader
            criterion (nn.Module): 损失函数
            device (str): 测试设备
            bit (int): 量化位数

        Returns:
            float: 测试准确率
        """
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)

                # 进行梯度量化
                inputs = quantize_tensor(inputs, bit)

                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total


    # 创建学习率调度器列表
    lr_schedulers = [optim.lr_scheduler.MultiStepLR(optim.SGD(model_copy.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay), milestones=[60, 120], gamma=0.1) for model_copy in model_copies]


    # 训练循环
    for epoch in range(1, epochs + 1):

      for n_bits, model_copy in enumerate(model_copies, start=1):
        test_accuracy = test(model_copy, test_loader, criterion, device, n_bits)
        train_loss = train_epoch(model_copy, train_loader, criterion, optimizer, device, n_bits)
        print(f"Epoch: {epoch}, {n_bits}-bit Model Loss: {train_loss:.4f}, Accuracy: {test_accuracy * 100:.2f}%")

      # 更新学习率调度器
      for scheduler in lr_schedulers:
        scheduler.step()
        
    # 在训练循环结束后，保存量化模型
    for n_bits, model_copy in enumerate(model_copies, start=1):

        # 保存量化模型
        quantized_model_path = f"wide_resnet_{n_bits}_bit.pth"
        torch.save(model_copy.state_dict(), quantized_model_path)
        print(f"{n_bits}-bit Quantized model saved as {quantized_model_path}")

if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10
Is GPU available? True
Current device: 0
Epoch: 1, 1-bit Model Loss: 2.3309, Accuracy: 10.19%
Epoch: 1, 2-bit Model Loss: 2.3310, Accuracy: 9.73%
Epoch: 1, 3-bit Model Loss: 2.3324, Accuracy: 10.31%
Epoch: 1, 4-bit Model Loss: 2.3323, Accuracy: 10.42%
Epoch: 1, 5-bit Model Loss: 2.3324, Accuracy: 10.43%
Epoch: 2, 1-bit Model Loss: 2.3307, Accuracy: 10.00%
Epoch: 2, 2-bit Model Loss: 2.3317, Accuracy: 10.01%
Epoch: 2, 3-bit Model Loss: 2.3327, Accuracy: 10.00%
Epoch: 2, 4-bit Model Loss: 2.3327, Accuracy: 10.00%
Epoch: 2, 5-bit Model Loss: 2.3322, Accuracy: 10.00%
Epoch: 3, 1-bit Model Loss: 2.3309, Accuracy: 10.00%
Epoch: 3, 2-bit Model Loss: 2.3312, Accuracy: 10.00%
Epoch: 3, 3-bit Model Loss: 2.3326, Accuracy: 10.00%
Epoch: 3, 4-bit Model Loss: 2.3328, Accuracy: 10.00%
Epoch: 3, 5-bit Model Loss: 2.3327, Accuracy: 9.99%
Epoch: 4, 1-bit Model Loss: 2.3308, Accuracy: 10.00%
Epoch: 4, 2-bit M