<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/WRN_28_10_Relu6_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.quantization import FakeQuantize, QConfig
from torch.quantization import HistogramObserver
from WRN_quant import Wide_ResNet as WRN_quant
from Wide_Resnet import Wide_ResNet
from torch.quantization import get_default_qconfig

class CustomHistogramObserverActivation(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverActivation, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = 0
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)


class CustomHistogramObserverWeight(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverWeight, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = int(-min_val / scale)
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)


def custom_qconfig(num_bits):
    return QConfig(
        activation=FakeQuantize.with_args(observer=CustomHistogramObserverActivation, num_bits=num_bits,dtype=torch.quint8),
        weight=FakeQuantize.with_args(observer=CustomHistogramObserverWeight, num_bits=num_bits, dtype=torch.qint8),
    )





def main():

    #训练参数
    batch_size = 128
    learning_rate = 0.1
    epochs = 100
    weight_decay = 0.0005
    momentum = 0.9

    #模型参数
    depth = 28
    widen_factor = 10
    num_classes = 10
    dropout_rate = 0.3

    #量化范围
    min_bits = 1 #最小为1
    max_bits = 4 #最大为8

    #设备设置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Is GPU available?", torch.cuda.is_available())
    # 数据预处理
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 加载CIFAR-10数据集
    full_train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

    # 将训练数据集分为训练集、验证集和测试集
    train_ratio = 0.7
    validation_ratio = 0.1
    test_ratio = 0.2
    num_train_samples = int(len(full_train_dataset) * train_ratio)
    num_validation_samples = int(len(full_train_dataset) * validation_ratio)
    num_test_samples = len(full_train_dataset) - num_train_samples - num_validation_samples

    train_dataset, validation_dataset, test_dataset_from_train = random_split(full_train_dataset, [num_train_samples, num_validation_samples, num_test_samples])
    test_dataset = torch.utils.data.ConcatDataset([test_dataset, test_dataset_from_train])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # 创建Wide ResNet模型
    model = Wide_ResNet(depth=depth, widen_factor=widen_factor, num_classes=num_classes, dropout_rate=dropout_rate).to(device)

    # 设置损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

    # 设置学习率调度器
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120], gamma=0.1)

    # 训练和测试函数
    def train_epoch(model, dataloader, criterion, optimizer, device):
        model.train()
        running_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        return running_loss / len(dataloader)

    def evaluate(model, dataloader, device):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        return correct / total

    #量化前评估函数
    def calibrate(model, loader, device,):
        model.eval()
        with torch.no_grad():
            for images, _ in loader:
                images = images.to(device)
                _ = model(images)
    
    #量化后检查函数
    def check_quantization_after_convert(model):
        quantized_layers = 0
        for layer_name, layer in model.named_modules():
            if isinstance(layer, (nn.quantized.Conv2d, nn.quantized.Linear)):
                quantized_layers += 1
                print(f"Layer {layer_name} is quantized.")
            elif isinstance(layer, (nn.Conv2d, nn.Linear)):
                print(f"Layer {layer_name} is not quantized.")
        return quantized_layers

    #以num_bits量化并保存模型函数
    def quantize_num_bits_model(o_model, num_bits, device, depth, widen_factor, num_classes, dropout_rate, validation_loader, test_loader):
        model = WRN_quant(depth=depth, widen_factor=widen_factor, num_classes=num_classes, dropout_rate=dropout_rate).to(device)
        model.load_state_dict(o_model.state_dict())
        model.eval()

        if num_bits == 8:
            model.qconfig = get_default_qconfig('fbgemm')
        else:
            model.qconfig = custom_qconfig(num_bits)
        torch.quantization.prepare(model, inplace=True)
        calibrate(model, validation_loader, device)
        model.to('cpu')
        torch.quantization.convert(model, inplace=True)

        # 在量化后检查实际量化的层数
        separator = '- ' * 25
        print(separator)
        print(separator)
        print(f"Check if {num_bits} bits model quantize successfully")
        print(separator)
        quantized_layers = check_quantization_after_convert(model)
        print(f"Number of quantized layers: {quantized_layers}")
        print(separator)
        print(separator)

        # 保存量化模型
        quantized_model_path = f"WRN_{depth}_{widen_factor}_Relu6_{num_bits}bit.pth"
        torch.save(model.state_dict(), quantized_model_path)
        print(f"{num_bits}-bit Quantized WRN_{depth}_{widen_factor}_Relu6 model saved as {quantized_model_path}")


        #在测试集上评估量化模型
        quantized_accuracy = evaluate(model, test_loader, 'cpu')
        print(f"Quantized test accuracy: {quantized_accuracy * 100:.2f}%")
        return
    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        validation_accuracy = evaluate(model, validation_loader, device)
        scheduler.step()
        print(f"Epoch: {epoch}, Loss: {train_loss:.4f}, Validation Accuracy: {validation_accuracy * 100:.2f}%")

    # 在测试集上进行评估全精度模型
    test_accuracy = evaluate(model, validation_loader,  device)
    print(f"full_precision model Test Accuracy: {test_accuracy * 100:.2f}%")

    # 保存全精度模型
    model_path = f"WRN_{depth}_{widen_factor}_Relu6_full_precision.pth"
    torch.save(model.state_dict(), model_path)
    print(f"full_precision WRN_{depth}_{widen_factor}_Relu6 model saved as WRN_full_precision.pth")
    !cp /content/{model_path} /content/drive/MyDrive/model
    print(f'copy full_precision model to google drive successfully!')

    #获得从min_bits到max_bits量化的模型
    for num_bits in range(min_bits, max_bits+1):
        quantize_num_bits_model(model, num_bits, device, depth, widen_factor, num_classes, dropout_rate, validation_loader, test_loader)
        !cp /content/{quantized_model_path} /content/drive/MyDrive/model
        print(f'copy {num_bits}bits model to google drive successfully!')


    # temp = model.state_dict()
    # model = WRN_quant(depth=16, widen_factor=8, num_classes=10, dropout_rate=0.3)
    # model.load_state_dict(temp)
    # model.to('cpu')
    # model.eval()
    # model_fused = model
    # model_fused = QuantizedNet(model_fp32=model)
    # model.qconfig = custom_qconfig(7)
    #
    # torch.quantization.prepare(model, inplace=True)
    #
    # def calibrate(model, loader, device):
    #     model.eval()
    #     with torch.no_grad():
    #         for images, _ in loader:
    #             images = images.to(device)
    #             _ = model(images)

    # calibrate(model, train_loader, 'cpu')
    # print(model.qconfig)







    # torch.quantization.convert(model, inplace=True)
    # print(model.state_dict())
    # def check_quantization_after_convert(model):
    #     quantized_layers = 0
    #     for layer_name, layer in model.named_modules():
    #         if isinstance(layer, (nn.quantized.Conv2d, nn.quantized.Linear)):
    #             quantized_layers += 1
    #             print(f"Layer {layer_name} is quantized.")
    #         elif isinstance(layer, (nn.Conv2d, nn.Linear)):
    #             print(f"Layer {layer_name} is not quantized.")
    #     return quantized_layers
    #
    # def quantize_num_bits_model(model, num_bits):
    #
    #     return
    # # 在量化后检查实际量化的层数
    # quantized_layers = check_quantization_after_convert(model)
    # print(f"Number of quantized layers: {quantized_layers}")
    # quantized_accuracy = evaluate(model, train_loader, 'cpu')
    # print(f"Quantized test accuracy: {quantized_accuracy * 100:.2f}%")

    # print(model.conv1.scale, model.conv1.zero_point)

    # print(model_fused.layer1.scale, model_fused.layer1.zero_point)
    # print(model_fused.layer2.scale, model_fused.layer2.zero_point)
    # print(model_fused.bn1.scale, model_fused.bn1.zero_point)
    # def check_quantization_after_convert(model):
    #     quantized_layers = 0
    #     for layer_name, layer in model.named_modules():
    #         if isinstance(layer, (nn.quantized.Conv2d, nn.quantized.Linear)):
    #             quantized_layers += 1
    #             print(f"Layer {layer_name} is quantized.")
    #         elif isinstance(layer, (nn.Conv2d, nn.Linear)):
    #             print(f"Layer {layer_name} is not quantized.")
    #     return quantized_layers
    #
    # # 在量化后检查实际量化的层数
    # quantized_layers = check_quantization_after_convert(model)
    # print(f"Number of quantized layers: {quantized_layers}")


if __name__ == '__main__':
    main()


Is GPU available? True
Files already downloaded and verified
Files already downloaded and verified
Creat WRN_full_precision successfully!
Epoch: 1, Loss: 1.7819, Validation Accuracy: 40.52%
Epoch: 2, Loss: 1.3196, Validation Accuracy: 47.50%
Epoch: 3, Loss: 1.0553, Validation Accuracy: 54.16%
Epoch: 4, Loss: 0.9284, Validation Accuracy: 31.32%
Epoch: 5, Loss: 0.8315, Validation Accuracy: 51.82%
Epoch: 6, Loss: 0.7451, Validation Accuracy: 62.28%
Epoch: 7, Loss: 0.6789, Validation Accuracy: 51.76%
Epoch: 8, Loss: 0.6315, Validation Accuracy: 63.30%
Epoch: 9, Loss: 0.5778, Validation Accuracy: 64.98%
Epoch: 10, Loss: 0.5562, Validation Accuracy: 54.38%
Epoch: 11, Loss: 0.5312, Validation Accuracy: 68.40%
Epoch: 12, Loss: 0.5028, Validation Accuracy: 60.98%
Epoch: 13, Loss: 0.4923, Validation Accuracy: 75.28%
Epoch: 14, Loss: 0.4756, Validation Accuracy: 75.44%
Epoch: 15, Loss: 0.4574, Validation Accuracy: 72.20%
Epoch: 16, Loss: 0.4511, Validation Accuracy: 71.62%
Epoch: 17, Loss: 0.4432

In [11]:
    !cp /content/WRN_28_10_Relu6_full_precision.pth /content/drive/MyDrive/model