<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/Wide_Resnet_28_10_Relu6_2bit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [56]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from wide_resnet import Wide_ResNet
import torch.quantization as quantization
from torch.quantization import FakeQuantize, default_qconfig, QuantStub, DeQuantStub
from torch.quantization.qconfig import QConfig
# from torch.quantization import MinMaxObserver


from torch.quantization import HistogramObserver

class CustomHistogramObserverActivation(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverActivation, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val

        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = 0
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)

class CustomHistogramObserverWeight(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverWeight, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val

        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = int(-min_val / scale)
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)

def custom_qconfig(num_bits):
    return QConfig(
        activation=FakeQuantize.with_args(observer=CustomHistogramObserverActivation, num_bits=num_bits, dtype=torch.quint8),
        weight=FakeQuantize.with_args(observer=CustomHistogramObserverWeight, num_bits=num_bits, dtype=torch.qint8),
    )
# class CustomMinMaxObserver(MinMaxObserver):
#     def __init__(self, num_bits, **kwargs):
#         super().__init__(**kwargs)
#         self.num_bits = num_bits
#         self.qmin = 0
#         self.qmax = 2 ** num_bits - 1

#     def forward(self, x):
#         self.min_val = torch.min(x)
#         self.max_val = torch.max(x)

#         scale, zero_point = self.calculate_qparams()
#         new_min = torch.min(torch.tensor([self.min_val, scale * (self.qmin - zero_point)]))
#         new_max = torch.max(torch.tensor([self.max_val, scale * (self.qmax - zero_point)]))

#         self.min_val = torch.min(new_min, self.min_val)
#         self.max_val = torch.max(new_max, self.max_val)
#         return x



def main():

    batch_size = 128
    learning_rate = 0.1
    epochs = 1
    weight_decay = 0.0005
    momentum = 0.9
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_bits = 5

    # 数据预处理
    transform = transforms.Compose([
        # transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载CIFAR-10数据集
    full_train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

    # 将训练数据集分为训练集、验证集和测试集
    train_ratio = 0.7
    validation_ratio = 0.1
    test_ratio = 0.2
    num_train_samples = int(len(full_train_dataset) * train_ratio)
    num_validation_samples = int(len(full_train_dataset) * validation_ratio)
    num_test_samples = len(full_train_dataset) - num_train_samples - num_validation_samples

    train_dataset, validation_dataset, test_dataset_from_train = random_split(full_train_dataset, [num_train_samples, num_validation_samples, num_test_samples])
    test_dataset = torch.utils.data.ConcatDataset([test_dataset, test_dataset_from_train])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 创建Wide ResNet模型
    model = Wide_ResNet(depth=10, widen_factor=3, num_classes=10, dropout_rate=0.3).to(device)

    # 创建num_bits位量化配置
    # num_bit_qconfig = QConfig(
    #     activation=FakeQuantize.with_args(observer=CustomMinMaxObserver, dtype=torch.quint8, qscheme=torch.per_tensor_affine, num_bits=num_bits),
    #     weight=FakeQuantize.with_args(observer=CustomMinMaxObserver, dtype=torch.qint8, qscheme=torch.per_tensor_affine, num_bits=num_bits)
    # )

    # 准备QAT
    model.qconfig = custom_qconfig(num_bits)
    qat_model = quantization.prepare_qat(model, inplace=False).to(device)

    # 设置损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(qat_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)

    # 设置学习率调度器
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

    # 训练和测试函数
    def train_epoch(model, dataloader, criterion, optimizer, device):
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        return running_loss / len(dataloader)


    def test(model, dataloader, criterion, device):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total


    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(qat_model, train_loader, criterion, optimizer, device)
        validation_accuracy = test(qat_model, validation_loader, criterion, device)
        scheduler.step()
        print(f"Epoch: {epoch}, Loss: {train_loss:.4f}, Validation Accuracy: {validation_accuracy * 100:.2f}%")

        # 每隔一定的epoch数量（如5个），冻结统计数据并执行一次量化
        if epoch % 5 == 4:
            qat_model.apply(quantization.disable_observer)
            qat_model.apply(quantization.enable_fake_quant)

        # 每隔一定的epoch数量（如10个），解冻统计数据
        if epoch % 10 == 9:
            qat_model.apply(quantization.enable_observer)
    print(1)
    # 量化训练完成后，将QAT模型转换为量化模型
    def calibrate(model, loader, device):
        model.eval()
        with torch.no_grad():
            for images, _ in loader:
                images = images.to(device)
                _ = model(images)
    qat_model.to('cpu') 
    calibrate(qat_model, train_loader, 'cpu')           
    quantized_model = quantization.convert(qat_model, inplace=False)

    test_accuracy = test(quantized_model, test_loader, criterion, 'cpu')
    print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
    print(2)
    def check_quantization(model):
        quantized = True
        for name, module in model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                if not isinstance(module, (torch.quantization.QuantizedConv2d, torch.quantization.QuantizedLinear)):
                    quantized = False
                    print(f"{name} is not quantized")
        return quantized

    is_quantized = check_quantization(quantized_model)
    if is_quantized:
        print("The model has been successfully quantized.")
    else:
        print("The model has not been fully quantized.")
    # 保存最终量化模型
    quantized_model_path = f"WRN_Relu6_{num_bits}bit.pth"
    # torch.save(quantized_model.state_dict(), quantized_model_path)
    print(f"{num_bits}-bit Relu6 Quantized WRN model saved as {quantized_model_path}")

    t_device = torch.device("cuda")
    t_model = Wide_ResNet(depth=28, widen_factor=10, num_classes=10, dropout_rate=0.3).to(t_device)

    def test_quantized(model, dataloader, device):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                # 量化输入数据
                targets = targets.to('cpu')
                inputs = inputs.to('cpu')
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total
    # 准备QAT
    t_model.qconfig = num_bit_qconfig
    t_qat_model = quantization.prepare_qat(t_model, inplace=False).to(t_device)    
    t_quantized_model = quantization.convert(t_qat_model, inplace=False)
    state_dict = torch.load(f"WRN_Relu6_{num_bits}bit.pth", map_location='cpu')
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    t_quantized_model.load_state_dict(state_dict)


    test_accuracy = test_quantized(t_quantized_model, test_loader, t_device)
    print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

if __name__ == '__main__':
    main()


Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 10x3
Is GPU available? True
Current device: 0
Epoch: 1, Loss: 2.1574, Validation Accuracy: 19.38%
1


NotImplementedError: ignored

In [9]:
qfrom google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [15]:
!cp WRN_Relu6_2bit.pth /content/drive/MyDrive/model