<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/%E2%80%9CWide_Resnet_28_10_Relu6_5bit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [56]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from WRN_Relu import Wide_ResNet
import torch.quantization as quantization
from torch.quantization import FakeQuantize, default_qconfig, QuantStub, DeQuantStub
from torch.quantization.qconfig import QConfig
from torch.quantization import MinMaxObserver

class CustomMinMaxObserver(MinMaxObserver):
    def __init__(self, num_bits, **kwargs):
        super().__init__(**kwargs)
        self.num_bits = num_bits
        self.qmin = 0
        self.qmax = 2 ** num_bits - 1

    def forward(self, x):
        self.min_val = torch.min(x)
        self.max_val = torch.max(x)

        scale, zero_point = self.calculate_qparams()
        new_min = torch.min(torch.tensor([self.min_val, scale * (self.qmin - zero_point)]))
        new_max = torch.max(torch.tensor([self.max_val, scale * (self.qmax - zero_point)]))

        self.min_val = torch.min(new_min, self.min_val)
        self.max_val = torch.max(new_max, self.max_val)
        return x

def main():

    batch_size = 128
    learning_rate = 0.1
    epochs = 1
    weight_decay = 0.0005
    momentum = 0.9
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_bits = 8

    # 数据预处理
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载CIFAR-10数据集
    full_train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

    # 将训练数据集分为训练集、验证集和测试集
    train_ratio = 0.7
    validation_ratio = 0.1
    test_ratio = 0.2
    num_train_samples = int(len(full_train_dataset) * train_ratio)
    num_validation_samples = int(len(full_train_dataset) * validation_ratio)
    num_test_samples = len(full_train_dataset) - num_train_samples - num_validation_samples

    train_dataset, validation_dataset, test_dataset_from_train = random_split(full_train_dataset, [num_train_samples, num_validation_samples, num_test_samples])
    test_dataset = torch.utils.data.ConcatDataset([test_dataset, test_dataset_from_train])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    # 创建Wide ResNet模型
    model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.3).to(device)

    # 创建5位量化配置
    five_bit_qconfig = QConfig(
        activation=FakeQuantize.with_args(observer=CustomMinMaxObserver, dtype=torch.quint8, qscheme=torch.per_tensor_affine, num_bits=5),
        weight=FakeQuantize.with_args(observer=CustomMinMaxObserver, dtype=torch.qint8, qscheme=torch.per_tensor_affine, num_bits=5)
    )

    # 准备QAT
    model.qconfig = five_bit_qconfig
    qat_model = quantization.prepare_qat(model, inplace=False).to(device)

    # 设置损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(qat_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    # 设置学习率调度器
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

    # 训练和测试函数
    def train_epoch(model, dataloader, criterion, optimizer, device):
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        return running_loss / len(dataloader)


    def test(model, dataloader, criterion, device):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total


    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(qat_model, train_loader, criterion, optimizer, device)
        validation_accuracy = test(qat_model, validation_loader, criterion, device)
        scheduler.step()
        print(f"Epoch: {epoch}, Loss: {train_loss:.4f}, Validation Accuracy: {validation_accuracy * 100:.2f}%")

        # 每隔一定的epoch数量（如5个），冻结统计数据并执行一次量化
        if epoch % 5 == 4:
            qat_model.apply(quantization.disable_observer)
            qat_model.apply(quantization.enable_fake_quant)

        # 每隔一定的epoch数量（如10个），解冻统计数据
        if epoch % 10 == 9:
            qat_model.apply(quantization.enable_observer)

    # 量化训练完成后，将QAT模型转换为量化模型
    quantized_model = quantization.convert(qat_model, inplace=False)
  
    # 最终在测试集上进行评估
    # quantized_model = quantized_model.to("cpu")
    # test_accuracy = test(quantized_model, test_loader, criterion, device)
    # print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

    # 保存最终量化模型
    quantized_model_path = f"WRN_Relu6_{num_bits}bit.pth"
    torch.save(quantized_model.state_dict(), quantized_model_path)
    print(f"{num_bits}-bit Relu6 Quantized WRN model saved as {quantized_model_path}")

    # # 创建与保存模型具有相同结构的新模型
    # q_model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0)

    # # 加载保存的模型状态字典
    # quantized_model_path = "WRN_Relu6_5bit.pth"
    # state_dict_quantized = torch.load(quantized_model_path)

    # # 创建一个新的状态字典，用于存储浮点格式的权重
    # state_dict_float = {}

    # # 将量化权重转换为浮点格式并存储在新状态字典中
    # for key, value in state_dict_quantized.items():
    #     if "scale" not in key and "zero_point" not in key:
    #         state_dict_float[key] = value.dequantize()

    # # 将加载的状态字典应用到新模型
    # q_model.load_state_dict(state_dict_float)

    # # 将新模型转移到所需的设备
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # q_model.to(device)

    # test_accuracy = test(q_model, test_loader, criterion, device)
    # print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

    # def load_quantized_model(model_path, device):
    #     model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0).to(device)
    #     model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    #     quantized_model = torch.quantization.prepare(model, inplace=False)
    #     quantized_model.load_state_dict(torch.load(model_path))
    #     quantized_model = torch.quantization.convert(quantized_model, inplace=False)
    #     return quantized_model

    # # 加载量化后的模型
    # num_bits = 5
    # quantized_model_path = f"WRN_Relu6_{num_bits}bit.pth"
    # quantized_model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.3).to(device)
    # quantized_model.load_state_dict(torch.load(quantized_model_path, map_location=device))

    # # 重新定义MinMaxObserver并进行初始化
    # observer = CustomMinMaxObserver(num_bits, dtype=torch.quint8, qscheme=torch.per_tensor_affine)
    # quantized_model.conv1.activation_post_process = observer

    # for module in quantized_model.modules():
    #     if isinstance(module, nn.BatchNorm2d):
    #         module.activation_post_process = observer

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # quantized_model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0).to(device)

    # # quantized_model_path = "WRN_Relu6_5bit.pth"
    # # device = torch.device("cpu")
    # # quantized_model = load_quantized_model(quantized_model_path, device)

    # # 使用加载的量化模型进行验证
    # test_accuracy = test(quantized_model, test_loader, device)
    # print(f"Test Accuracy: {test_accuracy * 100:.2f}%")



    # # 创建Wide ResNet模型
    # q_model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.0)

    # # 加载保存的量化模型
    # quantized_model_path = "WRN_Relu6_5bit.pth"
    # state_dict_quantized = torch.load(quantized_model_path)

    # # 创建一个新的状态字典，用于存储浮点格式的权重
    # state_dict_float = {}

    # # 将量化权重转换为浮点格式并存储在新状态字典中
    # # for key, value in state_dict_quantized.items():
    # #     if "scale" not in key and "zero_point" not in key:
    # #         state_dict_float[key] = value.to('cpu:qnnpack').int_repr().float()

    # # 将加载的状态字典应用到新模型
    # q_model.load_state_dict(state_dict_float)

    # # 打印模型结构
    # print(q_model)


    # # 打印模型结构
    # print(q_model)

if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10
Is GPU available? True
Current device: 0
Epoch: 1, Loss: 1.7940, Validation Accuracy: 41.86%
8-bit Relu6 Quantized WRN model saved as WRN_Relu6_8bit.pth


In [55]:
!git config --global user.name "Carba6"
!git config --global user.email "1102046255@qq.com"
!git clone https://github.com/Carba6/deeplearning.git

%cd /content/deeplearning

!pwd
!cp /content/WRN_Relu.py /content/deeplearning
!git add WRN_Relu.py
!git commit -m "Add WRN_Relu.py"
!git remote set-url origin https://Carba6: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@github.com/Carba6/deeplearning.git
!git push -f

fatal: destination path 'deeplearning' already exists and is not an empty directory.
/content/deeplearning
/content/deeplearning
[main f51c961] Add WRN_Relu.py
 1 file changed, 1 insertion(+), 1 deletion(-)
Enumerating objects: 5, done.
Counting objects: 100% (5/5), done.
Delta compression using up to 12 threads
Compressing objects: 100% (3/3), done.
Writing objects: 100% (3/3), 292 bytes | 292.00 KiB/s, done.
Total 3 (delta 2), reused 0 (delta 0)
remote: Resolving deltas: 100% (2/2), completed with 2 local objects.[K
To https://github.com/Carba6/deeplearning.git
   0f31992..f51c961  main -> main


In [None]:
import torch
print(torch.__version__)

2.0.0+cu118


In [None]:
!git clone https://github.com/Zhaogui/modules.git

Cloning into 'modules'...
fatal: could not read Username for 'https://github.com': No such device or address
