# Partie 8 : Post training quantization from scratch

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Utilisation du device : {device}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Utilisation du device : cuda


### Partie 0 : Quantization naive

In [3]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        return self.layers(x)


def train(model, loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(loader):.4f}")
    print("Entraînement terminé.")


def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    accuracy = 100 * correct / total
    return accuracy


def quantize(model):
    """
    Crée une copie du modèle et convertit "naïvement" les poids 
    en arrondissant à l'entier le plus proche.
    """
    quantized_model = copy.deepcopy(model)
    quantized_model.eval()

    with torch.no_grad():
        for param in quantized_model.parameters():
            param.data = param.data.round() 
            
    return quantized_model

In [4]:
print("--- Entraînement du modèle original (Float32) ---")
model_fp32 = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)

train(model_fp32, train_loader, criterion, optimizer, epochs=5)

original_accuracy = evaluate(model_fp32, test_loader)
print(f"\nAccuracy du modèle original (Float32): {original_accuracy:.2f} %")

print("\n--- Quantification naïve du modèle ---")
model_int_naive = quantize(model_fp32)

quantized_accuracy = evaluate(model_int_naive, test_loader)
print(f"Accuracy du modèle quantifié (Naïf): {quantized_accuracy:.2f} %")

print("\n--- Comparaison ---")
print(f"Original (FP32) : {original_accuracy:.2f} %")
print(f"Quantifié (Naïf) : {quantized_accuracy:.2f} %")

--- Entraînement du modèle original (Float32) ---
Epoch 1/5, Loss: 0.2573
Epoch 2/5, Loss: 0.1125
Epoch 3/5, Loss: 0.0783
Epoch 4/5, Loss: 0.0597
Epoch 5/5, Loss: 0.0479
Entraînement terminé.

Accuracy du modèle original (Float32): 97.67 %

--- Quantification naïve du modèle ---
Accuracy du modèle quantifié (Naïf): 9.80 %

--- Comparaison ---
Original (FP32) : 97.67 %
Quantifié (Naïf) : 9.80 %


On remarque que la quantification naive a echouée.

### Partie 1 : Quantization statique - weights only

In [5]:
def quantize_weights_single_range(model, num_bits=8):
    """
    Quantifie les poids du modèle en utilisant une plage unique pour tous les paramètres
    """
    quantized_model = copy.deepcopy(model)
    quantized_model.eval()

    global_min = float('inf')
    global_max = float('-inf')
    
    with torch.no_grad():
        for param in quantized_model.parameters():
            global_min = min(global_min, param.data.min().item())
            global_max = max(global_max, param.data.max().item())
    
    scale = (global_max - global_min) / (2**num_bits - 1)
    zero_point = -round(global_min / scale)
    
    with torch.no_grad():
        for param in quantized_model.parameters():
            # Quantification
            param.data = torch.round(param.data / scale + zero_point)
            # Déquantification
            param.data = (param.data - zero_point) * scale
            
    return quantized_model


def quantize_weights_per_layer(model, num_bits=8):
    """
    Quantifie les poids du modèle en utilisant une plage différente pour chaque couche
    """
    quantized_model = copy.deepcopy(model)
    quantized_model.eval()
    
    with torch.no_grad():
        for name, module in quantized_model.named_modules():
            if isinstance(module, nn.Linear):
                weight_min = module.weight.data.min().item()
                weight_max = module.weight.data.max().item()
                
                scale = (weight_max - weight_min) / (2**num_bits - 1)
                zero_point = -round(weight_min / scale)
                
                module.weight.data = torch.round(module.weight.data / scale + zero_point)
                module.weight.data = (module.weight.data - zero_point) * scale
                
                if module.bias is not None:
                    bias_min = module.bias.data.min().item()
                    bias_max = module.bias.data.max().item()
                    bias_scale = (bias_max - bias_min) / (2**num_bits - 1)
                    bias_zero_point = -round(bias_min / bias_scale)
                    
                    module.bias.data = torch.round(module.bias.data / bias_scale + bias_zero_point)
                    module.bias.data = (module.bias.data - bias_zero_point) * bias_scale
    
    return quantized_model

In [6]:
print("--- Test de la quantification avec plage unique ---")
model_single_range = quantize_weights_single_range(model_fp32)
single_range_accuracy = evaluate(model_single_range, test_loader)

print("--- Test de la quantification avec plage par couche ---")
model_per_layer = quantize_weights_per_layer(model_fp32)
per_layer_accuracy = evaluate(model_per_layer, test_loader)

print("\n--- Comparaison des performances ---")
print(f"Original (FP32) : {original_accuracy:.2f} %")
print(f"Quantifié (plage unique) : {single_range_accuracy:.2f} %")
print(f"Quantifié (plage par couche) : {per_layer_accuracy:.2f} %")

--- Test de la quantification avec plage unique ---
--- Test de la quantification avec plage par couche ---

--- Comparaison des performances ---
Original (FP32) : 97.67 %
Quantifié (plage unique) : 97.68 %
Quantifié (plage par couche) : 97.66 %


Cela semble mieux fonctionner.

### Quantization des activations

In [7]:
def collect_activation_stats(model, loader, num_batches=100):
    """
    Collecte les statistiques des activations sur un ensemble de calibration
    """
    activation_stats = {}
    hooks = []
    
    def hook_fn(name):
        def hook(module, input, output):
            if name not in activation_stats:
                activation_stats[name] = {"min": float('inf'), "max": float('-inf')}
            activation_stats[name]["min"] = min(activation_stats[name]["min"], output.min().item())
            activation_stats[name]["max"] = max(activation_stats[name]["max"], output.max().item())
        return hook
    
    for name, module in model.named_modules():
        if isinstance(module, nn.ReLU):
            hooks.append(module.register_forward_hook(hook_fn(name)))
    
    model.eval()
    with torch.no_grad():
        for i, (inputs, _) in enumerate(loader):
            if i >= num_batches:
                break
            inputs = inputs.to(device)
            model(inputs)
    
    for hook in hooks:
        hook.remove()
    
    return activation_stats

In [8]:
class QuantizedMLP(nn.Module):
    def __init__(self, original_model, num_bits=8, activation_stats=None):
        super(QuantizedMLP, self).__init__()
        self.num_bits = num_bits
        self.activation_stats = activation_stats
        
        self.flatten = original_model.flatten
        self.layers = nn.ModuleList()
        
        for layer in original_model.layers:
            if isinstance(layer, nn.Linear):
                weight_min = layer.weight.data.min().item()
                weight_max = layer.weight.data.max().item()
                weight_scale = (weight_max - weight_min) / (2**num_bits - 1)
                weight_zero_point = -round(weight_min / weight_scale)
                
                quantized_weight = torch.round(layer.weight.data / weight_scale + weight_zero_point)
                quantized_weight = (quantized_weight - weight_zero_point) * weight_scale
                
                new_layer = nn.Linear(layer.in_features, layer.out_features)
                new_layer.weight.data = quantized_weight
                
                if layer.bias is not None:
                    bias_min = layer.bias.data.min().item()
                    bias_max = layer.bias.data.max().item()
                    bias_scale = (bias_max - bias_min) / (2**num_bits - 1)
                    bias_zero_point = -round(bias_min / bias_scale)
                    
                    quantized_bias = torch.round(layer.bias.data / bias_scale + bias_zero_point)
                    quantized_bias = (quantized_bias - bias_zero_point) * bias_scale
                    new_layer.bias.data = quantized_bias
                
                self.layers.append(new_layer)
            elif isinstance(layer, nn.ReLU):
                self.layers.append(layer)

    def quantize_activation(self, x, stats):
        """Quantifie les activations en utilisant les statistiques collectées"""
        min_val = stats["min"]
        max_val = stats["max"]
        scale = (max_val - min_val) / (2**self.num_bits - 1)
        zero_point = -round(min_val / scale)
        
        x_quantized = torch.round(x / scale + zero_point)
        x_dequantized = (x_quantized - zero_point) * scale
        return x_dequantized

    def forward(self, x):
        x = self.flatten(x)
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if isinstance(layer, nn.ReLU) and self.activation_stats is not None:
                x = self.quantize_activation(x, self.activation_stats[f"layers.{i}"])
        return x

#### Collecte des statistiques d'activation et comparaisons finales

In [9]:
print("Collecte des statistiques d'activation...")
activation_stats = collect_activation_stats(model_fp32, train_loader, num_batches=50)

print("\n--- Test de la quantification complète (poids + activations) ---")
model_full_quant = QuantizedMLP(model_fp32, num_bits=8, activation_stats=activation_stats)
full_quant_accuracy = evaluate(model_full_quant, test_loader)

print("\n--- Comparaison finale des performances ---")
print(f"Original (FP32) : {original_accuracy:.2f} %")
print(f"Quantifié (poids uniquement) : {per_layer_accuracy:.2f} %")
print(f"Quantifié (poids + activations) : {full_quant_accuracy:.2f} %")

Collecte des statistiques d'activation...

--- Test de la quantification complète (poids + activations) ---

--- Comparaison finale des performances ---
Original (FP32) : 97.67 %
Quantifié (poids uniquement) : 97.66 %
Quantifié (poids + activations) : 97.64 %
