<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/%E2%80%9CWide_Resnet_28_10_Relu6_5bit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from wide_resnet import Wide_ResNet
import torch.quantization as quantization
from torch.quantization import FakeQuantize, default_qconfig, QuantStub, DeQuantStub
from torch.quantization.qconfig import QConfig




def main():

    batch_size = 128
    learning_rate = 0.1
    epochs = 200
    weight_decay = 0.0005
    momentum = 0.9
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_bits = 5

    # 数据预处理
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载CIFAR-10数据集
    train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # 创建Wide ResNet模型
    model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0).to(device)

    # 创建5位量化配置
    five_bit_qconfig = QConfig(
        activation=FakeQuantize.with_args(observer=torch.quantization.MinMaxObserver, dtype=torch.quint8, qscheme=torch.per_tensor_affine, num_bits=num_bits),
        weight=FakeQuantize.with_args(observer=torch.quantization.MinMaxObserver, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, num_bits=num_bits)
    )

    # 准备QAT
    model.qconfig = quantization.get_default_qat_qconfig('fbgemm')
    qat_model = quantization.prepare_qat(model, inplace=False).to(device)

    # 设置损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(qat_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    # 设置学习率调度器
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

    # 训练和测试函数
    def train_epoch(model, dataloader, criterion, optimizer, device):
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        return running_loss / len(dataloader)


    def test(model, dataloader, criterion, device):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total


    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(qat_model, train_loader, criterion, optimizer, device)
        test_accuracy = test(qat_model, test_loader, criterion, device)
        scheduler.step()
        print(f"Epoch: {epoch}, Loss: {train_loss:.4f}, Test Accuracy: {test_accuracy * 100:.2f}%")

        # 每隔一定的epoch数量（如5个），冻结统计数据并执行一次量化
        if epoch % 5 == 4:
            qat_model.apply(quantization.disable_observer)
            qat_model.apply(quantization.enable_fake_quant)

        # 每隔一定的epoch数量（如10个），解冻统计数据
        if epoch % 10 == 9:
            qat_model.apply(quantization.enable_observer)

    # 量化训练完成后，将QAT模型转换为量化模型
    quantized_model = quantization.convert(qat_model, inplace=False)
    
    # 保存最终量化模型
    quantized_model_path = f"WRN_Relu6_{num_bits}bit.pth"
    torch.save(quantized_model.state_dict(), quantized_model_path)
    print(f"{num_bits}-bit Relu6 Quantized WRN model saved as {quantized_model_path}")

if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10
Is GPU available? True
Current device: 0




Epoch: 1, Loss: 1.6921, Test Accuracy: 49.58%
Epoch: 2, Loss: 1.1754, Test Accuracy: 59.02%


KeyboardInterrupt: ignored

In [None]:
import torch
print(torch.__version__)

2.0.0+cu118


In [None]:
!git clone https://github.com/Zhaogui/modules.git

Cloning into 'modules'...
fatal: could not read Username for 'https://github.com': No such device or address
