<a href="https://colab.research.google.com/github/Auniik/VGG19-optimization/blob/main/CIFAR10_VGG19_optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchinfo tqdm

# Initial Setup

In [None]:
import os
import gc
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.onnx

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchinfo import summary
from tqdm import tqdm

In [None]:
torch.manual_seed(42)
np.random.seed(42)

# Data Loader Function

This function prepares data loaders for the CIFAR-10 dataset, tailored for use with models like VGG19 that expect RGB images of size 224x224. It includes transformations to adjust the CIFAR-10 grayscale images and create appropriate subsets for training and validation.t.

In [None]:
def get_data_loaders(num_samples=5000, batch_size=64):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2023, 0.1994, 0.2010])
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform
    )

    train_subset = Subset(train_dataset, range(num_samples))
    val_subset = Subset(test_dataset, range(num_samples // 5))

    train_loader = DataLoader(
        train_subset, batch_size=batch_size, shuffle=True,
        num_workers=2, pin_memory=True
    )
    val_loader = DataLoader(
        val_subset, batch_size=batch_size, shuffle=False,
        num_workers=2, pin_memory=True
    )

    return train_loader, val_loader

### VGG19 Model Preparation Function

This function loads a pre-trained VGG19 model from the ImageNet dataset and modifies its final classification layer to adapt it for a different dataset.


In [None]:
def get_model(classes = 10):
    model = torchvision.models.vgg19(weights='IMAGENET1K_V1')
    model.classifier[6] = nn.Linear(4096, classes)
    return model

### FLOPs Estimation Function
**Returns:**
- `total_flops` (int): The estimated total number of floating-point operations for both the forward and backward passes.


In [None]:
def count_flops(model, input_size=(1, 3, 224, 224)):
    """Estimate FLOPs for the model"""
    total_flops = 0
    for param in model.parameters():
        if len(param.shape) >= 2:
            total_flops += np.prod(param.shape)
    return total_flops * 2

def get_model_file_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / (1024 * 1024)
    os.remove("temp.p")
    return size

### Training Function for One Epoch

This function performs one epoch of training for a given model using the provided data loader, loss function, and optimizer. It includes essential features such as gradient clipping and NaN loss handling to ensure stable and effective training.

**Key Features:**
- **Model Training Mode:** Puts the model in training mode to enable features like dropout and batch normalization.
- **Loss Computation:** Uses the specified criterion (loss function) to calculate the training loss.
- **Gradient Clipping:** Prevents exploding gradients by clipping them to a maximum norm (`max_grad_norm`).
- **NaN Loss Handling:** Detects and skips batches with NaN loss values to maintain training stability.
- **Real-Time Progress Monitoring:** Uses `tqdm` to display real-time progress, including current loss and accuracy.
- **Accuracy Calculation:** Tracks the number of correct predictions to compute accuracy for the epoch.


In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device, max_grad_norm=1.0):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Check for NaN loss
        if torch.isnan(loss):
            print("Warning: NaN loss detected. Skipping batch.")
            continue

        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_postfix({
            'loss': f'{running_loss/total:.3f}',
            'acc': f'{100.*correct/total:.1f}%'
        })

    return running_loss / len(train_loader), 100.*correct/total

### Model Evaluation Function

This function evaluates a trained model on a validation dataset, providing key performance metrics such as loss, accuracy.

**Key Features:**
- **Evaluation Mode:** Sets the model to evaluation mode, disabling features like dropout and batch normalization updates.
- **No Gradient Calculation:** Uses `torch.no_grad()` for efficient evaluation by disabling gradient computations.
- **Top-1 and Top-5 Accuracy:** Computes both standard accuracy (Top-1) and Top-5 accuracy, useful for multi-class classification tasks.

In [None]:
def evaluate(model, val_loader, criterion, device):
    model.to(device)
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    correct_top5 = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item()

            # Top-1 accuracy
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Top-5 accuracy
            _, pred5 = outputs.topk(5, 1, True, True)
            correct_top5 += pred5.eq(targets.view(-1, 1).expand_as(pred5)).sum().item()

    return {
        'loss': running_loss / len(val_loader),
        'acc': 100. * correct / total,
        'top1_acc': 100. * correct / total,
        'top5_acc': 100. * correct_top5 / total
    }

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.best_model = model
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model_state = model.state_dict()
            self.counter = 0

In [None]:
def plot_metrics(train_losses, train_accs, val_losses, val_accs):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    ax1.plot(train_losses, label='Train Loss')
    ax1.plot(val_losses, label='Val Loss')
    ax1.set_title('Loss vs Epoch')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    ax2.plot(train_accs, label='Train Accuracy')
    ax2.plot(val_accs, label='Val Accuracy')
    ax2.set_title('Accuracy vs Epoch')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()

    plt.tight_layout()
    plt.show()

In [None]:
def print_and_get_metrics(technique, epoch, metrics, model_size, flops, *args, **kwargs):
    print("\nFinal Model Performance:")

    print(f"Result of {technique}:")
    print(f"Total epochs: {epoch}")
    print(f"Model Size: {model_size:.2f} MB")
    print(f"FLOPs: {flops / 1e9:.2f} GFLOPs")
    print(f"Accuracy: {metrics['acc']:.2f}%")
    print(f"Top-1 Accuracy: {metrics['top1_acc']:.2f}%")
    print(f"Top-5 Accuracy: {metrics['top5_acc']:.2f}%")

    result = {
        "technique": technique,
        "epoch": epoch,
        "model_size": f"{model_size:.2f} MB",
        "flops": f"{flops / 1e9:.2f} GFLOPs",
        "acc": f"{metrics['acc']:.2f}%",
        "top1": f"{metrics['top1_acc']:.2f}%",
        "top5": f"{metrics['top5_acc']:.2f}%",
    }

    for extra in args:
        if isinstance(extra, dict):
            result.update(extra)

    result.update(kwargs)

    return result

# Baseline Model Training

In [None]:
# Configuration
num_epochs=30
batch_size=32
num_samples=5000

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")

# Dataset details
print(f"\nDataset Information:")
print(f"Training samples: {num_samples}")
print(f"Validation samples: {num_samples // 5}")
print(f"Batch size: {batch_size}")
train_loader, val_loader = get_data_loaders(num_samples, batch_size)

In [None]:
def run_baseline(num_epochs, batch_size, num_samples):
    model = get_model().to(device)

    print("\nModel Summary:")
    print(summary(model, input_size=(batch_size, 3, 224, 224)))

    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
    flops = count_flops(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    early_stopping = EarlyStopping(patience=3, min_delta=0.001)

    print("\nStarting training Baseline...")
    epochs_performed = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        eval_metrics = evaluate(model, val_loader, criterion, device)
        val_losses.append(eval_metrics['loss'])
        val_accs.append(eval_metrics['acc'])
        epochs_performed = epochs_performed + 1
        # Early stopping logic
        early_stopping(eval_metrics['loss'], model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    plot_metrics(train_losses, train_accs, val_losses, val_accs)

    model = early_stopping.best_model
    scripted_model = torch.jit.script(model)
    scripted_model.save("baseline_cifar10.pt")

    print("\nPerforming final evaluation for Baseline cifar10...")
    gpu_metrics = evaluate(model, val_loader, criterion, device)

    result = print_and_get_metrics("Baseline - CIFAR-10", epochs_performed, gpu_metrics,  model_size, flops)

    return model, result

In [None]:
baseline_cifar10, baseline_result = run_baseline(num_epochs, batch_size, num_samples)

# Model Pruning

## Structured Pruning

In [None]:
!pip install ptflops
import torch.nn.utils.prune as prune
from ptflops import get_model_complexity_info

### StructuredPruner Class for Convolutional Neural Networks

The `StructuredPruner` class implements **structured pruning** for convolutional neural networks, focusing on removing entire filters (or channels) based on their importance. This pruning technique reduces model complexity and improves inference speed while minimizing the impact on model performance.

**Key Features:**
- **Filter Importance Scoring:** Uses the **L1-norm** of convolutional filters to measure their importance, prioritizing the removal of less significant filters.
- **Layer-Wise Pruning:** Prunes filters from convolutional layers and adjusts associated batch normalization layers to maintain consistency.
- **Automatic Adjustment of Fully Connected Layers:** After pruning convolutional layers, the class adjusts the first fully connected layer to match the reduced feature map dimensions.
- **Pruning History Tracking:** Records pruning statistics, including the number of filters before and after pruning for each layer.
- **Flexible Pruning Ratio:** Allows users to specify the fraction of filters to prune, providing control over the trade-off between model size and accuracy.

**Benefits of Structured Pruning:**
- **Model Compression:** Reduces the number of parameters and model size for more efficient storage and deployment.
- **Faster Inference:** Decreases computational load, leading to faster inference times, especially on resource-constrained devices.
- **Energy Efficiency:** Reduces power consumption by lowering the number of computations.

**Parameters:**
- `model` (torch.nn.Module): The neural network model to be pruned.

**Methods:**
- `compute_filter_importance(conv_layer)`: Computes the importance score for each filter in a convolutional layer using the L1-norm.
- `prune_conv_layer(conv_layer, bn_layer, prune_ratio)`: Prunes the least important filters from a convolutional layer and adjusts the batch normalization layer if present.
- `prune_model(prune_ratio, input_shape=(3, 224, 224))`: Applies structured pruning across the entire model and adjusts subsequent layers accordingly.
- `_adjust_fc_layer(input_shape)`: Adjusts the first fully connected (FC) layer to match the reduced output size from the convolutional layers after pruning.
- `print_pruning_history()`: Displays a summary of the pruning operations performed on the model.

**Returns:**
- A pruned version of the original model with reduced parameters and computational complexity.


In [None]:
class StructuredPruner:
    def __init__(self, model):
        self.model = model
        self.pruning_history = []

    def compute_filter_importance(self, conv_layer):
        """Compute L1-norm of each filter as importance score"""
        weights = conv_layer.weight.data
        importance = torch.sum(torch.abs(weights.view(weights.size(0), -1)), dim=1)
        return importance

    def prune_conv_layer(self, conv_layer, bn_layer, prune_ratio):
        n_filters = conv_layer.weight.size(0)
        n_prune = int(n_filters * prune_ratio)

        if n_prune == 0:
            return torch.arange(n_filters)  # Return all indices if nothing is pruned

        importance = self.compute_filter_importance(conv_layer)
        _, indices = torch.sort(importance)
        indices_to_keep = indices[n_prune:]

        # Prune conv layer
        conv_layer.weight.data = torch.index_select(conv_layer.weight.data, 0, indices_to_keep)
        if conv_layer.bias is not None:
            conv_layer.bias.data = torch.index_select(conv_layer.bias.data, 0, indices_to_keep)
        conv_layer.out_channels = len(indices_to_keep)

        # Prune batch norm layer if present
        if bn_layer is not None:
            bn_layer.weight.data = torch.index_select(bn_layer.weight.data, 0, indices_to_keep)
            bn_layer.bias.data = torch.index_select(bn_layer.bias.data, 0, indices_to_keep)
            bn_layer.running_mean = torch.index_select(bn_layer.running_mean, 0, indices_to_keep)
            bn_layer.running_var = torch.index_select(bn_layer.running_var, 0, indices_to_keep)
            bn_layer.num_features = len(indices_to_keep)

        return indices_to_keep

    def _pre_prune(self):
        self.flops_original, self.params_original = get_model_complexity_info(
            self.model, (3, 224, 224), as_strings=False, print_per_layer_stat=False
        )
        self.flops_original = f"{self.flops_original / 1e9:.2f} GFLOPs"
        self.params_original = f"{self.params_original / 1e6:.2f} Million"
        print(f"Original Model: {self.flops_original}, {self.params_original}")

    def _post_prune(self):
        self.flops_pruned, self.params_pruned = get_model_complexity_info(
            self.model, (3, 224, 224), as_strings=False, print_per_layer_stat=False
        )

        self.flops_pruned = f"{self.flops_pruned / 1e9:.2f} GFLOPs"
        self.params_pruned = f"{self.params_pruned / 1e6:.2f} Million"

        print(f"Pruned Model: {self.flops_pruned}, {self.params_pruned}")


    def prune_model(self, prune_ratio, input_shape=(3, 224, 224)):
        self._pre_prune()
        prev_indices = None

        # Get all conv and batch norm layers
        conv_layers = [module for module in self.model.features if isinstance(module, nn.Conv2d)]
        bn_layers = [module for module in self.model.features if isinstance(module, nn.BatchNorm2d)]

        print(f"Found {len(conv_layers)} convolutional layers and {len(bn_layers)} batch norm layers.")

        for i, conv in enumerate(conv_layers):
            bn = bn_layers[i] if i < len(bn_layers) else None
            filters_before = conv.weight.size(0)

            # Prune current layer
            indices = self.prune_conv_layer(conv, bn, prune_ratio)

            # Adjust input channels of next conv layer based on pruned channels
            if prev_indices is not None:
                conv.weight.data = torch.index_select(conv.weight.data, 1, prev_indices)
                conv.in_channels = len(prev_indices)

            prev_indices = indices

            # Record pruning statistics
            self.pruning_history.append({
                'layer': i,
                'filters_before': filters_before,
                'filters_after': len(indices) if indices is not None else filters_before
            })

        # Adjust the first FC layer based on the new conv output
        self._adjust_fc_layer(input_shape)
        self._post_prune()

    def _adjust_fc_layer(self, input_shape):
        # Run a dummy input through the model to determine new feature size
        dummy_input = torch.randn(1, *input_shape)
        with torch.no_grad():
            features = self.model.features(dummy_input)
            flattened_size = features.view(1, -1).size(1)

        # Update the FC layer to match the new flattened size
        fc = self.model.classifier[0]
        fc.in_features = flattened_size
        fc.weight.data = fc.weight.data[:, :flattened_size]  # Adjust weights
        print(f"Adjusted FC layer input size to: {flattened_size}")

    def get_pruned_stats(self):
        return {
            'flops_original': self.flops_original,
            'flops_pruned': self.flops_pruned,
            'params_original': self.params_original,
            'params_pruned': self.params_pruned
        }

    def print_pruning_history(self):
        if not self.pruning_history:
            print("No pruning has been recorded.")
        for entry in self.pruning_history:
            print(f"Layer {entry['layer']}: Filters before = {entry['filters_before']}, Filters after = {entry['filters_after']}")

### VGG19 Model with Structured Pruning

This function initializes a pre-trained **VGG19** model, modifies it for a specified number of classes, and applies **structured pruning** to reduce the model's complexity. The pruning process removes less important filters from convolutional layers, optimizing the model for faster inference and reduced memory usage.

**Key Features:**
- **Pre-trained VGG19 Backbone:** Loads a VGG19 model pre-trained on ImageNet and modifies the final classification layer for the target dataset.
- **Structured Pruning Integration:** Applies structured pruning using the `StructuredPruner` class, which removes less important filters based on their L1-norm importance scores.
- **Customizable Pruning Ratio:** Allows users to define the proportion of filters to prune from each convolutional layer.
- **Pruning History Tracking:** Returns the pruner object, which contains a history of pruning operations, enabling analysis of model modifications.

**Benefits of Structured Pruning:**
- **Model Compression:** Reduces the number of parameters, leading to a smaller model size and faster inference times.
- **Performance Optimization:** Optimized for deployment on resource-constrained devices, such as mobile and embedded systems.
- **Energy Efficiency:** Reduces computational requirements, lowering energy consumption during inference.


In [None]:
def get_model_with_pruning(classes=10, prune_ratio=0.3):
    model = get_model(classes)
    pruner = StructuredPruner(model)
    pruner.prune_model(prune_ratio)
    return model, pruner

In [None]:
def run_structured_pruning(num_epochs, batch_size, num_samples, prune_ratio=0.3):

    model, pruner = get_model_with_pruning(classes=10, prune_ratio=prune_ratio)
    model = model.to(device)

    print("\nModel Summary after structured pruning:")
    print(summary(model, input_size=(batch_size, 3, 224, 224)))
    print("\nPruning Statistics:")
    pruner.print_pruning_history()
    stats = pruner.get_pruned_stats()

    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
    flops = count_flops(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    early_stopping = EarlyStopping(patience=5, min_delta=0.001)

    # Training loop
    print("\nStarting training structure pruned model...")
    epochs_performed = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        eval_metrics = evaluate(model, val_loader, criterion, device)
        val_losses.append(eval_metrics['loss'])
        val_accs.append(eval_metrics['acc'])
        epochs_performed = epochs_performed + 1
        early_stopping(eval_metrics['loss'], model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    plot_metrics(train_losses, train_accs, val_losses, val_accs)

    model = early_stopping.best_model
    scripted_model = torch.jit.script(model)
    scripted_model.save("sp_cifar10.pt")

    print("\nPerforming final evaluation for Structured cifar10...")
    gpu_metrics = evaluate(model, val_loader, criterion, device)

    result = print_and_get_metrics(
        "Structured Pruning - CIFAR-10", epochs_performed,
        gpu_metrics, model_size, flops, stats
    )
    gc.collect()
    torch.cuda.empty_cache()

    return model, result

In [None]:
sp_cifar10, sp_result = run_structured_pruning(num_epochs, batch_size, num_samples)

## Unstructured Pruning

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

### UnstructuredPruner Class for Neural Network Pruning

The `UnstructuredPruner` class implements **unstructured pruning** for neural networks, which selectively removes individual weights based on their importance (typically by magnitude). This fine-grained pruning method allows for more flexible reduction of model parameters without altering the network's structure.

**Key Features:**
- **L1-Norm Unstructured Pruning:** Prunes weights with the smallest absolute values using **L1-norm** as the criterion, targeting less significant connections.
- **Gradual Pruning:** Supports gradual pruning across multiple steps, incrementally increasing sparsity to reduce the risk of accuracy degradation.
- **Sparsity Tracking:** Calculates and tracks the sparsity (percentage of zero weights) for each pruned layer, providing detailed pruning statistics.
- **Permanent Pruning:** Removes the pruning reparametrization, making the weight removals permanent and reducing computational overhead.
- **Pruning Summary Reporting:** Provides comprehensive statistics, including total parameters, remaining parameters, and overall sparsity.

**Benefits of Unstructured Pruning:**
- **Fine-Grained Control:** Allows selective pruning at the individual weight level, providing greater flexibility compared to structured pruning.
- **Model Compression:** Reduces the number of active parameters, leading to smaller models and faster inference times.
- **Performance Optimization:** Can improve model efficiency, especially on hardware that supports sparse matrix operations.


In [None]:
class UnstructuredPruner:
    def __init__(self, model):
        self.model = model
        self.pruning_stats = {}

    def calculate_sparsity(self, module):
        total_params = module.weight.nelement()
        zero_params = torch.sum(module.weight == 0).item()
        return (zero_params / total_params) * 100

    def apply_gradual_pruning(self, initial_amount=0.1, final_amount=0.3, steps=3):
        step_amount = (final_amount - initial_amount) / steps
        current_amount = initial_amount

        for step in range(steps):
            print(f"\nApplying pruning step {step + 1}/{steps} (amount: {current_amount:.2f})")
            for name, module in self.model.named_modules():
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    prune.l1_unstructured(module, name='weight', amount=current_amount)
            current_amount += step_amount
            self.update_pruning_stats()

    def update_pruning_stats(self):
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                self.pruning_stats[name] = {
                    'sparsity': self.calculate_sparsity(module),
                    'total_params': module.weight.nelement(),
                    'remaining_params': torch.sum(module.weight != 0).item()
                }

    def make_permanent(self):
        """Remove pruning reparametrization and make pruning permanent"""
        for module in self.model.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                prune.remove(module, 'weight')

    def get_pruning_summary(self):
        total_params = 0
        remaining_params = 0

        for stats in self.pruning_stats.values():
            total_params += stats['total_params']
            remaining_params += stats['remaining_params']

        return {
            'total_params': total_params,
            'remaining_params': remaining_params,
            'overall_sparsity': ((total_params - remaining_params) / total_params) * 100,
            'layer_stats': self.pruning_stats
        }

    def print_stats(self):
        pruning_summary = self.get_pruning_summary()
        print(f"Total parameters: {pruning_summary['total_params']:,}")
        print(f"Remaining parameters: {pruning_summary['remaining_params']:,}")
        pr = f"{((pruning_summary['total_params'] - pruning_summary['remaining_params']) / pruning_summary['total_params'] * 100):.2f}%"
        print(f"Parameters Reduction: {pr}")
        print(f"Overall sparsity: {pruning_summary['overall_sparsity']:.2f}%")

        for feature, stats in self.pruning_stats.items():
            print(feature , f": Total params: {stats['total_params']},  Remaining params: {stats['remaining_params']}, Sparsity: {stats['sparsity']}")

        return {
            'parameter_reduction': pr,
            'overall_sparsity': f"{pruning_summary['overall_sparsity']:.2f}%"
        }

### VGG19 Model with Unstructured Pruning

This function initializes a **VGG19** model with pretrained weights, modifies the final classification layer, and applies **unstructured pruning** to reduce model complexity. Unstructured pruning removes individual weights with the least importance, leading to a sparse model that maintains accuracy while reducing the number of active parameters.

**Key Features:**
- **Pretrained VGG19 Backbone:** Loads a VGG19 model pre-trained on ImageNet and modifies the final classification layer for the specified number of classes.
- **Unstructured Pruning:** Applies gradual unstructured pruning using the `UnstructuredPruner` class, removing individual weights based on their L1-norm.
- **Gradual Pruning Steps:** Prunes the model incrementally across multiple steps, reducing the risk of sudden performance degradation.
- **Permanent Pruning:** Removes the pruning reparametrization, making the weight removals permanent and optimizing inference performance.

**Benefits of Unstructured Pruning:**
- **Fine-Grained Compression:** Provides flexible parameter reduction without altering the architecture of the network.
- **Inference Efficiency:** Reduces the number of computations, which can speed up inference on hardware that supports sparse matrix operations.
- **Energy and Memory Savings:** Reduces the computational load and memory footprint of the model, making it ideal for deployment on resource-constrained devices.


In [None]:
def get_model_with_unstructured_pruning(classes=10, prune_amount=0.3):
    # Initialize model with pretrained weights
    model = get_model(classes)

    # Initialize the new classifier layer properly
    nn.init.kaiming_normal_(model.classifier[6].weight)
    nn.init.constant_(model.classifier[6].bias, 0)

    # Apply gradual pruning
    pruner = UnstructuredPruner(model)
    pruner.apply_gradual_pruning(initial_amount=0.1, final_amount=prune_amount, steps=3)
    pruner.make_permanent()

    return model, pruner

In [None]:
def run_unstructured_pruning(num_epochs, batch_size, num_samples):
    model, pruner = get_model_with_unstructured_pruning()
    model = model.to(device)

    model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)
    flops = count_flops(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

    print("\nPruning Statistics:")
    pruning_summary = pruner.print_stats()

    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    early_stopping = EarlyStopping(patience=5, min_delta=0.001)

    # Training loop
    print("\nStarting training unstructure pruned model...")
    epochs_performed = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        eval_metrics = evaluate(model, val_loader, criterion, device)
        val_losses.append(eval_metrics['loss'])
        val_accs.append(eval_metrics['acc'])

        # Learning rate scheduling
        scheduler.step(eval_metrics['loss'])
        epochs_performed = epochs_performed + 1
        # Early stopping logic
        early_stopping(eval_metrics['loss'], model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break


    plot_metrics(train_losses, train_accs, val_losses, val_accs)

    model = early_stopping.best_model
    scripted_model = torch.jit.script(model)
    scripted_model.save("usp_cifar10.pt")

    print("\nPerforming final evaluation for Unstructured cifar10...")
    gpu_metrics = evaluate(model, val_loader, criterion, device)

    result = print_and_get_metrics(
        "Unstructured Pruning - CIFAR-10",
        epochs_performed, gpu_metrics, model_size, flops,
    )

    gc.collect()
    torch.cuda.empty_cache()
    return model, result

In [None]:
usp_cifar10, usp_result = run_unstructured_pruning(num_epochs, batch_size, num_samples)

In [None]:
def comparison_diagram_extended(baseline_model, pruned_model, layer_name, pruning_type):
    baseline_weights = dict(baseline_model.named_parameters())[layer_name].detach().cpu().numpy().flatten()
    pruned_weights   = dict(pruned_model.named_parameters())[layer_name].detach().cpu().numpy().flatten()

    # Compute L2 norm.
    l2_baseline = np.linalg.norm(baseline_weights)
    l2_pruned   = np.linalg.norm(pruned_weights)

    # Compute L1 norm.
    l1_baseline = np.sum(np.abs(baseline_weights))
    l1_pruned   = np.sum(np.abs(pruned_weights))

    # Parameter count.
    param_count_baseline = baseline_weights.size
    param_count_pruned   = pruned_weights.size

    # Create a 1x4 figure.
    fig, axs = plt.subplots(1, 4, figsize=(22, 5))

    # Panel 1: Weight Distribution Histogram.
    axs[0].hist(baseline_weights, bins=30, alpha=0.6, label='Baseline', color='blue')
    axs[0].hist(pruned_weights, bins=30, alpha=0.6, label='Pruned', color='green')
    axs[0].set_title('Weight Distribution')
    axs[0].set_xlabel('Weight values')
    axs[0].set_ylabel('Frequency')
    axs[0].legend()

    # Panel 2: L2 Norm Comparison.
    axs[1].bar(['Baseline', 'Pruned'], [l2_baseline, l2_pruned], color=['blue', 'green'])
    axs[1].set_title('L2 Norm')
    axs[1].set_ylabel('L2 Norm')

    # Panel 3: L1 Norm Comparison.
    axs[2].bar(['Baseline', 'Pruned'], [l1_baseline, l1_pruned], color=['blue', 'green'])
    axs[2].set_title('L1 Norm')
    axs[2].set_ylabel('L1 Norm')

    # Panel 4: Parameter Count Comparison.
    axs[3].bar(['Baseline', 'Pruned'], [param_count_baseline, param_count_pruned], color=['blue', 'green'])
    axs[3].set_title('Parameter Count')
    axs[3].set_ylabel('Count')

    plt.suptitle(f'Layer Comparison of {pruning_type} : {layer_name}', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

def find_max_improved_layer(baseline_model, pruned_model, pruning_type):
    max_composite = -np.inf
    max_layer = None
    for name, param in baseline_model.named_parameters():
        if 'weight' not in name:
            continue
        try:
            base_w = param.detach().cpu().numpy().flatten()
            pruned_w = dict(pruned_model.named_parameters())[name].detach().cpu().numpy().flatten()
        except KeyError:
            continue

        # Compute metrics.
        l2_base = np.linalg.norm(base_w)
        l2_prun = np.linalg.norm(pruned_w)
        l1_base = np.sum(np.abs(base_w))
        l1_prun = np.sum(np.abs(pruned_w))
        count_base = base_w.size
        count_prun = pruned_w.size

        # Avoid division by zero.
        if l2_base == 0 or l1_base == 0:
            continue

        l2_improve = (l2_base - l2_prun) / l2_base
        l1_improve = (l1_base - l1_prun) / l1_base

        if pruning_type == "structured":
            if count_base == 0:
                continue
            param_improve = (count_base - count_prun) / count_base
            composite = (l2_improve + l1_improve + param_improve) / 3.0
        else:
            composite = (l2_improve + l1_improve) / 2.0

        if composite > max_composite:
            max_composite = composite
            max_layer = name

    return max_layer, max_composite

def comparison_diagram_max_improved_layer(baseline_model, pruned_model, pruning_type):
    max_layer, composite = find_max_improved_layer(baseline_model, pruned_model, pruning_type)
    if max_layer is None:
        print("No layer found with improvement.")
    else:
        print(f"Maximum improved layer: {max_layer}\nComposite Improvement: {composite:.3f}")
        comparison_diagram_extended(baseline_model, pruned_model, max_layer, pruning_type)

In [None]:
comparison_diagram_max_improved_layer(baseline_cifar10, sp_cifar10,  'CIFAR-10 - Structured')
comparison_diagram_max_improved_layer(baseline_cifar10, usp_cifar10, 'CIFAR-10 - Utructured')

# Qunatization

In [None]:
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize

In [None]:
class QuantizedVGG19(nn.Module):
    def __init__(self, num_classes=10):
        super(QuantizedVGG19, self).__init__()

        vgg19 = get_model(num_classes)

        # Get features and classifier
        self.features = vgg19.features
        self.avgpool = vgg19.avgpool
        self.classifier = vgg19.classifier

        # Modify last layer
        self.classifier[6] = nn.Linear(4096, num_classes)

        # Quantization stubs
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

        # Initialize the new classifier layer
        nn.init.kaiming_normal_(self.classifier[6].weight)
        nn.init.constant_(self.classifier[6].bias, 0)

    def forward(self, x):
        # Quantize input
        x = self.quant(x)

        # Forward pass
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        # Dequantize output
        x = self.dequant(x)
        return x

    def fuse_model(self):
        """
        Fuse Conv+ReLU layers in VGG19.features. Skip invalid pairs like ReLU+MaxPool.
        """
        fuse_list = []
        for i in range(len(self.features) - 1):
            if isinstance(self.features[i], nn.Conv2d) and isinstance(self.features[i + 1], nn.ReLU):
                fuse_list.append([str(i), str(i + 1)])

        torch.quantization.fuse_modules(self.features, fuse_list, inplace=True)

## 8 bit Quantization

In [None]:
def prepare_model_for_quantization(model, device='cpu', bit=8):
    model.train()

    if device == 'cpu':
        backend = 'fbgemm'  # for x86 CPU
    else:
        backend = 'qnnpack'  # for ARM CPU

    if bit == 4:
        model.qconfig = QConfig(
            activation=FakeQuantize.with_args(
                observer=MinMaxObserver,
                quant_min=0, quant_max=15,  # 4-bit range
                dtype=torch.quint8,
                qscheme=torch.per_tensor_affine
            ),
            weight=FakeQuantize.with_args(
                observer=MinMaxObserver,
                quant_min=-8, quant_max=7,  # 4-bit signed range (-8 to 7)
                dtype=torch.qint8,
                qscheme=torch.per_tensor_affine
            )
        )
    if bit == 8:
        model.qconfig = torch.quantization.get_default_qat_qconfig(backend)

    model.fuse_model()

    torch.quantization.prepare_qat(model, inplace=True)

    model.to(device)

    return model

In [None]:
def run_8bit_quantization(num_epochs, batch_size, num_samples):

    model = QuantizedVGG19(num_classes=10)
    original_size = get_model_file_size(model)
    model = prepare_model_for_quantization(model, device='cpu', bit=8)

    print("\nModel Summary:")
    print(summary(model, input_size=(batch_size, 3, 224, 224)))
    flops = count_flops(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001,  momentum=0.9)

    train_losses, train_accs = [], []
    val_losses, val_accs = [], []

    early_stopping = EarlyStopping(patience=5, min_delta=0.001)

    print("\nStarting training with 8 bit quantized model...")
    epochs_performed = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        # Training phase
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        eval_metrics = evaluate(model, val_loader, criterion, device)
        val_losses.append(eval_metrics['loss'])
        val_accs.append(eval_metrics['acc'])
        epochs_performed = epochs_performed + 1

        # Early stopping logic
        early_stopping(eval_metrics['loss'], model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    plot_metrics(train_losses, train_accs, val_losses, val_accs)
    model = early_stopping.best_model
    print("\nConverting to quantized model...")
    quantized_model = torch.quantization.convert(model.eval().cpu(), inplace=False)
    scripted_model = torch.jit.script(quantized_model)
    scripted_model.save("qat_8bit_cifar10.pt")

    quantized_size = get_model_file_size(quantized_model)

    cpu_metrics = evaluate(quantized_model, val_loader, criterion, 'cpu')

    size_reduction = f"{(original_size - quantized_size) / original_size * 100:.2f}%"

    result = print_and_get_metrics(
        "Quantization 8bit - CIFAR-10",
        epochs_performed, cpu_metrics, quantized_size, flops,
        size_reduction=size_reduction
    )

    return quantized_model, result

In [None]:
qat_8bit_cifar10, qat_8bit_result = run_8bit_quantization(num_epochs, batch_size, num_samples)

## 4bit Quantization

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert

In [None]:
def run_4bit_quantization(num_epochs, batch_size, num_samples):
    model = QuantizedVGG19(num_classes=10).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)

    for epoch in range(5):
        print(f"[Warm-up Epoch {epoch + 1}/5]")
        train_epoch(model, train_loader, criterion, optimizer, device)
        eval_metrics = evaluate(model, val_loader, criterion, device)
        print(f"Validation Accuracy: {eval_metrics['acc']:.4f}")

    original_size = get_model_file_size(model)
    model = prepare_model_for_quantization(model, device, bit=4)
    model.apply(torch.quantization.disable_observer)

    print("\nModel Summary:")
    print(summary(model, input_size=(batch_size, 3, 224, 224)))
    flops = count_flops(model)

    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    best_acc = 0.0

    early_stopping = EarlyStopping(patience=5, min_delta=0.001)

    print("\nStarting training with 4bit quantized model...")
    epochs_performed = 0
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        if epoch == 3:
            model.apply(torch.quantization.enable_observer)

        # Training phase
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        eval_metrics = evaluate(model, val_loader, criterion, device)
        val_losses.append(eval_metrics['loss'])
        val_accs.append(eval_metrics['acc'])

        scheduler.step()

        epochs_performed = epochs_performed + 1

        if eval_metrics['acc'] > best_acc:
            best_acc = eval_metrics['acc']
            print(f"New best model saved with accuracy: {best_acc:.4f}")

        # Early stopping logic
        early_stopping(eval_metrics['loss'], model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    plot_metrics(train_losses, train_accs, val_losses, val_accs)

    model = early_stopping.best_model
    quantized_model = torch.quantization.convert(model.eval().cpu(), inplace=False)
    scripted_model = torch.jit.script(quantized_model)
    scripted_model.save("qat_4bit_cifar10.pt")

    quantized_size = get_model_file_size(quantized_model)

    cpu_metrics = evaluate(quantized_model, val_loader, criterion, 'cpu')

    size_reduction = f"{(original_size - quantized_size) / original_size * 100:.2f}%"

    result = print_and_get_metrics(
        "Quantization 4bit - CIFAR-10",
        epochs_performed, cpu_metrics, quantized_size, flops,
        size_reduction=size_reduction
    )

    return quantized_model, result

In [None]:
qat_4bit_cifar10, qat_4bit_result = run_4bit_quantization(num_epochs, batch_size, num_samples)

In [None]:
def get_quantized_weights_mapping(model):
    quantized_weights = {}
    import torch.nn.quantized as nnq
    for name, module in model.named_modules():
        if isinstance(module, (nnq.Conv2d, nnq.Linear)):
            # Use dequantize() to compare as floating-point values.
            quantized_weights[name] = module.weight().dequantize()
    return quantized_weights

def find_max_changed_layer_mapping(baseline_model, quantized_weights_map):
    max_composite = -np.inf
    max_layer = None
    for name, param in baseline_model.named_parameters():
        if "weight" not in name:
            continue

        # Adjust the baseline name to match the keys in quantized_weights_map by stripping '.weight'
        quant_name = name[:-7] if name.endswith(".weight") else name

        if quant_name not in quantized_weights_map:
            print(f"Layer {name} not found in quantized weights mapping")
            continue

        baseline_weights = param.detach().cpu().numpy().flatten()
        quantized_weights = quantized_weights_map[quant_name].detach().cpu().numpy().flatten()

        l2_baseline = np.linalg.norm(baseline_weights)
        l2_quantized = np.linalg.norm(quantized_weights)
        l1_baseline = np.sum(np.abs(baseline_weights))
        l1_quantized = np.sum(np.abs(quantized_weights))

        # Skip layers with zero norm to avoid division by zero.
        if l2_baseline == 0 or l1_baseline == 0:
            continue

        l2_change = np.abs(l2_baseline - l2_quantized) / l2_baseline
        l1_change = np.abs(l1_baseline - l1_quantized) / l1_baseline

        composite = (l2_change + l1_change) / 2.0

        improvement_threshold = 0.1
        improved = "Yes" if composite < improvement_threshold else "No"

        l2_improvement = (1 - l2_change) * 100
        l1_improvement = (1 - l1_change) * 100
        composite_improvement = (1 - composite) * 100

        if composite > max_composite:
            max_composite = composite
            max_layer = name
    return max_layer, max_composite

def comparison_diagram_extended_quantization_mapping(baseline_model, quantized_weights_map, layer_name, quantization_bits):
    # Get baseline weights from baseline model.
    baseline_weights = dict(baseline_model.named_parameters())[layer_name].detach().cpu().numpy().flatten()
    # Adjust layer name for quantized mapping (strip ".weight")
    quant_layer_name = layer_name
    if quant_layer_name.endswith(".weight"):
        quant_layer_name = quant_layer_name[:-7]
    quantized_weights = quantized_weights_map[quant_layer_name].detach().cpu().numpy().flatten()

    # Compute norms.
    l2_baseline = np.linalg.norm(baseline_weights)
    l2_quantized = np.linalg.norm(quantized_weights)
    l1_baseline = np.sum(np.abs(baseline_weights))
    l1_quantized = np.sum(np.abs(quantized_weights))

    # Compute quantization error.
    quant_error = baseline_weights - quantized_weights

    # Create a 1x5 figure.
    fig, axs = plt.subplots(1, 5, figsize=(30, 5))

    # Panel 1: Baseline Weight Distribution.
    axs[0].hist(baseline_weights, bins=30, alpha=0.7, color='blue')
    axs[0].set_title('Baseline Weight Distribution')
    axs[0].set_xlabel('Weight values')
    axs[0].set_ylabel('Frequency')

    # Panel 2: Quantized Weight Distribution.
    axs[1].hist(quantized_weights, bins=30, alpha=0.7, color='orange')
    axs[1].set_title(f'{quantization_bits}-bit Quantized Distribution')
    axs[1].set_xlabel('Weight values')
    axs[1].set_ylabel('Frequency')

    # Panel 3: L2 Norm Comparison.
    axs[2].bar(['Baseline', f'{quantization_bits}-bit'], [l2_baseline, l2_quantized], color=['blue', 'orange'])
    axs[2].set_title('L2 Norm')
    axs[2].set_ylabel('L2 Norm')

    # Panel 4: L1 Norm Comparison.
    axs[3].bar(['Baseline', f'{quantization_bits}-bit'], [l1_baseline, l1_quantized], color=['blue', 'orange'])
    axs[3].set_title('L1 Norm')
    axs[3].set_ylabel('L1 Norm')

    # Panel 5: Quantization Error Distribution.
    axs[4].hist(quant_error, bins=30, alpha=0.7, color='green')
    axs[4].set_title('Quantization Error Distribution')
    axs[4].set_xlabel('Error value')
    axs[4].set_ylabel('Frequency')

    plt.suptitle(f'Layer {layer_name} Quantization Visualization: {quantization_bits}-bit', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

def get_overall_improvement(baseline_model, quantized_weights_map):
    total_layers = 0
    sum_l2_improvement = 0
    sum_l1_improvement = 0
    sum_composite_improvement = 0

    for name, param in baseline_model.named_parameters():
        if "weight" not in name:
            continue

        quant_name = name[:-7] if name.endswith(".weight") else name

        if quant_name not in quantized_weights_map:
            print(f"Layer {name} not found in quantized weights mapping")
            continue

        baseline_weights = param.detach().cpu().numpy().flatten()
        quantized_weights = quantized_weights_map[quant_name].detach().cpu().numpy().flatten()

        l2_baseline = np.linalg.norm(baseline_weights)
        l2_quantized = np.linalg.norm(quantized_weights)
        l1_baseline = np.sum(np.abs(baseline_weights))
        l1_quantized = np.sum(np.abs(quantized_weights))

        if l2_baseline == 0 or l1_baseline == 0:
            continue

        l2_change = np.abs(l2_baseline - l2_quantized) / l2_baseline
        l1_change = np.abs(l1_baseline - l1_quantized) / l1_baseline
        composite_change = (l2_change + l1_change) / 2.0

        l2_improvement = (1 - l2_change) * 100
        l1_improvement = (1 - l1_change) * 100
        composite_improvement = (1 - composite_change) * 100

        total_layers += 1
        sum_l2_improvement += l2_improvement
        sum_l1_improvement += l1_improvement
        sum_composite_improvement += composite_improvement

    if total_layers > 0:
        avg_l2_improvement = sum_l2_improvement / total_layers
        avg_l1_improvement = sum_l1_improvement / total_layers
        avg_composite_improvement = sum_composite_improvement / total_layers
    else:
        avg_l2_improvement, avg_l1_improvement, avg_composite_improvement = 0, 0, 0

    print("Overall Improvement Metrics:")
    print(f"  Average L2 Improvement: {avg_l2_improvement:.2f}%")
    print(f"  Average L1 Improvement: {avg_l1_improvement:.2f}%")
    print(f"  Average Composite Improvement: {avg_composite_improvement:.2f}%")

    return avg_l2_improvement, avg_l1_improvement, avg_composite_improvement

def comparison_diagram_max_changed_layer_mapping(baseline_model, quantized_weights_map, quantization_bits):
    # Compute and print overall improvements
    avg_l2_improvement, avg_l1_improvement, avg_composite_improvement = get_overall_improvement(baseline_model, quantized_weights_map)

    print("\nOverall Improvement Metrics:")
    print(f"  Average L2 Improvement: {avg_l2_improvement:.2f}%")
    print(f"  Average L1 Improvement: {avg_l1_improvement:.2f}%")
    print(f"  Average Composite Improvement: {avg_composite_improvement:.2f}%")

    max_layer, composite = find_max_changed_layer_mapping(baseline_model, quantized_weights_map)
    if max_layer is None:
        print("No layer found with significant change.")
    else:
        print(f"Layer with maximum change: {max_layer}\nComposite Change: {composite:.3f}")
        comparison_diagram_extended_quantization_mapping(baseline_model, quantized_weights_map, max_layer, quantization_bits)


In [None]:
quantized_weights_map = get_quantized_weights_mapping(qat_8bit_cifar10)
comparison_diagram_max_changed_layer_mapping(baseline_cifar10, quantized_weights_map, quantization_bits=8)

In [None]:
quantized_weights_map = get_quantized_weights_mapping(qat_4bit_cifar10)
comparison_diagram_max_changed_layer_mapping(baseline_cifar10, quantized_weights_map, quantization_bits=4)

In [None]:
import pandas as pd

def flatten_dict(d, parent_key='', sep='_'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

# Flatten each dictionary in the list
flat_data = [flatten_dict(item) for item in [baseline_result,
                                             sp_result, usp_result,
                                             qat_8bit_result, qat_4bit_result]]

# Create a DataFrame from the flattened data
df = pd.DataFrame(flat_data)


# print(df)
df