In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import timm
import torch.ao.quantization.quantize_fx as quantize_fx
import time
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Tuple, Dict, List

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [2]:
class PerformanceMetrics:
    """Class to track and compare model performance metrics"""
    
    def __init__(self):
        self.metrics = {}
    
    def add_model_metrics(self, model_name: str, accuracy: float, 
                         inference_time: float, model_size: float,
                         memory_usage: float = None):
        """Add metrics for a specific model"""
        self.metrics[model_name] = {
            'accuracy': accuracy,
            'inference_time': inference_time,
            'model_size': model_size,
            'memory_usage': memory_usage
        }
    
    def compare_models(self):
        """Generate comparison report"""
        if len(self.metrics) < 2:
            print("Need at least 2 models for comparison")
            return
        
        print("\n" + "="*80)
        print("MODEL PERFORMANCE COMPARISON")
        print("="*80)
        
        # Print tabular comparison
        print(f"{'Model':<15} {'Accuracy':<12} {'Inf. Time (ms)':<15} {'Size (MB)':<12} {'Memory (MB)':<12}")
        print("-" * 80)
        
        for name, metrics in self.metrics.items():
            memory_str = f"{metrics['memory_usage']:.2f}" if metrics['memory_usage'] else "N/A"
            print(f"{name:<15} {metrics['accuracy']:.4f}       "
                  f"{metrics['inference_time']*1000:.2f}           "
                  f"{metrics['model_size']:.2f}        {memory_str}")
        
        # Calculate improvements/degradations
        if 'FP32' in self.metrics and 'QAT_INT8' in self.metrics:
            fp32 = self.metrics['FP32']
            qat = self.metrics['QAT_INT8']
            
            acc_change = ((qat['accuracy'] - fp32['accuracy']) / fp32['accuracy']) * 100
            speed_improvement = ((fp32['inference_time'] - qat['inference_time']) / fp32['inference_time']) * 100
            size_reduction = ((fp32['model_size'] - qat['model_size']) / fp32['model_size']) * 100
            
            print("\n" + "="*50)
            print("QAT vs FP32 IMPROVEMENTS:")
            print("="*50)
            print(f"Accuracy change: {acc_change:+.2f}%")
            print(f"Speed improvement: {speed_improvement:+.2f}%")
            print(f"Model size reduction: {size_reduction:+.2f}%")
            
            if qat['memory_usage'] and fp32['memory_usage']:
                mem_reduction = ((fp32['memory_usage'] - qat['memory_usage']) / fp32['memory_usage']) * 100
                print(f"Memory usage reduction: {mem_reduction:+.2f}%")


In [3]:
def get_model_size(model):
    """Calculate model size in MB"""
    param_size = 0
    buffer_size = 0
    
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

def get_memory_usage():
    """Get current GPU memory usage in MB"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**2
    return None

def evaluate_model(model, data_loader, device, num_classes=10):
    """Evaluate model and return accuracy and detailed metrics"""
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    inference_times = []
    
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            
            # Measure inference time
            start_time = time.time()
            outputs = model(data)
            end_time = time.time()
            
            inference_times.append(end_time - start_time)
            
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    accuracy = correct / total
    avg_inference_time = np.mean(inference_times)
    
    return accuracy, avg_inference_time, all_preds, all_targets

def prepare_data(batch_size=64, num_workers=2):
    """Prepare CIFAR-100 dataset with appropriate transforms for ViT"""
    
    # Transforms for training and testing
    transform_train = transforms.Compose([
        transforms.Resize((224, 224)),  # ViT expects 224x224 images
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # ImageNet normalization
    ])
    
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    
    # Load CIFAR-100 dataset
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform_train
    )
    
    test_dataset = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    
    return train_loader, test_loader

def create_vit_model(num_classes=100):
    """Create and modify ViT Large model for CIFAR-100"""
    # Load pre-trained ViT Large model
    model = timm.create_model('vit_large_patch16_224', pretrained=True)
    
    # Modify the classifier for CIFAR-100 (100 classes)
    model.head = nn.Linear(model.head.in_features, num_classes)
    
    return model

def train_model(model, train_loader, criterion, optimizer, device, num_epochs=3):
    """Training function"""
    model.to(device)
    model.train()
    
    print(f"\nTraining on {device} for {num_epochs} epochs...")
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            if i % 100 == 99:  # Print every 100 batches
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], '
                      f'Loss: {running_loss/100:.4f}, '
                      f'Acc: {100*correct/total:.2f}%')
                running_loss = 0.0
        
        epoch_acc = 100 * correct / total
        print(f'Epoch [{epoch+1}/{num_epochs}] completed. Training Accuracy: {epoch_acc:.2f}%')

def plot_confusion_matrix(y_true, y_pred, class_names, model_name):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'Confusion Matrix - {model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.show()


In [4]:
def main():
    """Main function to run the complete QAT pipeline"""
    
    print("Starting QAT ViT Large Training Pipeline on CIFAR-100...")
    print("="*70)
    
    # Configuration
    batch_size = 16   # Smaller batch size due to ViT Large memory requirements
    num_epochs = 2    # Reduced for demonstration
    learning_rate = 5e-6  # Lower learning rate for large model
    num_classes = 100
    
    # Prepare data
    print("Preparing CIFAR-100 dataset...")
    train_loader, test_loader = prepare_data(batch_size=batch_size)
    
    # CIFAR-100 class names (100 fine-grained classes)
    class_names = [
        'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
        'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
        'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
        'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
        'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
        'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
        'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
        'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
        'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
        'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
        'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
        'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
        'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
        'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
        'worm'
    ]
    
    # Initialize metrics tracker
    metrics_tracker = PerformanceMetrics()
    
    # Check device availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    print(f"ViT Model: Large (vit_large_patch16_224)")
    print(f"Dataset: CIFAR-100 ({num_classes} classes)")
    
    # ========================
    # PART 1: FP32 Model Training and Evaluation
    # ========================
    print("\n" + "="*70)
    print("PART 1: FP32 ViT LARGE MODEL TRAINING")
    print("="*70)
    
    # Create FP32 model
    model_fp32 = create_vit_model(num_classes)
    print(f"Model parameters: {sum(p.numel() for p in model_fp32.parameters())/1e6:.1f}M")
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer_fp32 = optim.AdamW(model_fp32.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # Train FP32 model
    train_model(model_fp32, train_loader, criterion, optimizer_fp32, device, num_epochs)
    
    # Evaluate FP32 model
    print("\nEvaluating FP32 model...")
    model_fp32.eval()
    fp32_accuracy, fp32_inf_time, fp32_preds, fp32_targets = evaluate_model(
        model_fp32, test_loader, device, num_classes
    )
    
    # Get FP32 model metrics
    fp32_size = get_model_size(model_fp32)
    fp32_memory = get_memory_usage()
    
    metrics_tracker.add_model_metrics(
        'FP32', fp32_accuracy, fp32_inf_time, fp32_size, fp32_memory
    )
    
    print(f"FP32 ViT Large - Accuracy: {fp32_accuracy:.4f}, "
          f"Avg Inference Time: {fp32_inf_time*1000:.2f}ms, "
          f"Model Size: {fp32_size:.2f}MB")
    
    # ========================
    # PART 2: QAT Model Training
    # ========================
    print("\n" + "="*70)
    print("PART 2: QAT ViT LARGE MODEL TRAINING")
    print("="*70)
    
    # Move model to CPU for quantization preparation
    model_fp32.to("cpu")
    
    # Prepare example inputs for FX graph mode quantization
    example_inputs = (torch.randn(1, 3, 224, 224),)
    
    # Get QAT configuration
    qconfig_mapping = torch.ao.quantization.get_default_qat_qconfig_mapping("x86")
    
    # Prepare model for QAT
    print("Preparing ViT Large for QAT...")
    model_prepared = quantize_fx.prepare_qat_fx(model_fp32, qconfig_mapping, example_inputs)
    
    # Move back to training device
    model_prepared.to(device)
    
    # Create new optimizer for QAT training (even lower learning rate)
    optimizer_qat = optim.AdamW(model_prepared.parameters(), lr=learning_rate/10, weight_decay=0.01)
    
    # QAT Training (typically fewer epochs)
    print("Starting QAT fine-tuning for ViT Large...")
    train_model(model_prepared, train_loader, criterion, optimizer_qat, device, num_epochs)
    
    # ========================
    # PART 3: Convert to INT8 and Evaluate
    # ========================
    print("\n" + "="*70)
    print("PART 3: INT8 CONVERSION AND EVALUATION")
    print("="*70)
    
    # Convert to INT8 (must be done on CPU)
    model_prepared.to("cpu")
    model_prepared.eval()
    
    print("Converting ViT Large to INT8 model...")
    model_int8_qat = quantize_fx.convert_fx(model_prepared)
    
    # Evaluate INT8 model
    print("Evaluating INT8 QAT ViT Large model...")
    qat_accuracy, qat_inf_time, qat_preds, qat_targets = evaluate_model(
        model_int8_qat, test_loader, "cpu", num_classes
    )
    
    # Get QAT model metrics
    qat_size = get_model_size(model_int8_qat)
    qat_memory = get_memory_usage()
    
    metrics_tracker.add_model_metrics(
        'QAT_INT8', qat_accuracy, qat_inf_time, qat_size, qat_memory
    )
    
    print(f"QAT INT8 ViT Large - Accuracy: {qat_accuracy:.4f}, "
          f"Avg Inference Time: {qat_inf_time*1000:.2f}ms, "
          f"Model Size: {qat_size:.2f}MB")
    
    # ========================
    # PART 4: Performance Comparison and Analysis
    # ========================
    print("\n" + "="*70)
    print("PART 4: PERFORMANCE ANALYSIS - ViT LARGE on CIFAR-100")
    print("="*70)
    
    # Display comprehensive comparison
    metrics_tracker.compare_models()
    
    # Generate top-5 accuracy for CIFAR-100
    print("\n" + "="*50)
    print("TOP-5 ACCURACY ANALYSIS:")
    print("="*50)
    
    def calculate_top5_accuracy(model, data_loader, device):
        model.eval()
        correct_top5 = 0
        total = 0
        
        with torch.no_grad():
            for data, target in data_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                _, pred = outputs.topk(5, 1, True, True)
                pred = pred.t()
                correct = pred.eq(target.view(1, -1).expand_as(pred))
                correct_top5 += correct[:5].reshape(-1).float().sum(0, keepdim=True).item()
                total += target.size(0)
        
        return correct_top5 / total
    
    # Calculate top-5 accuracy for both models
    model_fp32.to(device)
    fp32_top5 = calculate_top5_accuracy(model_fp32, test_loader, device)
    qat_top5 = calculate_top5_accuracy(model_int8_qat, test_loader, "cpu")
    
    print(f"FP32 ViT Large Top-5 Accuracy: {fp32_top5:.4f}")
    print(f"QAT INT8 ViT Large Top-5 Accuracy: {qat_top5:.4f}")
    print(f"Top-5 Accuracy Retention: {(qat_top5/fp32_top5)*100:.2f}%")
    
    # Generate classification reports (showing top 10 classes for brevity)
    print("\n" + "="*50)
    print("SAMPLE CLASSIFICATION METRICS (First 10 classes):")
    print("="*50)
    
    from sklearn.metrics import classification_report
    sample_classes = class_names[:10]
    
    # Filter predictions and targets for first 10 classes only
    fp32_sample_mask = [i for i, target in enumerate(fp32_targets) if target < 10]
    qat_sample_mask = [i for i, target in enumerate(qat_targets) if target < 10]
    
    if fp32_sample_mask and qat_sample_mask:
        fp32_sample_preds = [fp32_preds[i] for i in fp32_sample_mask]
        fp32_sample_targets = [fp32_targets[i] for i in fp32_sample_mask]
        qat_sample_preds = [qat_preds[i] for i in qat_sample_mask]
        qat_sample_targets = [qat_targets[i] for i in qat_sample_mask]
        
        print("\nFP32 ViT Large (Sample Classes):")
        print(classification_report(fp32_sample_targets, fp32_sample_preds, 
                                  target_names=sample_classes, labels=list(range(10))))
        
        print("\nQAT INT8 ViT Large (Sample Classes):")
        print(classification_report(qat_sample_targets, qat_sample_preds, 
                                  target_names=sample_classes, labels=list(range(10))))
    
    # Save models
    print("\nSaving models...")
    torch.save(model_fp32.state_dict(), 'vit_large_fp32_cifar100.pth')
    torch.save(model_int8_qat.state_dict(), 'vit_large_qat_int8_cifar100.pth')
    
    print("\nQAT ViT Large Training Pipeline on CIFAR-100 Completed!")
    print("Models saved successfully.")
    print(f"Final Model Size Reduction: {((fp32_size - qat_size) / fp32_size) * 100:.1f}%")
    
    return model_fp32, model_int8_qat, metrics_tracker

if __name__ == "__main__":
    # Run the complete pipeline
    fp32_model, qat_model, metrics = main()

Starting QAT ViT Large Training Pipeline on CIFAR-100...
Preparing CIFAR-100 dataset...


100%|██████████| 169M/169M [00:10<00:00, 15.9MB/s] 


Using device: cuda
ViT Model: Large (vit_large_patch16_224)
Dataset: CIFAR-100 (100 classes)

PART 1: FP32 ViT LARGE MODEL TRAINING


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

Model parameters: 303.4M

Training on cuda for 2 epochs...
Epoch [1/2], Step [100], Loss: 4.7780, Acc: 6.31%
Epoch [1/2], Step [200], Loss: 3.1825, Acc: 18.38%
Epoch [1/2], Step [300], Loss: 1.7653, Acc: 33.04%
Epoch [1/2], Step [400], Loss: 1.0600, Acc: 44.03%
Epoch [1/2], Step [500], Loss: 0.8330, Acc: 51.30%
Epoch [1/2], Step [600], Loss: 0.7020, Acc: 56.52%
Epoch [1/2], Step [700], Loss: 0.5989, Acc: 60.57%
Epoch [1/2], Step [800], Loss: 0.5546, Acc: 63.77%
Epoch [1/2], Step [900], Loss: 0.5199, Acc: 66.35%
Epoch [1/2], Step [1000], Loss: 0.4918, Acc: 68.45%
Epoch [1/2], Step [1100], Loss: 0.5095, Acc: 70.08%
Epoch [1/2], Step [1200], Loss: 0.4116, Acc: 71.71%
Epoch [1/2], Step [1300], Loss: 0.4674, Acc: 72.92%
Epoch [1/2], Step [1400], Loss: 0.4224, Acc: 74.02%
Epoch [1/2], Step [1500], Loss: 0.4000, Acc: 75.05%
Epoch [1/2], Step [1600], Loss: 0.3587, Acc: 75.97%
Epoch [1/2], Step [1700], Loss: 0.4117, Acc: 76.75%
Epoch [1/2], Step [1800], Loss: 0.4085, Acc: 77.42%
Epoch [1/2], St



Starting QAT fine-tuning for ViT Large...

Training on cuda for 2 epochs...


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 25.12 MiB is free. Process 2508 has 15.86 GiB memory in use. Of the allocated memory 14.84 GiB is allocated by PyTorch, and 741.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)