In [None]:
import datetime
import os
import sys
import time
import collections
import wandb

import torch
import torch.utils.data
from torch import nn
import torch.nn.functional as F
from utils.benchmark import *
import torch
import os
from models.vgg_tiny import VGG
# Imports
import copy
import random
import numpy as np
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
assert torch.cuda.is_available()
from tqdm import tqdm

import torchvision
from torchvision import transforms

from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor

from absl import logging
logging.set_verbosity(logging.FATAL)  # Disable logging as they are too noisy in notebook



In [None]:
run = wandb.init(project="QAT")

## Set default QuantDescriptor to use histogram based calibration for activation

In [None]:
quant_desc_input = QuantDescriptor(calib_method='histogram')
quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)

## Initialize quantized modules

In [None]:
from pytorch_quantization import quant_modules
quant_modules.initialize()

## Calibrate the model

In [None]:
def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()
            
def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
#             print(F"{name:40}: {module}")
    model.cuda()

In [None]:
from registry import get_model
from torch2trt import torch2trt

path = "results_ptq/"
models = [
    "resnet20", "resnet32", "resnet44", "resnet56",
     "vgg11_bn", "vgg13_bn",  "vgg16_bn",  "vgg19_bn",
    "repvgg_a0", "repvgg_a1", "repvgg_a2"
]



datasets = ["cifar10", "cifar100"]



from torchvision.datasets import *
from torchvision.transforms import *
from torch.utils.data import DataLoader
# dataloader pour cifar10 et cifar100
NORMALIZE_DICT = {
    'cifar10':  dict( mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)),
    'cifar100': dict( mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2761)),
    }
image_size = 32
transforms_cifar10 = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(**NORMALIZE_DICT['cifar10']),
    ]),
    "test": Compose([
        ToTensor(),
        Normalize(**NORMALIZE_DICT['cifar10']),
    ]),
}
transforms_cifar100 = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(**NORMALIZE_DICT['cifar100']),
    ]),
    "test": Compose([
        ToTensor(),
        Normalize(**NORMALIZE_DICT['cifar100']),
    ]),
}

dataset = {}
for split in ["train", "test"]:
    dataset[split] = CIFAR10(root="data/cifar10", train=(split == "train"), download=True, transform=transforms_cifar10[split])
dataloaderc10 = {}
for split in ['train', 'test']:
    dataloaderc10[split] = DataLoader(dataset[split], batch_size=128, shuffle=(split == 'train'), num_workers=0, pin_memory=True)

dataset = {}
for split in ["train", "test"]:
    dataset[split] = CIFAR100(root="data/cifar100", train=(split == "train"), download=True, transform=transforms_cifar100[split])
dataloaderc100 = {}
for split in ['train', 'test']:
    dataloaderc100[split] = DataLoader(dataset[split], batch_size=128, shuffle=(split == 'train'), num_workers=0, pin_memory=True)

import torch 
from tqdm import tqdm
import torch.nn as nn
device = torch.device('cuda')
import torch.nn.functional as F
# Evaluation loop
@torch.no_grad()
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    verbose=True,
) -> float:
    model.eval()

    num_samples = 0
    num_correct = 0
    loss = 0

    for inputs, targets in tqdm(dataloader, desc="eval", leave=False, disable=not verbose):
        # Move the data from CPU to GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Inference
        outputs = model(inputs)
        # Calculate loss
        loss += F.cross_entropy(outputs, targets, reduction="sum")
        # Convert logits to class indices
        outputs = outputs.argmax(dim=1)
        # Update metrics
        num_samples += targets.size(0)
        num_correct += (outputs == targets).sum()
    return (num_correct / num_samples * 100).item(), (loss / num_samples).item()




        

In [None]:
"""for model_name in models:
    for dataset in datasets:
        print(f'model name = {model_name} / dataset = {dataset}')
        model = get_model(model_name, dataset)
        model.eval()
        if dataset == 'cifar10':
            with torch.no_grad():
                collect_stats(model, dataloaderc10["train"], num_batches=80)
                for method in ["entropy", "percentile"]:
                    if method=="percentile":
                        for p in [99, 99.9, 99.99,99.999,100]:
                            print(F"{method} calibration + percentile {p}")
                            compute_amax(model, method=method, percentile=p)
                            acc, loss = evaluate(model, dataloaderc10['test'])
                            print(acc, loss)
                            torch.save(model, path + f"{model_name}_{dataset}_{method}_{p}.pth")
                        
                            
                    else:
                        print(F"{method} calibration")
                        compute_amax(model, method=method)
                        acc, loss = evaluate(model, dataloaderc10['test'])
                        print(acc, loss)
                        torch.save(model, path + f"{model_name}_{dataset}_{method}.pth")
                       
        else:
           with torch.no_grad():
                collect_stats(model, dataloaderc100["train"], num_batches=80)
                for method in ["entropy", "percentile"]:
                    if method=="percentile":
                        for p in [99, 99.9, 99.99,99.999,100]:
                            print(F"{method} calibration + percentile {p}")
                            compute_amax(model, method=method, percentile=p)
                            acc, loss = evaluate(model, dataloaderc100['test'])
                            print(acc, loss)
                            torch.save(model, path + f"{model_name}_{dataset}_{method}_{p}.pth")
                            
                    else:
                        print(F"{method} calibration")
                        compute_amax(model, method=method)
                        acc, loss = evaluate(model, dataloaderc100['test'])
                        print(acc, loss)
                        torch.save(model, path + f"{model_name}_{dataset}_{method}.pth")
                        
           """


In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int,
    lr: int,
    # for pruning
    weight_decay=5e-4,
    pruner=None,
    callbacks=None,
    save=None,
    save_only_state_dict=False,
) -> None:

    optimizer = torch.optim.SGD(model.parameters(
    ), lr=lr, momentum=0.9, weight_decay=weight_decay if pruner is None else 0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,80,100], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    best_acc = -1
    best_checkpoint = dict()

  
    for epoch in range(epochs):
        model.train()
        for inputs, targets in tqdm(train_loader, leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Reset the gradients (from the last iteration)
            optimizer.zero_grad()

            # Forward inference
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward propagation
            loss.backward()

            # Pruner regularize for sparsity learning
            if pruner is not None:
                pruner.regularize(model)

            # Update optimizer
            optimizer.step()

            if callbacks is not None:
                for callback in callbacks:
                    callback()

        acc, val_loss = evaluate(model, test_loader)
        print(
            f'Epoch {epoch + 1}/{epochs} | Val acc: {acc:.2f} | Val loss: {val_loss:.4f} | LR: {optimizer.param_groups[0]["lr"]:.6f}')
        # log les valeurs dans wandb

        if best_acc < acc:
            best_checkpoint['state_dict'] = copy.deepcopy(model.state_dict())
            best_acc = acc
        # Update LR scheduler
        scheduler.step()
    model.load_state_dict(best_checkpoint['state_dict'])
    if save:
        # on veut sauvegarder le meilleur modèle
        path = os.path.join(os.getcwd(), "results", save)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        if save_only_state_dict:
            torch.save(model.state_dict(), path)
        else:
            torch.save(model, path)     
    print(f'Best val acc: {best_acc:.2f}')

# training loop
def train_kd(
    model_student: nn.Module,
    model_teacher: nn.Module,
    train_loader: DataLoader,
    test_loader: DataLoader,
    epochs: int,
    lr: int,
    weight_decay=5e-4,
    callbacks=None,
    save=None,
    save_only_state_dict=False,
) -> None:

    optimizer = torch.optim.SGD(model_student.parameters(
    ), lr=lr, momentum=0.9, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,80], gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    best_acc = -1
    best_checkpoint = dict()

  
    for epoch in range(epochs):
        model_student.train()
        model_teacher.train()
        for inputs, targets in tqdm(train_loader, leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Reset the gradients (from the last iteration)
            optimizer.zero_grad()

            # Forward inference
            out_student = model_student(inputs)
            out_teacher = model_teacher(inputs)


            # kd loss
            kd_T = 4
            predict_student = F.log_softmax(out_student / kd_T, dim=1)
            predict_teacher = F.softmax(out_teacher / kd_T, dim=1)
            alpha = 0.9
            loss = nn.KLDivLoss()(predict_student, predict_teacher) * (alpha * kd_T * kd_T) + criterion(out_student, targets) * (1-alpha)
            
            loss.backward()


            # Update optimizer
            optimizer.step()

            if callbacks is not None:
                for callback in callbacks:
                    callback()

        acc, val_loss = evaluate(model_student, test_loader)
        print(
            f'Epoch {epoch + 1}/{epochs} | Val acc: {acc:.2f} | Val loss: {val_loss:.4f} | LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
        if best_acc < acc:
            best_checkpoint['state_dict'] = copy.deepcopy(model_student.state_dict())
            best_acc = acc
        # Update LR scheduler
        scheduler.step()
    model_student.load_state_dict(best_checkpoint['state_dict'])
    if save:
        # on veut sauvegarder le meilleur modèle
        path = os.path.join(os.getcwd(), "results", save)
        os.makedirs(os.path.dirname(path), exist_ok=True)
        if save_only_state_dict:
            torch.save(model_student.state_dict(), path)
        else:
            torch.save(model_student, path)     
    print(f'Best val acc: {best_acc:.2f}')

In [None]:
# combinaisons architecture: best calib méthode
combi = { "cifar10/resnet20": "percentile_99.999", 
         "cifar10/resnet32":"percentile_99.9" ,"cifar10/resnet44": "percentile_99.99", "cifar10/resnet56": "entropy", "cifar10/vgg11_bn": "percentile_99.99",
           "cifar10/vgg13_bn": "entropy", "cifar10/vgg16_bn": "percentile_99.9", "cifar10/vgg19_bn": "percentile_100", "cifar100/repvgg_a0": "percentile_99.99",
         "cifar100/repvgg_a1": "percentile_99.99", "cifar100/repvgg_a2": "percentile_99.99", "cifar100/resnet20": "percentile_99.99",
           "cifar100/resnet32": "percentile_100", "cifar100/resnet44": "percentile_99.999", "cifar100/resnet56": "percentile_99.99",
         "cifar100/vgg11_bn": "entropy", "cifar100/vgg13_bn": "entropy", "cifar100/vgg16_bn": "percentile_100", "cifar100/vgg19_bn": "percentile_99.999"}


In [None]:


# get best calibration method for each model and run QAT 
from registry import get_model
for names, method in combi.items():
    # empty cache
    torch.cuda.empty_cache()
    # ugly but works
    c = names.split('/')
    if c[0] =="cifar10":
        dataset = "cifar10"
    else:
        dataset = "cifar100"
    name = c[1]
    print(F"model name = {name} / dataset = {dataset}")
    path_model = f"{name}_{dataset}_{method}.pth"
    model = torch.load(f'results_ptq/{path_model}')
    model.eval()
    if dataset == 'cifar10':
        acc, loss = evaluate(model, dataloaderc10['test'])
    else:
        acc, loss = evaluate(model, dataloaderc100['test'])
    model_teacher  = get_model(name, dataset)
    print('modèle initial')
    print(acc, loss)
    if dataset == 'cifar10':
        train_kd(model, model_teacher, dataloaderc10['train'], dataloaderc10['test'], epochs=15, lr=0.0001, save=f"results_qat_kd/{method}_{name}_c10")
    else:
        train_kd(model, model_teacher,  dataloaderc100['train'], dataloaderc100['test'], epochs=15, lr=0.001, save=f"results_qat_kd/{method}_{name}_c100")
        
        