In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
import copy
from pathlib import Path
from collections import Counter

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 4  # Person, Helmet, Safety-vest, No-PPE
BATCH_SIZE = 32
EPOCHS = 10
IMG_SIZE = 224

class PPEDataset(Dataset):
    """Dataset for PPE detection - handles multiple objects per image"""
    def __init__(self, images_dir, labels_dir, transform=None, mode='crop_objects'):
        self.images_dir = Path(images_dir)
        self.labels_dir = Path(labels_dir)
        self.transform = transform
        self.mode = mode  # 'multi_label' or 'dominant_class' or 'crop_objects'
        
        # Get all image files
        self.image_files = sorted(list(self.images_dir.glob('*.jpg')) + 
                                  list(self.images_dir.glob('*.png')) +
                                  list(self.images_dir.glob('*.jpeg')))
        
        print(f"Found {len(self.image_files)} images in {images_dir}")
        
        if self.mode == 'crop_objects':
            # Create separate samples for each object
            self.samples = []
            self._create_object_crops()
        
    def _create_object_crops(self):
        """Create individual samples for each object in images"""
        for img_path in self.image_files:
            label_path = self.labels_dir / (img_path.stem + '.txt')
            if label_path.exists():
                with open(label_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            class_id = int(parts[0])
                            x_center, y_center, width, height = map(float, parts[1:5])
                            self.samples.append({
                                'image_path': img_path,
                                'class_id': class_id,
                                'bbox': (x_center, y_center, width, height)
                            })
        print(f"Created {len(self.samples)} object samples from images")
    
    def __len__(self):
        if self.mode == 'crop_objects':
            return len(self.samples)
        return len(self.image_files)
    
    def __getitem__(self, idx):
        if self.mode == 'crop_objects':
            return self._get_cropped_object(idx)
        elif self.mode == 'dominant_class':
            return self._get_dominant_class(idx)
        else:  # multi_label
            return self._get_multi_label(idx)
    
    def _get_cropped_object(self, idx):
        """Get a cropped object from the image"""
        sample = self.samples[idx]
        image = cv2.imread(str(sample['image_path']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image.shape[:2]
        
        # Convert YOLO format to pixel coordinates
        x_center, y_center, width, height = sample['bbox']
        x1 = int((x_center - width/2) * w)
        y1 = int((y_center - height/2) * h)
        x2 = int((x_center + width/2) * w)
        y2 = int((y_center + height/2) * h)
        
        # Ensure coordinates are within bounds
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(w, x2), min(h, y2)
        
        # Crop the object
        cropped = image[y1:y2, x1:x2]
        
        # Handle empty crops
        if cropped.size == 0:
            cropped = image
        
        if self.transform:
            cropped = self.transform(cropped)
        
        return cropped, sample['class_id']
    
    def _get_dominant_class(self, idx):
        """Get the most common class in the image"""
        img_path = self.image_files[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Read all labels and find dominant class
        label_path = self.labels_dir / (img_path.stem + '.txt')
        classes = []
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        classes.append(int(parts[0]))
        
        # Get most common class, default to 0
        if classes:
            label = Counter(classes).most_common(1)[0][0]
        else:
            label = 0
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    def _get_multi_label(self, idx):
        """Get multi-label vector for the image"""
        img_path = self.image_files[idx]
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Create multi-label vector
        label_vector = torch.zeros(NUM_CLASSES)
        label_path = self.labels_dir / (img_path.stem + '.txt')
        
        if label_path.exists():
            with open(label_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if parts:
                        class_id = int(parts[0])
                        if 0 <= class_id < NUM_CLASSES:
                            label_vector[class_id] = 1
        
        # If no labels, set first class as default
        if label_vector.sum() == 0:
            label_vector[0] = 1
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_vector

def create_model(multi_label=False):
    """Create MobileNetV2 model for PPE classification"""
    model = models.mobilenet_v2(pretrained=True)
    
    if multi_label:
        # For multi-label classification, use sigmoid activation
        model.classifier[1] = nn.Linear(model.last_channel, NUM_CLASSES)
    else:
        # For single-label classification
        model.classifier[1] = nn.Linear(model.last_channel, NUM_CLASSES)
    
    return model.to(DEVICE)

def train_epoch(model, dataloader, criterion, optimizer, multi_label=False):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images = images.to(DEVICE)
        
        if multi_label:
            labels = labels.to(DEVICE)
        else:
            labels = labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if not multi_label:
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        else:
            # For multi-label, calculate accuracy differently
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += labels.numel()
            correct += (predicted == labels).sum().item()
    
    return running_loss / len(dataloader), 100. * correct / total

def evaluate(model, dataloader, multi_label=False):
    """Evaluate model"""
    model.eval()
    correct = 0
    total = 0
    inference_times = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            
            if multi_label:
                labels = labels.to(DEVICE)
            else:
                labels = labels.to(DEVICE)
            
            start = time.time()
            outputs = model(images)
            inference_times.append(time.time() - start)
            
            if not multi_label:
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            else:
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                total += labels.numel()
                correct += (predicted == labels).sum().item()
    
    accuracy = 100. * correct / total
    avg_time = np.mean(inference_times) * 1000  # Convert to ms
    return accuracy, avg_time

def get_model_size(model):
    """Get actual model size in MB (counting only non-zero parameters)"""
    param_size = 0
    threshold = 1e-8
    
    for param in model.parameters():
        # Count non-zero parameters
        non_zero = torch.sum(torch.abs(param) > threshold).item()
        param_size += non_zero * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_mb = (param_size + buffer_size) / 1024**2
    return size_mb

def get_sparsity(model):
    """Calculate model sparsity (% of near-zero weights)"""
    zero_params = 0
    total_params = 0
    threshold = 1e-8  # Consider values below this as pruned
    
    for param in model.parameters():
        zero_params += torch.sum(torch.abs(param) < threshold).item()
        total_params += param.nelement()
    
    return 100. * zero_params / total_params

def count_parameters(model):
    """Count total and non-zero parameters"""
    total = sum(p.numel() for p in model.parameters())
    threshold = 1e-8
    non_zero = sum(torch.sum(torch.abs(p) > threshold).item() for p in model.parameters())
    return total, non_zero

# ==============================================================================
# PRUNING HEURISTIC 1: Magnitude-based Pruning (L1)
# ==============================================================================
def magnitude_pruning(model, amount):
    """
    Prune weights with smallest L1 magnitude globally across all layers.
    This is the most common and effective pruning heuristic.
    """
    parameters_to_prune = []
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            parameters_to_prune.append((module, 'weight'))
    
    # Global unstructured pruning based on L1 magnitude
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount,
    )
    
    return model

# ==============================================================================
# PRUNING HEURISTIC 2: Random Pruning
# ==============================================================================
def random_pruning(model, amount):
    """
    Randomly prune weights - useful as a baseline to compare against 
    structured pruning methods. Shows importance of selecting which weights to prune.
    """
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.random_unstructured(module, name='weight', amount=amount)
    
    return model

# ==============================================================================
# PRUNING HEURISTIC 3: Layer-wise Proportional Pruning
# ==============================================================================
def layerwise_pruning(model, amount):
    """
    Prune each layer independently with the same proportion.
    This maintains the relative capacity of each layer, which can be beneficial
    for preserving model structure.
    """
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            # Prune each layer independently
            prune.l1_unstructured(module, name='weight', amount=amount)
    
    return model

# ==============================================================================
# Making Pruning Permanent
# ==============================================================================
def make_pruning_permanent(model):
    """
    Remove pruning reparameterization to get actual model size reduction.
    This converts masked weights to actual zeros and removes the mask buffers.
    """
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            try:
                prune.remove(module, 'weight')
            except:
                pass
    return model

# ==============================================================================
# Fine-tuning Functions
# ==============================================================================
def fine_tune(model, train_loader, val_loader, epochs=3, multi_label=False, lr=0.001):
    """Fine-tune pruned model"""
    if multi_label:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    best_val_acc = 0
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, multi_label)
        val_acc, _ = evaluate(model, val_loader, multi_label)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
        
        print(f"  Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    return model

# ==============================================================================
# Iterative Pruning Strategy
# ==============================================================================
def iterative_pruning(model, train_loader, val_loader, target_sparsity, steps=5, 
                     multi_label=False, pruning_fn=magnitude_pruning):
    """
    Iteratively prune and fine-tune in multiple steps.
    Gradually increases sparsity, allowing the model to adapt at each step.
    Generally produces better results than one-shot pruning.
    """
    results = []
    
    # Calculate pruning amount per step
    # We need to account for cumulative pruning
    remaining = 1.0
    prune_amounts = []
    for step in range(steps):
        target_remaining = 1.0 - ((step + 1) / steps * target_sparsity)
        step_amount = 1.0 - (target_remaining / remaining)
        prune_amounts.append(step_amount)
        remaining = target_remaining
    
    for step in range(steps):
        print(f"\nIterative Pruning - Step {step+1}/{steps}")
        print(f"Pruning {prune_amounts[step]*100:.1f}% of remaining weights...")
        
        # Prune
        model = pruning_fn(model, prune_amounts[step])
        
        # Fine-tune with fewer epochs per step
        model = fine_tune(model, train_loader, val_loader, epochs=2, multi_label=multi_label, lr=0.0005)
        
        # Make pruning permanent before evaluation
        model = make_pruning_permanent(model)
        
        # Evaluate
        sparsity = get_sparsity(model)
        accuracy, inf_time = evaluate(model, val_loader, multi_label)
        size_mb = get_model_size(model)
        total_params, non_zero_params = count_parameters(model)
        
        results.append({
            'step': step + 1,
            'sparsity': sparsity,
            'accuracy': accuracy,
            'size_mb': size_mb,
            'inf_time': inf_time,
            'total_params': total_params,
            'non_zero_params': non_zero_params
        })
        
        print(f"Step {step+1} Results:")
        print(f"  Sparsity: {sparsity:.2f}%")
        print(f"  Accuracy: {accuracy:.2f}%")
        print(f"  Size: {size_mb:.2f} MB")
        print(f"  Parameters: {non_zero_params:,} / {total_params:,}")
    
    return model, results

# ==============================================================================
# One-shot Pruning Strategy
# ==============================================================================
def oneshot_pruning(model, train_loader, val_loader, target_sparsity, 
                   multi_label=False, pruning_fn=magnitude_pruning):
    """
    Prune all at once to target sparsity, then fine-tune.
    Faster but may result in lower accuracy compared to iterative pruning.
    """
    results = []
    print(f"\nOne-shot Pruning - Target Sparsity: {target_sparsity*100:.1f}%")
    
    # Prune
    model = pruning_fn(model, target_sparsity)
    
    # Fine-tune with more epochs to recover
    model = fine_tune(model, train_loader, val_loader, epochs=5, multi_label=multi_label, lr=0.0005)
    
    # Make pruning permanent
    model = make_pruning_permanent(model)
    
    # Evaluate
    sparsity = get_sparsity(model)
    accuracy, inf_time = evaluate(model, val_loader, multi_label)
    size_mb = get_model_size(model)
    total_params, non_zero_params = count_parameters(model)
    
    results.append({
        'sparsity': sparsity,
        'accuracy': accuracy,
        'size_mb': size_mb,
        'inf_time': inf_time,
        'total_params': total_params,
        'non_zero_params': non_zero_params
    })
    
    print(f"Results:")
    print(f"  Sparsity: {sparsity:.2f}%")
    print(f"  Accuracy: {accuracy:.2f}%")
    print(f"  Size: {size_mb:.2f} MB")
    print(f"  Parameters: {non_zero_params:,} / {total_params:,}")
    
    return model, results

# ==============================================================================
# Visualization Functions
# ==============================================================================
def plot_comprehensive_results(baseline, mag_results, rand_results, layer_results, 
                              iter_results, oneshot_results):
    """Create comprehensive visualization of all pruning experiments"""
    
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # ========== Plot 1: Accuracy vs Sparsity (All Methods) ==========
    ax1 = fig.add_subplot(gs[0, :2])
    ax1.axhline(y=baseline['accuracy'], color='green', linestyle='--', 
                label='Baseline', linewidth=2.5, alpha=0.7)
    
    if mag_results:
        sparsities = [r['sparsity'] for r in mag_results]
        accuracies = [r['accuracy'] for r in mag_results]
        ax1.plot(sparsities, accuracies, 'o-', label='Magnitude Pruning', 
                linewidth=2.5, markersize=8, color='#2E86AB')
    
    if rand_results:
        sparsities = [r['sparsity'] for r in rand_results]
        accuracies = [r['accuracy'] for r in rand_results]
        ax1.plot(sparsities, accuracies, 's-', label='Random Pruning', 
                linewidth=2.5, markersize=8, color='#A23B72')
    
    if layer_results:
        sparsities = [r['sparsity'] for r in layer_results]
        accuracies = [r['accuracy'] for r in layer_results]
        ax1.plot(sparsities, accuracies, '^-', label='Layer-wise Pruning', 
                linewidth=2.5, markersize=8, color='#F18F01')
    
    if iter_results:
        sparsities = [r['sparsity'] for r in iter_results]
        accuracies = [r['accuracy'] for r in iter_results]
        ax1.plot(sparsities, accuracies, 'D-', label='Iterative (Magnitude)', 
                linewidth=2.5, markersize=8, color='#C73E1D')
    
    if oneshot_results:
        sparsities = [r['sparsity'] for r in oneshot_results]
        accuracies = [r['accuracy'] for r in oneshot_results]
        ax1.plot(sparsities, accuracies, 'P-', label='One-shot (Magnitude)', 
                linewidth=2.5, markersize=12, color='#6A4C93')
    
    ax1.set_xlabel('Sparsity (%)', fontsize=13, fontweight='bold')
    ax1.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
    ax1.set_title('Accuracy vs Sparsity - All Pruning Methods', fontsize=15, fontweight='bold')
    ax1.legend(fontsize=10, loc='best')
    ax1.grid(True, alpha=0.3, linestyle='--')
    ax1.set_xlim(left=-5)
    
    # ========== Plot 2: Model Size vs Sparsity ==========
    ax2 = fig.add_subplot(gs[0, 2])
    ax2.axhline(y=baseline['size_mb'], color='green', linestyle='--', 
                label='Baseline', linewidth=2.5, alpha=0.7)
    
    if mag_results:
        sparsities = [r['sparsity'] for r in mag_results]
        sizes = [r['size_mb'] for r in mag_results]
        ax2.plot(sparsities, sizes, 'o-', label='Magnitude', 
                linewidth=2, markersize=7, color='#2E86AB')
    
    if iter_results:
        sparsities = [r['sparsity'] for r in iter_results]
        sizes = [r['size_mb'] for r in iter_results]
        ax2.plot(sparsities, sizes, 'D-', label='Iterative', 
                linewidth=2, markersize=7, color='#C73E1D')
    
    ax2.set_xlabel('Sparsity (%)', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Model Size (MB)', fontsize=12, fontweight='bold')
    ax2.set_title('Model Size Reduction', fontsize=13, fontweight='bold')
    ax2.legend(fontsize=9)
    ax2.grid(True, alpha=0.3, linestyle='--')
    
    # ========== Plot 3: Inference Time vs Sparsity ==========
    ax3 = fig.add_subplot(gs[1, 0])
    ax3.axhline(y=baseline['inf_time'], color='green', linestyle='--', 
                label='Baseline', linewidth=2.5, alpha=0.7)
    
    if mag_results:
        sparsities = [r['sparsity'] for r in mag_results]
        times = [r['inf_time'] for r in mag_results]
        ax3.plot(sparsities, times, 'o-', label='Magnitude', 
                linewidth=2, markersize=7, color='#2E86AB')
    
    if iter_results:
        sparsities = [r['sparsity'] for r in iter_results]
        times = [r['inf_time'] for r in iter_results]
        ax3.plot(sparsities, times, 'D-', label='Iterative', 
                linewidth=2, markersize=7, color='#C73E1D')
    
    ax3.set_xlabel('Sparsity (%)', fontsize=12, fontweight='bold')
    ax3.set_ylabel('Inference Time (ms)', fontsize=12, fontweight='bold')
    ax3.set_title('Inference Time vs Sparsity', fontsize=13, fontweight='bold')
    ax3.legend(fontsize=9)
    ax3.grid(True, alpha=0.3, linestyle='--')
    
    # ========== Plot 4: Heuristic Comparison ==========
    ax4 = fig.add_subplot(gs[1, 1])
    
    methods = ['Baseline']
    accuracies = [baseline['accuracy']]
    colors = ['green']
    
    if mag_results:
        methods.append('Magnitude\n(70%)')
        accuracies.append(mag_results[-1]['accuracy'])
        colors.append('#2E86AB')
    
    if rand_results:
        methods.append('Random\n(70%)')
        accuracies.append(rand_results[-1]['accuracy'])
        colors.append('#A23B72')
    
    if layer_results:
        methods.append('Layer-wise\n(70%)')
        accuracies.append(layer_results[-1]['accuracy'])
        colors.append('#F18F01')
    
    bars = ax4.bar(methods, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
    ax4.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    ax4.set_title('Pruning Heuristics Comparison (70% Sparsity)', fontsize=13, fontweight='bold')
    ax4.set_ylim([min(accuracies)-5, 100])
    ax4.grid(True, alpha=0.3, axis='y', linestyle='--')
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}%',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # ========== Plot 5: Iterative vs One-shot ==========
    ax5 = fig.add_subplot(gs[1, 2])
    
    comparison_methods = ['Baseline']
    comparison_accs = [baseline['accuracy']]
    comparison_colors = ['green']
    
    if iter_results:
        comparison_methods.append('Iterative\nPruning')
        comparison_accs.append(iter_results[-1]['accuracy'])
        comparison_colors.append('#C73E1D')
    
    if oneshot_results:
        comparison_methods.append('One-shot\nPruning')
        comparison_accs.append(oneshot_results[-1]['accuracy'])
        comparison_colors.append('#6A4C93')
    
    bars = ax5.bar(comparison_methods, comparison_accs, color=comparison_colors, 
                   alpha=0.7, edgecolor='black', linewidth=1.5)
    ax5.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
    ax5.set_title('Iterative vs One-shot (70% Target)', fontsize=13, fontweight='bold')
    ax5.set_ylim([min(comparison_accs)-5, 100])
    ax5.grid(True, alpha=0.3, axis='y', linestyle='--')
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax5.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}%',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # ========== Plot 6: Iterative Pruning Progress ==========
    ax6 = fig.add_subplot(gs[2, :2])
    
    if iter_results:
        steps = [r['step'] for r in iter_results]
        sparsities = [r['sparsity'] for r in iter_results]
        accuracies = [r['accuracy'] for r in iter_results]
        
        ax6_twin = ax6.twinx()
        
        line1 = ax6.plot(steps, sparsities, 'o-', label='Sparsity', 
                        linewidth=3, markersize=10, color='#C73E1D')
        line2 = ax6_twin.plot(steps, accuracies, 's-', label='Accuracy', 
                             linewidth=3, markersize=10, color='#2E86AB')
        
        ax6.set_xlabel('Pruning Step', fontsize=12, fontweight='bold')
        ax6.set_ylabel('Sparsity (%)', fontsize=12, fontweight='bold', color='#C73E1D')
        ax6_twin.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold', color='#2E86AB')
        ax6.set_title('Iterative Pruning: Progressive Sparsity & Accuracy', 
                     fontsize=13, fontweight='bold')
        
        ax6.tick_params(axis='y', labelcolor='#C73E1D')
        ax6_twin.tick_params(axis='y', labelcolor='#2E86AB')
        
        # Combine legends
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax6.legend(lines, labels, loc='center left', fontsize=10)
        ax6.grid(True, alpha=0.3, linestyle='--')
    
    # ========== Plot 7: Summary Table ==========
    ax7 = fig.add_subplot(gs[2, 2])
    ax7.axis('off')
    
    table_data = [
        ['Method', 'Sparsity\n(%)', 'Accuracy\n(%)', 'Size\n(MB)', 'Time\n(ms)', 'Params\nRemaining']
    ]
    
    total_baseline_params = baseline.get('total_params', 0)
    
    table_data.append([
        'Baseline', 
        '0.0', 
        f"{baseline['accuracy']:.2f}", 
        f"{baseline['size_mb']:.2f}", 
        f"{baseline['inf_time']:.2f}",
        f"{total_baseline_params:,}" if total_baseline_params > 0 else "N/A"
    ])
    
    if mag_results and len(mag_results) > 0:
        r = mag_results[-1]
        table_data.append([
            'Magnitude', 
            f"{r['sparsity']:.1f}", 
            f"{r['accuracy']:.2f}",
            f"{r['size_mb']:.2f}", 
            f"{r['inf_time']:.2f}",
            f"{r.get('non_zero_params', 0):,}"
        ])
    
    if rand_results and len(rand_results) > 0:
        r = rand_results[-1]
        table_data.append([
            'Random', 
            f"{r['sparsity']:.1f}", 
            f"{r['accuracy']:.2f}",
            f"{r['size_mb']:.2f}", 
            f"{r['inf_time']:.2f}",
            f"{r.get('non_zero_params', 0):,}"
        ])
    
    if layer_results and len(layer_results) > 0:
        r = layer_results[-1]
        table_data.append([
            'Layer-wise', 
            f"{r['sparsity']:.1f}", 
            f"{r['accuracy']:.2f}",
            f"{r['size_mb']:.2f}", 
            f"{r['inf_time']:.2f}",
            f"{r.get('non_zero_params', 0):,}"
        ])
    
    if iter_results and len(iter_results) > 0:
        r = iter_results[-1]
        table_data.append([
            'Iterative', 
            f"{r['sparsity']:.1f}", 
            f"{r['accuracy']:.2f}",
            f"{r['size_mb']:.2f}", 
            f"{r['inf_time']:.2f}",
            f"{r.get('non_zero_params', 0):,}"
        ])
    
    if oneshot_results and len(oneshot_results) > 0:
        r = oneshot_results[-1]
        table_data.append([
            'One-shot', 
            f"{r['sparsity']:.1f}", 
            f"{r['accuracy']:.2f}",
            f"{r['size_mb']:.2f}", 
            f"{r['inf_time']:.2f}",
            f"{r.get('non_zero_params', 0):,}"
        ])
    
    table = ax7.table(cellText=table_data, cellLoc='center', loc='center',
                     colWidths=[0.18, 0.12, 0.12, 0.12, 0.12, 0.16])
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 2.5)
    
    # Style header row
    for i in range(len(table_data[0])):
        cell = table[(0, i)]
        cell.set_facecolor('#2E86AB')
        cell.set_text_props(weight='bold', color='white', fontsize=9)
    
    # Style data rows with alternating colors
    for i in range(1, len(table_data)):
        for j in range(len(table_data[0])):
            cell = table[(i, j)]
            if i % 2 == 0:
                cell.set_facecolor('#F0F0F0')
            else:
                cell.set_facecolor('white')
    
    plt.suptitle('PPE Detection Model Pruning - Comprehensive Results', 
                 fontsize=18, fontweight='bold', y=0.995)
    
    plt.savefig('pruning_comprehensive_results.png', dpi=300, bbox_inches='tight')
    print("\n" + "="*70)
    print("Comprehensive results saved to 'pruning_comprehensive_results.png'")
    print("="*70)
    plt.show()

def print_summary(baseline, mag_results, rand_results, layer_results, 
                 iter_results, oneshot_results):
    """Print a text summary of all results"""
    print("\n" + "="*70)
    print("PRUNING EXPERIMENTS SUMMARY")
    print("="*70)
    
    print(f"\n{'Method':<20} {'Sparsity':<12} {'Accuracy':<12} {'Size (MB)':<12} {'Speedup':<10}")
    print("-"*70)
    
    print(f"{'Baseline':<20} {'0.0%':<12} {baseline['accuracy']:>6.2f}% {baseline['size_mb']:>9.2f} {'1.00x':<10}")
    
    if mag_results:
        for i, r in enumerate(mag_results):
            label = f"Magnitude {int(r['sparsity'])}%"
            speedup = baseline['inf_time'] / r['inf_time']
            print(f"{label:<20} {r['sparsity']:>5.1f}% {r['accuracy']:>11.2f}% {r['size_mb']:>9.2f} {speedup:>9.2f}x")
    
    if rand_results:
        for i, r in enumerate(rand_results):
            label = f"Random {int(r['sparsity'])}%"
            speedup = baseline['inf_time'] / r['inf_time']
            print(f"{label:<20} {r['sparsity']:>5.1f}% {r['accuracy']:>11.2f}% {r['size_mb']:>9.2f} {speedup:>9.2f}x")
    
    if layer_results:
        for i, r in enumerate(layer_results):
            label = f"Layer-wise {int(r['sparsity'])}%"
            speedup = baseline['inf_time'] / r['inf_time']
            print(f"{label:<20} {r['sparsity']:>5.1f}% {r['accuracy']:>11.2f}% {r['size_mb']:>9.2f} {speedup:>9.2f}x")
    
    if iter_results:
        r = iter_results[-1]
        label = "Iterative (final)"
        speedup = baseline['inf_time'] / r['inf_time']
        print(f"{label:<20} {r['sparsity']:>5.1f}% {r['accuracy']:>11.2f}% {r['size_mb']:>9.2f} {speedup:>9.2f}x")
    
    if oneshot_results:
        r = oneshot_results[-1]
        label = "One-shot"
        speedup = baseline['inf_time'] / r['inf_time']
        print(f"{label:<20} {r['sparsity']:>5.1f}% {r['accuracy']:>11.2f}% {r['size_mb']:>9.2f} {speedup:>9.2f}x")
    
    print("="*70)
    
    # Key findings
    print("\nKEY FINDINGS:")
    print("-"*70)
    
    if mag_results and rand_results:
        mag_acc = mag_results[-1]['accuracy']
        rand_acc = rand_results[-1]['accuracy']
        print(f"1. Magnitude pruning outperforms random pruning by {mag_acc - rand_acc:.2f}%")
    
    if iter_results and oneshot_results:
        iter_acc = iter_results[-1]['accuracy']
        oneshot_acc = oneshot_results[-1]['accuracy']
        diff = iter_acc - oneshot_acc
        if diff > 0:
            print(f"2. Iterative pruning achieves {diff:.2f}% higher accuracy than one-shot")
        else:
            print(f"2. One-shot pruning achieves {abs(diff):.2f}% higher accuracy than iterative")
    
    if mag_results:
        final_sparsity = mag_results[-1]['sparsity']
        size_reduction = (1 - mag_results[-1]['size_mb'] / baseline['size_mb']) * 100
        acc_drop = baseline['accuracy'] - mag_results[-1]['accuracy']
        print(f"3. At {final_sparsity:.1f}% sparsity: {size_reduction:.1f}% size reduction with {acc_drop:.2f}% accuracy drop")
    
    print("="*70 + "\n")

def main():
    print("="*70)
    print("PPE DETECTION MODEL PRUNING EXPERIMENTS")
    print("="*70)
    print(f"Device: {DEVICE}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Image Size: {IMG_SIZE}x{IMG_SIZE}")
    print(f"Number of Classes: {NUM_CLASSES}")
    print("="*70)
    
    # Setup data transforms
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Load datasets - UPDATE THESE PATHS TO YOUR DATA
    train_images = 'train/images'
    train_labels = 'train/labels'
    val_images = 'valid/images'
    val_labels = 'valid/labels'
    
    # Dataset mode
    DATASET_MODE = 'crop_objects'  # Recommended for object detection
    MULTI_LABEL = (DATASET_MODE == 'multi_label')
    
    print(f"\nDataset Configuration:")
    print(f"  Mode: {DATASET_MODE}")
    print(f"  Multi-label: {MULTI_LABEL}")
    
    print("\nLoading datasets...")
    train_dataset = PPEDataset(train_images, train_labels, transform=transform, mode=DATASET_MODE)
    val_dataset = PPEDataset(val_images, val_labels, transform=transform, mode=DATASET_MODE)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    print(f"  Train samples: {len(train_dataset)}")
    print(f"  Validation samples: {len(val_dataset)}")
    
    # ==============================================================================
    # BASELINE MODEL TRAINING
    # ==============================================================================
    print("\n" + "="*70)
    print("PHASE 1: TRAINING BASELINE MODEL")
    print("="*70)
    
    model = create_model(multi_label=MULTI_LABEL)
    
    if MULTI_LABEL:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    print("\nTraining baseline model...")
    for epoch in range(EPOCHS):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, MULTI_LABEL)
        val_acc, val_time = evaluate(model, val_loader, MULTI_LABEL)
        print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    # Baseline metrics
    baseline_acc, baseline_time = evaluate(model, val_loader, MULTI_LABEL)
    baseline_size = get_model_size(model)
    baseline_sparsity = get_sparsity(model)
    total_params, non_zero_params = count_parameters(model)
    
    baseline = {
        'accuracy': baseline_acc,
        'size_mb': baseline_size,
        'inf_time': baseline_time,
        'sparsity': baseline_sparsity,
        'total_params': total_params,
        'non_zero_params': non_zero_params
    }
    
    print("\n" + "-"*70)
    print("BASELINE MODEL METRICS:")
    print("-"*70)
    print(f"Accuracy: {baseline_acc:.2f}%")
    print(f"Model Size: {baseline_size:.2f} MB")
    print(f"Inference Time: {baseline_time:.2f} ms")
    print(f"Parameters: {total_params:,}")
    print(f"Initial Sparsity: {baseline_sparsity:.2f}%")
    print("-"*70)
    
    # Save baseline model
    torch.save(model.state_dict(), 'baseline_model.pth')
    print("Baseline model saved to 'baseline_model.pth'")
    
    # ==============================================================================
    # EXPERIMENT 1: MAGNITUDE PRUNING (Different Levels)
    # ==============================================================================
    print("\n" + "="*70)
    print("PHASE 2: MAGNITUDE-BASED PRUNING EXPERIMENTS")
    print("="*70)
    
    mag_results = []
    for prune_amount in [0.3, 0.5, 0.7]:
        print(f"\n{'*'*70}")
        print(f"Magnitude Pruning: {prune_amount*100:.0f}% Target Sparsity")
        print(f"{'*'*70}")
        
        model_mag = copy.deepcopy(model)
        model_mag = magnitude_pruning(model_mag, prune_amount)
        model_mag = fine_tune(model_mag, train_loader, val_loader, epochs=5, multi_label=MULTI_LABEL)
        model_mag = make_pruning_permanent(model_mag)
        
        sparsity = get_sparsity(model_mag)
        accuracy, inf_time = evaluate(model_mag, val_loader, MULTI_LABEL)
        size_mb = get_model_size(model_mag)
        total_params, non_zero_params = count_parameters(model_mag)
        
        mag_results.append({
            'sparsity': sparsity,
            'accuracy': accuracy,
            'size_mb': size_mb,
            'inf_time': inf_time,
            'total_params': total_params,
            'non_zero_params': non_zero_params
        })
        
        print(f"\nResults:")
        print(f"  Achieved Sparsity: {sparsity:.2f}%")
        print(f"  Accuracy: {accuracy:.2f}%")
        print(f"  Model Size: {size_mb:.2f} MB")
        print(f"  Inference Time: {inf_time:.2f} ms")
        print(f"  Parameters: {non_zero_params:,} / {total_params:,}")
    
    # ==============================================================================
    # EXPERIMENT 2: RANDOM PRUNING
    # ==============================================================================
    print("\n" + "="*70)
    print("PHASE 3: RANDOM PRUNING EXPERIMENTS")
    print("="*70)
    
    rand_results = []
    for prune_amount in [0.3, 0.5, 0.7]:
        print(f"\n{'*'*70}")
        print(f"Random Pruning: {prune_amount*100:.0f}% Target Sparsity")
        print(f"{'*'*70}")
        
        model_rand = copy.deepcopy(model)
        model_rand = random_pruning(model_rand, prune_amount)
        model_rand = fine_tune(model_rand, train_loader, val_loader, epochs=5, multi_label=MULTI_LABEL)
        model_rand = make_pruning_permanent(model_rand)
        
        sparsity = get_sparsity(model_rand)
        accuracy, inf_time = evaluate(model_rand, val_loader, MULTI_LABEL)
        size_mb = get_model_size(model_rand)
        total_params, non_zero_params = count_parameters(model_rand)
        
        rand_results.append({
            'sparsity': sparsity,
            'accuracy': accuracy,
            'size_mb': size_mb,
            'inf_time': inf_time,
            'total_params': total_params,
            'non_zero_params': non_zero_params
        })
        
        print(f"\nResults:")
        print(f"  Achieved Sparsity: {sparsity:.2f}%")
        print(f"  Accuracy: {accuracy:.2f}%")
        print(f"  Model Size: {size_mb:.2f} MB")
        print(f"  Inference Time: {inf_time:.2f} ms")
        print(f"  Parameters: {non_zero_params:,} / {total_params:,}")
    
    # ==============================================================================
    # EXPERIMENT 3: LAYER-WISE PRUNING
    # ==============================================================================
    print("\n" + "="*70)
    print("PHASE 4: LAYER-WISE PRUNING EXPERIMENTS")
    print("="*70)
    
    layer_results = []
    for prune_amount in [0.3, 0.5, 0.7]:
        print(f"\n{'*'*70}")
        print(f"Layer-wise Pruning: {prune_amount*100:.0f}% Per Layer")
        print(f"{'*'*70}")
        
        model_layer = copy.deepcopy(model)
        model_layer = layerwise_pruning(model_layer, prune_amount)
        model_layer = fine_tune(model_layer, train_loader, val_loader, epochs=5, multi_label=MULTI_LABEL)
        model_layer = make_pruning_permanent(model_layer)
        
        sparsity = get_sparsity(model_layer)
        accuracy, inf_time = evaluate(model_layer, val_loader, MULTI_LABEL)
        size_mb = get_model_size(model_layer)
        total_params, non_zero_params = count_parameters(model_layer)
        
        layer_results.append({
            'sparsity': sparsity,
            'accuracy': accuracy,
            'size_mb': size_mb,
            'inf_time': inf_time,
            'total_params': total_params,
            'non_zero_params': non_zero_params
        })
        
        print(f"\nResults:")
        print(f"  Achieved Sparsity: {sparsity:.2f}%")
        print(f"  Accuracy: {accuracy:.2f}%")
        print(f"  Model Size: {size_mb:.2f} MB")
        print(f"  Inference Time: {inf_time:.2f} ms")
        print(f"  Parameters: {non_zero_params:,} / {total_params:,}")
    
    # ==============================================================================
    # EXPERIMENT 4: ITERATIVE VS ONE-SHOT PRUNING
    # ==============================================================================
    print("\n" + "="*70)
    print("PHASE 5: ITERATIVE VS ONE-SHOT PRUNING COMPARISON")
    print("="*70)
    
    # Iterative Pruning
    print(f"\n{'*'*70}")
    print("Iterative Pruning (5 steps to 70% sparsity)")
    print(f"{'*'*70}")
    model_iter = copy.deepcopy(model)
    model_iter, iter_results = iterative_pruning(
        model_iter, train_loader, val_loader, 
        target_sparsity=0.7, steps=5, multi_label=MULTI_LABEL
    )
    
    # One-shot Pruning
    print(f"\n{'*'*70}")
    print("One-shot Pruning (Direct to 70% sparsity)")
    print(f"{'*'*70}")
    model_oneshot = copy.deepcopy(model)
    model_oneshot, oneshot_results = oneshot_pruning(
        model_oneshot, train_loader, val_loader, 
        target_sparsity=0.7, multi_label=MULTI_LABEL
    )
    
    # ==============================================================================
    # SAVE BEST MODELS
    # ==============================================================================
    print("\n" + "="*70)
    print("SAVING PRUNED MODELS")
    print("="*70)
    
    torch.save(model_iter.state_dict(), 'model_iterative_pruned.pth')
    print("Iterative pruned model saved to 'model_iterative_pruned.pth'")
    
    torch.save(model_oneshot.state_dict(), 'model_oneshot_pruned.pth')
    print("One-shot pruned model saved to 'model_oneshot_pruned.pth'")
    
    if mag_results:
        model_mag_best = copy.deepcopy(model)
        model_mag_best = magnitude_pruning(model_mag_best, 0.7)
        model_mag_best = fine_tune(model_mag_best, train_loader, val_loader, epochs=5, multi_label=MULTI_LABEL)
        model_mag_best = make_pruning_permanent(model_mag_best)
        torch.save(model_mag_best.state_dict(), 'model_magnitude_pruned.pth')
        print("Magnitude pruned model saved to 'model_magnitude_pruned.pth'")
    
    # ==============================================================================
    # VISUALIZATION AND SUMMARY
    # ==============================================================================
    print("\n" + "="*70)
    print("GENERATING COMPREHENSIVE VISUALIZATION")
    print("="*70)
    
    plot_comprehensive_results(baseline, mag_results, rand_results, layer_results, 
                               iter_results, oneshot_results)
    
    print_summary(baseline, mag_results, rand_results, layer_results, 
                 iter_results, oneshot_results)
    
    print("\n" + "="*70)
    print("EXPERIMENTS COMPLETED SUCCESSFULLY!")
    print("="*70)
    print("\nGenerated Files:")
    print("  1. baseline_model.pth - Original trained model")
    print("  2. model_iterative_pruned.pth - Best iterative pruned model")
    print("  3. model_oneshot_pruned.pth - One-shot pruned model")
    print("  4. model_magnitude_pruned.pth - Magnitude pruned model")
    print("  5. pruning_comprehensive_results.png - Visualization of all results")
    print("="*70 + "\n")

if __name__ == '__main__':
    main()