# Installs

In [None]:
!pip install "numpy<2.0"
!pip install torch==2.1.0+cu118 torchvision==0.16.0+cu118 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118

!pip install torch_geometric


Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch==2.1.0+cu118
  Downloading https://download.pytorch.org/whl/cu118/torch-2.1.0%2Bcu118-cp311-cp311-linux_x86_64.whl (2325.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 GB[0m [31m486.1 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchvision==0.16.0+cu118
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.16.0%2Bcu118-cp311-cp311-linux_x86_64.whl (6.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.2/6.2 MB[0m [31m110.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchaudio==2.1.0
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.1.0%2Bcu118-cp311-cp311-linux_x86_64.whl (3.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m102.3 MB/s[0m eta [36m0:00:00[0m
Collecting triton==2.1.0 (from torch==2.1.0+cu118)
  Downloading https://download.pytorch.org/whl/triton-2.1.0-0-cp311-cp311-

In [None]:
!pip install scikit-optimize

Collecting scikit-optimize
  Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl.metadata (9.7 kB)
Collecting pyaml>=16.9 (from scikit-optimize)
  Downloading pyaml-25.5.0-py3-none-any.whl.metadata (12 kB)
Downloading scikit_optimize-0.10.2-py2.py3-none-any.whl (107 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.8/107.8 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyaml-25.5.0-py3-none-any.whl (26 kB)
Installing collected packages: pyaml, scikit-optimize
Successfully installed pyaml-25.5.0 scikit-optimize-0.10.2


# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphSAGE, GCNConv, GATConv, GINConv, LayerNorm, BatchNorm
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, global_sort_pool
from torch_geometric.data import DataLoader
from torch_geometric.nn.aggr import SumAggregation, MeanAggregation, MaxAggregation, StdAggregation

import numpy as np
import pandas as pd
import time
import psutil
import os
import pickle
from typing import Dict, List, Tuple, Any
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.model_selection import StratifiedKFold
import warnings
from torch_geometric.datasets import TUDataset
from sklearn.model_selection import ParameterSampler

warnings.filterwarnings('ignore')

# Bayesian Optimization imports
try:
    from skopt import gp_minimize
    from skopt.space import Real, Integer, Categorical
    from skopt.utils import use_named_args
    from skopt.acquisition import gaussian_ei
    BAYESIAN_OPT_AVAILABLE = True
except ImportError:
    print("Warning: scikit-optimize not available. Please install with: pip install scikit-optimize")
    BAYESIAN_OPT_AVAILABLE = False

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool, global_max_pool, global_add_pool
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
from torch_geometric.nn import GINConv, BatchNorm, LayerNorm
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time
import psutil
import gc
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')

# Configurations

In [None]:
torch.manual_seed(42)
np.random.seed(42)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


# Loading the dataset

1. **Nodes**: Amino acids.
2. **Edges**: Connections between amino acids that are within 6 Ångströms of each other.
3. **Labels**: Binary classification indicating whether a protein is an enzyme or not.

Clarification:
1. The 89 features are node features, not graph-level features (Structural information, Chemical properties)


In [None]:
print("Loading DD dataset...")
dataset = TUDataset(root='/tmp/DD', name='DD')

print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")

# Get first graph for exploration
data = dataset[0]
print(f"\nFirst graph:")
print(f"Number of nodes: {data.x.shape[0]}")
print(f"Number of edges: {data.edge_index.shape[1]}")
print(f"Node feature shape: {data.x.shape}")
print(f"Label: {data.y}")

Loading DD dataset...


Downloading https://www.chrsmrrs.com/graphkerneldatasets/DD.zip
Processing...


Dataset: DD(1178)
Number of graphs: 1178
Number of features: 89
Number of classes: 2

First graph:
Number of nodes: 327
Number of edges: 1798
Node feature shape: torch.Size([327, 89])
Label: tensor([0])


Done!


# DATA PREPROCESSING

In [None]:
data_list = [data for data in dataset]

## splitting

In [None]:
from sklearn.model_selection import train_test_split

labels = [data.y.item() for data in data_list]
all_indices = list(range(len(data_list)))

# First split: 60% train, 40% (val+test)
train_idx, temp_idx = train_test_split(
    all_indices,
    test_size=0.4,
    stratify=labels,
    random_state=42
)

# Second split: 20% val, 20% test from the remaining 40%
# So test_size=0.5 here means 50% of the 40% => 20% of the total
temp_labels = [labels[i] for i in temp_idx]
val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    stratify=temp_labels,
    random_state=42
)


In [None]:
print(f"Train set size: {len(train_idx)}")
print(f"Validation set size: {len(val_idx)}")
print(f"Test set size: {len(test_idx)}")


Train set size: 706
Validation set size: 236
Test set size: 236


In [None]:
train_dataset = [data_list[i] for i in train_idx]
val_dataset = [data_list[i] for i in val_idx]
test_dataset = [data_list[i] for i in test_idx]

# data loaders

In [None]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"\nBatch size: {batch_size}")
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")


Batch size: 32
Number of training batches: 23
Number of validation batches: 8
Number of test batches: 8


# TRAINING UTILITIES

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_memory_usage():
    """Get current memory usage"""
    process = psutil.Process()
    cpu_memory = process.memory_info().rss / 1024 / 1024  # MB
    gpu_memory = 0
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.memory_allocated() / 1024 / 1024  # MB
    return cpu_memory, gpu_memory

def measure_inference_time(model, loader, device, num_samples=100):
    model.eval()
    times = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i * batch.y.size(0) >= num_samples:
                break
            batch = batch.to(device)
            start_time = time.time()
            _ = model(batch.x, batch.edge_index, batch.batch)
            end_time = time.time()
            batch_time = (end_time - start_time) / batch.y.size(0)
            times.append(batch_time)
    return np.mean(times) * 1000  # Convert to milliseconds


# Model definition: GINModel

In [None]:
class GINModel(nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes, num_layers,
                 dropout=0.3, global_pool='mean', eps=0.0, train_eps=False,
                 batch_norm=True, layer_norm=False):
        super().__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if batch_norm else None
        self.layer_norms = nn.ModuleList() if layer_norm else None

        # First layer
        mlp = nn.Sequential(
            nn.Linear(num_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.convs.append(GINConv(mlp, eps=eps, train_eps=train_eps))

        # Hidden layers
        for _ in range(num_layers - 2):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(mlp, eps=eps, train_eps=train_eps))

        # Last layer
        if num_layers > 1:
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(GINConv(mlp, eps=eps, train_eps=train_eps))

        # Normalization layers
        if batch_norm:
            for _ in range(num_layers):
                self.batch_norms.append(BatchNorm(hidden_dim))
        if layer_norm:
            for _ in range(num_layers):
                self.layer_norms.append(LayerNorm(hidden_dim))

        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

        # Global pooling
        self.global_pool = {
            'mean': global_mean_pool,
            'max': global_max_pool,
            'add': global_add_pool
        }[global_pool]

    def forward(self, x, edge_index, batch):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            if self.batch_norms:
                x = self.batch_norms[i](x)
            if self.layer_norms:
                x = self.layer_norms[i](x)
            x = F.relu(x)
            x = self.dropout(x)

        # Last layer
        if len(self.convs) > 0:
            x = self.convs[-1](x, edge_index)
            if self.batch_norms and len(self.batch_norms) > len(self.convs) - 1:
                x = self.batch_norms[-1](x)
            if self.layer_norms and len(self.layer_norms) > len(self.convs) - 1:
                x = self.layer_norms[-1](x)

        x = self.global_pool(x, batch)
        return self.classifier(x)

In [None]:
def get_model_size(model):
    """Calculate model size in MB"""
    param_size = 0
    buffer_size = 0

    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_mb = (param_size + buffer_size) / 1024 / 1024
    return size_mb

In [None]:
def count_parameters(model):
    """Count total and non-zero parameters"""
    total_params = sum(p.numel() for p in model.parameters())

    # Count non-zero parameters (after pruning)
    non_zero_params = 0
    for p in model.parameters():
        if hasattr(p, 'weight_mask'):
            # Pruned parameter
            non_zero_params += torch.sum(p.weight_mask).item()
        else:
            # Regular parameter
            non_zero_params += torch.sum(p != 0).item()

    sparsity = (total_params - non_zero_params) / total_params * 100
    return total_params, non_zero_params, sparsity

In [None]:
def calculate_flops_reduction(original_params, pruned_params):
    """Estimate FLOPS reduction based on parameter reduction"""
    return (original_params - pruned_params) / original_params * 100

# Evalutation of model

In [None]:
def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            pred = model(batch.x, batch.edge_index, batch.batch).argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            total += batch.y.size(0)
    return correct / total

def evaluate_detailed(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            pred = model(batch.x, batch.edge_index, batch.batch).argmax(dim=1)
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())

    return {
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds, average='binary'),
        'precision': precision_score(all_labels, all_preds, average='binary'),
        'recall': recall_score(all_labels, all_preds, average='binary'),
        'predictions': all_preds,
        'labels': all_labels
    }


In [None]:
def evaluate_model(model, test_loader, device):
    """Evaluate model performance"""
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    inference_times = []

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)

            start_time = time.time()
            out = model(batch.x, batch.edge_index, batch.batch)
            inference_time = (time.time() - start_time) * 1000  # ms

            pred = out.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            total += batch.y.size(0)

            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())
            inference_times.append(inference_time)

    accuracy = correct / total

    from sklearn.metrics import f1_score, precision_score, recall_score
    f1 = f1_score(all_labels, all_preds, average='weighted')
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    avg_inference_time = np.mean(inference_times)

    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'inference_time_ms': avg_inference_time
    }


# Training and evalutation function

In [None]:
def train_evaluate(model, train_loader, val_loader, test_loader, device, epochs=100, plot_training=False):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0
    patience_counter = 0
    start_time = time.time()

    # For plotting and peak memory tracking
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    epoch_times = []
    peak_train_cpu = peak_train_gpu = 0
    peak_inf_cpu = peak_inf_gpu = 0

    for epoch in range(epochs):
        epoch_start = time.time()

        # Train with peak memory tracking
        model.train()
        epoch_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Track peak memory during training
        cpu_mem, gpu_mem = get_memory_usage()
        peak_train_cpu = max(peak_train_cpu, cpu_mem)
        peak_train_gpu = max(peak_train_gpu, gpu_mem)

        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)

        # Validate every 10 epochs
        if epoch % 10 == 0:
            val_acc = evaluate(model, val_loader, device)
            train_acc = evaluate(model, train_loader, device)
            val_loss = evaluate_loss(model, val_loader, device, criterion)

            train_losses.append(epoch_loss / len(train_loader))
            val_losses.append(val_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= 5:  # Early stopping
                break

    training_time = time.time() - start_time

    # Plot training curves for best model
    if plot_training:
        plt.figure(figsize=(15, 5))
        epochs_range = range(0, len(train_losses) * 10, 10)

        plt.subplot(1, 3, 1)
        plt.plot(epochs_range, train_losses, 'b-', label='Train Loss')
        plt.plot(epochs_range, val_losses, 'r-', label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Train vs Validation Loss')
        plt.legend()

        plt.subplot(1, 3, 2)
        plt.plot(epochs_range, train_accs, 'b-', label='Train Accuracy')
        plt.plot(epochs_range, val_accs, 'r-', label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Train vs Validation Accuracy')
        plt.legend()

        plt.subplot(1, 3, 3)
        plt.plot(epochs_range, train_losses, 'b-', label='Train Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.legend()

        plt.tight_layout()
        plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
        plt.show()

    # Peak memory during inference
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            _ = model(batch.x, batch.edge_index, batch.batch)
            cpu_mem, gpu_mem = get_memory_usage()
            peak_inf_cpu = max(peak_inf_cpu, cpu_mem)
            peak_inf_gpu = max(peak_inf_gpu, gpu_mem)

    metrics = evaluate_detailed(model, test_loader, device)
    inference_time = measure_inference_time(model, test_loader, device)

    metrics.update({
        'training_time': training_time,
        'avg_epoch_time': np.mean(epoch_times),
        'parameters': count_parameters(model),
        'peak_train_cpu_mb': peak_train_cpu,
        'peak_train_gpu_mb': peak_train_gpu,
        'peak_inf_cpu_mb': peak_inf_cpu,
        'peak_inf_gpu_mb': peak_inf_gpu,
        'inference_time_ms': inference_time
    })

    return metrics

# Helper function for validation loss


In [None]:
def evaluate_loss(model, loader, device, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = criterion(out, batch.y)
            total_loss += loss.item()
    return total_loss / len(loader)

# Optimize the model

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import time
import gc
import psutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from copy import deepcopy
from sklearn.metrics import f1_score, precision_score, recall_score
from torch_geometric.data import Data, Batch

class PruningComparator:
    def __init__(self, model_config, device=None):
        self.model_config = model_config
        # Auto-detect device if not specified
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        print(f"PruningComparator initialized with device: {self.device}")
        self.results = {}

    def ensure_device_consistency(self, model, data):
        """Ensure model and data are on the same device"""
        # Move model to device
        model = model.to(self.device)

        # Handle different data types
        if isinstance(data, (Data, Batch)):
            data = data.to(self.device)
        elif hasattr(data, 'to') and callable(getattr(data, 'to')):
            # For objects that have a .to() method
            data = data.to(self.device)
        elif hasattr(data, '__dict__'):
            # For custom objects with attributes, move tensor attributes to device
            for attr_name in dir(data):
                if not attr_name.startswith('_'):
                    attr_value = getattr(data, attr_name)
                    if torch.is_tensor(attr_value):
                        setattr(data, attr_name, attr_value.to(self.device))
        elif isinstance(data, tuple):
            data = tuple(item.to(self.device) if torch.is_tensor(item) else item for item in data)
        elif torch.is_tensor(data):
            data = data.to(self.device)

        return model, data

    def get_model_size(self, model):
        """Calculate model size in MB"""
        param_size = 0
        buffer_size = 0

        for param in model.parameters():
            param_size += param.nelement() * param.element_size()

        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()

        model_size = (param_size + buffer_size) / 1024 / 1024
        return model_size

    def count_parameters(self, model):
        """Count total and non-zero parameters"""
        total_params = sum(p.numel() for p in model.parameters())

        # Count non-zero parameters (for pruned models)
        non_zero_params = 0
        for p in model.parameters():
            if hasattr(p, 'weight_mask'):
                non_zero_params += torch.sum(p.weight_mask).item()
            else:
                non_zero_params += torch.sum(p != 0).item()

        return total_params, non_zero_params

    def get_memory_usage(self):
        """Get current memory usage"""
        process = psutil.Process()
        cpu_memory = process.memory_info().rss / 1024 / 1024  # MB

        gpu_memory = 0
        if torch.cuda.is_available() and self.device.type == 'cuda':
            gpu_memory = torch.cuda.memory_allocated(self.device) / 1024 / 1024  # MB

        return cpu_memory, gpu_memory

    def measure_inference_time(self, model, sample_data, num_runs=100):
        """Measure inference time with device consistency"""
        # Ensure everything is on the same device
        model, sample_data = self.ensure_device_consistency(model, sample_data)
        model.eval()

        # Handle different data formats
        if isinstance(sample_data, (Data, Batch)):
            x, edge_index, batch = sample_data.x, sample_data.edge_index, sample_data.batch
        else:
            x, edge_index, batch = sample_data

        # Ensure all tensors are on the correct device
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        if batch is not None:
            batch = batch.to(self.device)

        # Warmup runs
        with torch.no_grad():
            for _ in range(10):
                try:
                    if batch is not None:
                        _ = model(x, edge_index, batch)
                    else:
                        _ = model(x, edge_index)
                except Exception as e:
                    print(f"Warning during warmup: {e}")
                    break

        # Measure inference time
        if self.device.type == 'cuda':
            torch.cuda.synchronize(self.device)

        start_time = time.time()
        with torch.no_grad():
            for _ in range(num_runs):
                try:
                    if batch is not None:
                        _ = model(x, edge_index, batch)
                    else:
                        _ = model(x, edge_index)
                except Exception as e:
                    print(f"Error during inference timing: {e}")
                    return float('inf')

        if self.device.type == 'cuda':
            torch.cuda.synchronize(self.device)

        end_time = time.time()
        avg_time = (end_time - start_time) / num_runs * 1000  # ms
        return avg_time

    def unstructured_pruning(self, model, pruning_ratio=0.5):
        """Apply unstructured magnitude-based pruning"""
        model_pruned = deepcopy(model).to(self.device)

        # Collect all linear layers
        modules_to_prune = []
        for name, module in model_pruned.named_modules():
            if isinstance(module, nn.Linear):
                modules_to_prune.append((module, 'weight'))

        if not modules_to_prune:
            print("Warning: No Linear layers found for pruning")
            return model_pruned

        # Apply global unstructured pruning
        prune.global_unstructured(
            modules_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=pruning_ratio,
        )

        # Make pruning permanent
        for module, param_name in modules_to_prune:
            prune.remove(module, param_name)

        return model_pruned

    def structured_pruning(self, model, pruning_ratio=0.5):
        """Apply structured pruning (remove entire neurons/channels)"""
        model_pruned = deepcopy(model).to(self.device)

        # Apply structured pruning to linear layers
        for name, module in model_pruned.named_modules():
            if isinstance(module, nn.Linear) and hasattr(module, 'weight'):
                # Calculate L1 norm for each output neuron
                weight = module.weight.data
                l1_norm = torch.norm(weight, p=1, dim=1)

                # Determine number of neurons to prune
                num_neurons = weight.size(0)
                num_to_prune = int(num_neurons * pruning_ratio)

                if num_to_prune > 0 and num_to_prune < num_neurons:
                    try:
                        # Apply structured pruning
                        prune.ln_structured(module, name='weight', amount=num_to_prune, n=1, dim=0)
                        prune.remove(module, 'weight')
                    except Exception as e:
                        print(f"Warning: Could not apply structured pruning to {name}: {e}")

        return model_pruned

    def magnitude_based_pruning(self, model, pruning_ratio=0.5):
        """Apply magnitude-based pruning with different thresholds"""
        model_pruned = deepcopy(model).to(self.device)

        # Collect all weights
        all_weights = []
        for param in model_pruned.parameters():
            if param.requires_grad:
                all_weights.extend(param.data.abs().flatten().tolist())

        if not all_weights:
            print("Warning: No trainable parameters found")
            return model_pruned

        # Calculate threshold based on magnitude
        all_weights = torch.tensor(all_weights, device=self.device)
        threshold = torch.quantile(all_weights, pruning_ratio)

        # Apply pruning based on magnitude threshold
        for param in model_pruned.parameters():
            if param.requires_grad:
                mask = param.data.abs() > threshold
                param.data *= mask.float()

        return model_pruned

    def random_pruning(self, model, pruning_ratio=0.5):
        """Apply random pruning for comparison"""
        model_pruned = deepcopy(model).to(self.device)

        modules_to_prune = []
        for name, module in model_pruned.named_modules():
            if isinstance(module, nn.Linear):
                modules_to_prune.append((module, 'weight'))

        if not modules_to_prune:
            print("Warning: No Linear layers found for random pruning")
            return model_pruned

        # Apply random pruning
        prune.global_unstructured(
            modules_to_prune,
            pruning_method=prune.RandomUnstructured,
            amount=pruning_ratio,
        )

        # Make pruning permanent
        for module, param_name in modules_to_prune:
            prune.remove(module, param_name)

        return model_pruned

    def evaluate_model(self, model, test_loader, criterion):
        """Evaluate model performance with device consistency"""
        model, _ = self.ensure_device_consistency(model, None)
        model.eval()

        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for data in test_loader:
                # Move data to device - handle custom data objects
                if hasattr(data, 'to') and callable(getattr(data, 'to')):
                    data = data.to(self.device)
                elif hasattr(data, '__dict__'):
                    # For custom objects, move tensor attributes to device
                    for attr_name in ['x', 'edge_index', 'batch', 'y']:
                        if hasattr(data, attr_name):
                            attr_value = getattr(data, attr_name)
                            if torch.is_tensor(attr_value):
                                setattr(data, attr_name, attr_value.to(self.device))

                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y

                try:
                    outputs = model(x, edge_index, batch)
                    loss = criterion(outputs, y)
                    total_loss += loss.item()

                    _, predicted = torch.max(outputs.data, 1)
                    total += y.size(0)
                    correct += (predicted == y).sum().item()

                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(y.cpu().numpy())
                except Exception as e:
                    print(f"Error during evaluation: {e}")
                    continue

        if total == 0:
            return {
                'accuracy': 0.0,
                'loss': float('inf'),
                'f1_score': 0.0,
                'precision': 0.0,
                'recall': 0.0
            }

        accuracy = correct / total
        avg_loss = total_loss / len(test_loader) if len(test_loader) > 0 else float('inf')

        # Calculate additional metrics
        try:
            f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
            precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
            recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        except Exception as e:
            print(f"Warning: Could not calculate metrics: {e}")
            f1 = precision = recall = 0.0

        return {
            'accuracy': accuracy,
            'loss': avg_loss,
            'f1_score': f1,
            'precision': precision,
            'recall': recall
        }

    def compare_pruning_methods(self, original_model, test_loader, sample_data,
                              pruning_ratios=[0.1, 0.3, 0.5, 0.7, 0.9]):
        """Compare different pruning methods with device consistency"""
        criterion = nn.CrossEntropyLoss().to(self.device)

        # Ensure original model is on correct device
        original_model = original_model.to(self.device)

        pruning_methods = {
            'Original': lambda m, r: m,
            'Unstructured': self.unstructured_pruning,
            'Structured': self.structured_pruning,
            'Magnitude-based': self.magnitude_based_pruning,
            'Random': self.random_pruning
        }

        results = []

        for method_name, pruning_func in pruning_methods.items():
            print(f"\n{'='*50}")
            print(f"Testing {method_name} Pruning")
            print(f"{'='*50}")

            if method_name == 'Original':
                ratios_to_test = [0.0]
            else:
                ratios_to_test = pruning_ratios

            for ratio in ratios_to_test:
                print(f"\nPruning ratio: {ratio}")

                try:
                    # Apply pruning
                    if method_name == 'Original':
                        pruned_model = deepcopy(original_model).to(self.device)
                    else:
                        pruned_model = pruning_func(original_model, ratio)

                    # Measure metrics
                    start_time = time.time()

                    # Model size and parameters
                    model_size = self.get_model_size(pruned_model)
                    total_params, non_zero_params = self.count_parameters(pruned_model)
                    sparsity = 1 - (non_zero_params / total_params) if total_params > 0 else 0

                    # Memory usage
                    cpu_mem, gpu_mem = self.get_memory_usage()

                    # Inference time
                    inference_time = self.measure_inference_time(pruned_model, sample_data)

                    # Model performance
                    performance = self.evaluate_model(pruned_model, test_loader, criterion)

                    # Compilation time
                    compile_time = time.time() - start_time

                    result = {
                        'Method': method_name,
                        'Pruning_Ratio': ratio,
                        'Model_Size_MB': model_size,
                        'Total_Parameters': total_params,
                        'Non_Zero_Parameters': non_zero_params,
                        'Sparsity': sparsity,
                        'CPU_Memory_MB': cpu_mem,
                        'GPU_Memory_MB': gpu_mem,
                        'Inference_Time_ms': inference_time,
                        'Compile_Time_s': compile_time,
                        'Accuracy': performance['accuracy'],
                        'F1_Score': performance['f1_score'],
                        'Precision': performance['precision'],
                        'Recall': performance['recall'],
                        'Loss': performance['loss']
                    }

                    results.append(result)

                    print(f"  Model Size: {model_size:.2f} MB")
                    print(f"  Sparsity: {sparsity:.2%}")
                    print(f"  Accuracy: {performance['accuracy']:.4f}")
                    print(f"  F1-Score: {performance['f1_score']:.4f}")
                    print(f"  Inference Time: {inference_time:.2f} ms")

                except Exception as e:
                    print(f"Error processing {method_name} with ratio {ratio}: {e}")
                    continue
                finally:
                    # Clean up
                    if 'pruned_model' in locals():
                        del pruned_model
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

        return pd.DataFrame(results)

    def create_visualizations(self, results_df):
        """Create comprehensive visualizations"""
        if results_df.empty:
            print("No results to visualize")
            return None

        plt.style.use('default')
        fig = plt.figure(figsize=(20, 15))

        # 1. Accuracy vs Pruning Ratio
        ax1 = plt.subplot(3, 3, 1)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['Accuracy'],
                    marker='o', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 2. Model Size vs Pruning Ratio
        ax2 = plt.subplot(3, 3, 2)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['Model_Size_MB'],
                    marker='s', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('Model Size (MB)')
        plt.title('Model Size vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 3. Inference Time vs Pruning Ratio
        ax3 = plt.subplot(3, 3, 3)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['Inference_Time_ms'],
                    marker='^', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('Inference Time (ms)')
        plt.title('Inference Time vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 4. Sparsity vs Accuracy
        ax4 = plt.subplot(3, 3, 4)
        for method in results_df['Method'].unique():
            if method != 'Original':
                method_data = results_df[results_df['Method'] == method]
                plt.scatter(method_data['Sparsity'], method_data['Accuracy'],
                           s=100, alpha=0.7, label=method)
        plt.xlabel('Sparsity')
        plt.ylabel('Accuracy')
        plt.title('Sparsity vs Accuracy Trade-off')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 5. F1-Score Comparison
        ax5 = plt.subplot(3, 3, 5)
        pruning_50 = results_df[results_df['Pruning_Ratio'] == 0.5]
        if not pruning_50.empty:
            methods = pruning_50['Method'].tolist()
            f1_scores = pruning_50['F1_Score'].tolist()
            bars = plt.bar(methods, f1_scores, alpha=0.7, color=plt.cm.Set3(range(len(methods))))
            plt.ylabel('F1-Score')
            plt.title('F1-Score Comparison (50% Pruning)')
            plt.xticks(rotation=45)
            for bar, score in zip(bars, f1_scores):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{score:.3f}', ha='center', va='bottom')

        # 6. Memory Efficiency
        ax6 = plt.subplot(3, 3, 6)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['CPU_Memory_MB'],
                    marker='d', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('CPU Memory (MB)')
        plt.title('Memory Usage vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 7. Accuracy vs Model Size (Efficiency Plot)
        ax7 = plt.subplot(3, 3, 7)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.scatter(method_data['Model_Size_MB'], method_data['Accuracy'],
                       s=100, alpha=0.7, label=method)
        plt.xlabel('Model Size (MB)')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs Model Size (Efficiency)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 8. Comprehensive Performance Heatmap
        ax8 = plt.subplot(3, 3, 8)
        pivot_data = results_df.pivot_table(
            values='Accuracy',
            index='Method',
            columns='Pruning_Ratio',
            fill_value=np.nan
        )
        if not pivot_data.empty:
            sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='YlOrRd',
                       cbar_kws={'label': 'Accuracy'})
        plt.title('Accuracy Heatmap by Method and Pruning Ratio')

        # 9. Pareto Frontier (Accuracy vs Compression)
        ax9 = plt.subplot(3, 3, 9)
        original_data = results_df[results_df['Method'] == 'Original']
        if not original_data.empty:
            original_size = original_data['Model_Size_MB'].iloc[0]

            for method in results_df['Method'].unique():
                if method != 'Original':
                    method_data = results_df[results_df['Method'] == method]
                    compression_ratio = original_size / method_data['Model_Size_MB']
                    plt.scatter(compression_ratio, method_data['Accuracy'],
                               s=100, alpha=0.7, label=method)

            plt.xlabel('Compression Ratio')
            plt.ylabel('Accuracy')
            plt.title('Pareto Frontier: Accuracy vs Compression')
            plt.legend()
            plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

        return fig

    def generate_report(self, results_df):
        """Generate comprehensive comparison report"""
        if results_df.empty:
            print("No results to report")
            return results_df

        print("\n" + "="*80)
        print("COMPREHENSIVE PRUNING COMPARISON REPORT")
        print("="*80)

        # Overall statistics
        print(f"\nTested {len(results_df['Method'].unique())} pruning methods")
        print(f"Pruning ratios: {sorted(results_df['Pruning_Ratio'].unique())}")

        # Best performers at different pruning ratios
        for ratio in [0.3, 0.5, 0.7]:
            ratio_data = results_df[results_df['Pruning_Ratio'] == ratio]
            if not ratio_data.empty:
                best_accuracy = ratio_data.loc[ratio_data['Accuracy'].idxmax()]
                efficiency_metric = ratio_data['Accuracy'] / ratio_data['Model_Size_MB']
                best_efficiency = ratio_data.loc[efficiency_metric.idxmax()]

                print(f"\n--- At {ratio*100}% Pruning ---")
                print(f"Best Accuracy: {best_accuracy['Method']} "
                      f"({best_accuracy['Accuracy']:.4f})")
                print(f"Best Efficiency: {best_efficiency['Method']} "
                      f"(Acc: {best_efficiency['Accuracy']:.4f}, "
                      f"Size: {best_efficiency['Model_Size_MB']:.2f} MB)")

        # Method comparison summary
        print(f"\n--- Method Summary ---")
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            avg_accuracy = method_data['Accuracy'].mean()
            avg_size = method_data['Model_Size_MB'].mean()
            avg_inference = method_data['Inference_Time_ms'].mean()

            print(f"{method}:")
            print(f"  Avg Accuracy: {avg_accuracy:.4f}")
            print(f"  Avg Model Size: {avg_size:.2f} MB")
            print(f"  Avg Inference Time: {avg_inference:.2f} ms")

        # Detailed comparison table
        print(f"\n--- Detailed Results ---")
        print(results_df.round(4).to_string(index=False))

        return results_df

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import time
import gc
import psutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from copy import deepcopy
from sklearn.metrics import f1_score, precision_score, recall_score
from torch_geometric.data import Data, Batch

class PruningComparator:
    def __init__(self, model_config, device=None):
        self.model_config = model_config
        # Auto-detect device if not specified
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)

        print(f"PruningComparator initialized with device: {self.device}")
        self.results = {}

    def ensure_device_consistency(self, model, data):
        """Ensure model and data are on the same device"""
        # Move model to device
        model = model.to(self.device)

        # Handle different data types
        if isinstance(data, (Data, Batch)):
            data = data.to(self.device)
        elif hasattr(data, 'to') and callable(getattr(data, 'to')):
            # For objects that have a .to() method
            data = data.to(self.device)
        elif hasattr(data, '__dict__'):
            # For custom objects with attributes, move tensor attributes to device
            for attr_name in dir(data):
                if not attr_name.startswith('_'):
                    attr_value = getattr(data, attr_name)
                    if torch.is_tensor(attr_value):
                        setattr(data, attr_name, attr_value.to(self.device))
        elif isinstance(data, tuple):
            data = tuple(item.to(self.device) if torch.is_tensor(item) else item for item in data)
        elif torch.is_tensor(data):
            data = data.to(self.device)

        return model, data

    def get_model_size(self, model):
        """Calculate model size in MB"""
        param_size = 0
        buffer_size = 0

        for param in model.parameters():
            param_size += param.nelement() * param.element_size()

        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()

        model_size = (param_size + buffer_size) / 1024 / 1024
        return model_size

    def count_parameters(self, model):
        """Count total and non-zero parameters"""
        total_params = sum(p.numel() for p in model.parameters())

        # Count non-zero parameters (for pruned models)
        non_zero_params = 0
        for p in model.parameters():
            if hasattr(p, 'weight_mask'):
                non_zero_params += torch.sum(p.weight_mask).item()
            else:
                non_zero_params += torch.sum(p != 0).item()

        return total_params, non_zero_params

    def get_memory_usage(self):
        """Get current memory usage"""
        process = psutil.Process()
        cpu_memory = process.memory_info().rss / 1024 / 1024  # MB

        gpu_memory = 0
        if torch.cuda.is_available() and self.device.type == 'cuda':
            gpu_memory = torch.cuda.memory_allocated(self.device) / 1024 / 1024  # MB

        return cpu_memory, gpu_memory

    def measure_inference_time(self, model, sample_data, num_runs=100):
        """Measure inference time with device consistency"""
        # Ensure everything is on the same device
        model, sample_data = self.ensure_device_consistency(model, sample_data)
        model.eval()

        # Handle different data formats
        if isinstance(sample_data, (Data, Batch)):
            x, edge_index, batch = sample_data.x, sample_data.edge_index, sample_data.batch
        else:
            x, edge_index, batch = sample_data

        # Ensure all tensors are on the correct device
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        if batch is not None:
            batch = batch.to(self.device)

        # Warmup runs
        with torch.no_grad():
            for _ in range(10):
                try:
                    if batch is not None:
                        _ = model(x, edge_index, batch)
                    else:
                        _ = model(x, edge_index)
                except Exception as e:
                    print(f"Warning during warmup: {e}")
                    break

        # Measure inference time
        if self.device.type == 'cuda':
            torch.cuda.synchronize(self.device)

        start_time = time.time()
        with torch.no_grad():
            for _ in range(num_runs):
                try:
                    if batch is not None:
                        _ = model(x, edge_index, batch)
                    else:
                        _ = model(x, edge_index)
                except Exception as e:
                    print(f"Error during inference timing: {e}")
                    return float('inf')

        if self.device.type == 'cuda':
            torch.cuda.synchronize(self.device)

        end_time = time.time()
        avg_time = (end_time - start_time) / num_runs * 1000  # ms
        return avg_time

    def unstructured_pruning(self, model, pruning_ratio=0.5):
        """Apply unstructured magnitude-based pruning"""
        model_pruned = deepcopy(model).to(self.device)

        # Collect all linear layers
        modules_to_prune = []
        for name, module in model_pruned.named_modules():
            if isinstance(module, nn.Linear):
                modules_to_prune.append((module, 'weight'))

        if not modules_to_prune:
            print("Warning: No Linear layers found for pruning")
            return model_pruned

        # Apply global unstructured pruning
        prune.global_unstructured(
            modules_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=pruning_ratio,
        )

        # Make pruning permanent
        for module, param_name in modules_to_prune:
            prune.remove(module, param_name)

        return model_pruned

    def structured_pruning(self, model, pruning_ratio=0.5):
        """Apply structured pruning (remove entire neurons/channels)"""
        model_pruned = deepcopy(model).to(self.device)

        # Apply structured pruning to linear layers
        for name, module in model_pruned.named_modules():
            if isinstance(module, nn.Linear) and hasattr(module, 'weight'):
                # Calculate L1 norm for each output neuron
                weight = module.weight.data
                l1_norm = torch.norm(weight, p=1, dim=1)

                # Determine number of neurons to prune
                num_neurons = weight.size(0)
                num_to_prune = int(num_neurons * pruning_ratio)

                if num_to_prune > 0 and num_to_prune < num_neurons:
                    try:
                        # Apply structured pruning
                        prune.ln_structured(module, name='weight', amount=num_to_prune, n=1, dim=0)
                        prune.remove(module, 'weight')
                    except Exception as e:
                        print(f"Warning: Could not apply structured pruning to {name}: {e}")

        return model_pruned

    def magnitude_based_pruning(self, model, pruning_ratio=0.5):
        """Apply magnitude-based pruning with different thresholds"""
        model_pruned = deepcopy(model).to(self.device)

        # Collect all weights
        all_weights = []
        for param in model_pruned.parameters():
            if param.requires_grad:
                all_weights.extend(param.data.abs().flatten().tolist())

        if not all_weights:
            print("Warning: No trainable parameters found")
            return model_pruned

        # Calculate threshold based on magnitude
        all_weights = torch.tensor(all_weights, device=self.device)
        threshold = torch.quantile(all_weights, pruning_ratio)

        # Apply pruning based on magnitude threshold
        for param in model_pruned.parameters():
            if param.requires_grad:
                mask = param.data.abs() > threshold
                param.data *= mask.float()

        return model_pruned

    def random_pruning(self, model, pruning_ratio=0.5):
        """Apply random pruning for comparison"""
        model_pruned = deepcopy(model).to(self.device)

        modules_to_prune = []
        for name, module in model_pruned.named_modules():
            if isinstance(module, nn.Linear):
                modules_to_prune.append((module, 'weight'))

        if not modules_to_prune:
            print("Warning: No Linear layers found for random pruning")
            return model_pruned

        # Apply random pruning
        prune.global_unstructured(
            modules_to_prune,
            pruning_method=prune.RandomUnstructured,
            amount=pruning_ratio,
        )

        # Make pruning permanent
        for module, param_name in modules_to_prune:
            prune.remove(module, param_name)

        return model_pruned

    def evaluate_model(self, model, test_loader, criterion):
        """Evaluate model performance with device consistency"""
        model, _ = self.ensure_device_consistency(model, None)
        model.eval()

        total_loss = 0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for data in test_loader:
                # Move data to device - handle custom data objects
                if hasattr(data, 'to') and callable(getattr(data, 'to')):
                    data = data.to(self.device)
                elif hasattr(data, '__dict__'):
                    # For custom objects, move tensor attributes to device
                    for attr_name in ['x', 'edge_index', 'batch', 'y']:
                        if hasattr(data, attr_name):
                            attr_value = getattr(data, attr_name)
                            if torch.is_tensor(attr_value):
                                setattr(data, attr_name, attr_value.to(self.device))

                x, edge_index, batch, y = data.x, data.edge_index, data.batch, data.y

                try:
                    outputs = model(x, edge_index, batch)
                    loss = criterion(outputs, y)
                    total_loss += loss.item()

                    _, predicted = torch.max(outputs.data, 1)
                    total += y.size(0)
                    correct += (predicted == y).sum().item()

                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(y.cpu().numpy())
                except Exception as e:
                    print(f"Error during evaluation: {e}")
                    continue

        if total == 0:
            return {
                'accuracy': 0.0,
                'loss': float('inf'),
                'f1_score': 0.0,
                'precision': 0.0,
                'recall': 0.0
            }

        accuracy = correct / total
        avg_loss = total_loss / len(test_loader) if len(test_loader) > 0 else float('inf')

        # Calculate additional metrics
        try:
            f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
            precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
            recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
        except Exception as e:
            print(f"Warning: Could not calculate metrics: {e}")
            f1 = precision = recall = 0.0

        return {
            'accuracy': accuracy,
            'loss': avg_loss,
            'f1_score': f1,
            'precision': precision,
            'recall': recall
        }

    def compare_pruning_methods(self, original_model, test_loader, sample_data,
                              pruning_ratios=[0.1, 0.3, 0.5, 0.7, 0.9]):
        """Compare different pruning methods with device consistency"""
        criterion = nn.CrossEntropyLoss().to(self.device)

        # Ensure original model is on correct device
        original_model = original_model.to(self.device)

        pruning_methods = {
            'Original': lambda m, r: m,
            'Unstructured': self.unstructured_pruning,
            'Structured': self.structured_pruning,
            'Magnitude-based': self.magnitude_based_pruning,
            'Random': self.random_pruning
        }

        results = []

        for method_name, pruning_func in pruning_methods.items():
            print(f"\n{'='*50}")
            print(f"Testing {method_name} Pruning")
            print(f"{'='*50}")

            if method_name == 'Original':
                ratios_to_test = [0.0]
            else:
                ratios_to_test = pruning_ratios

            for ratio in ratios_to_test:
                print(f"\nPruning ratio: {ratio}")

                try:
                    # Apply pruning
                    if method_name == 'Original':
                        pruned_model = deepcopy(original_model).to(self.device)
                    else:
                        pruned_model = pruning_func(original_model, ratio)

                    # Measure metrics
                    start_time = time.time()

                    # Model size and parameters
                    model_size = self.get_model_size(pruned_model)
                    total_params, non_zero_params = self.count_parameters(pruned_model)
                    sparsity = 1 - (non_zero_params / total_params) if total_params > 0 else 0

                    # Memory usage
                    cpu_mem, gpu_mem = self.get_memory_usage()

                    # Inference time
                    inference_time = self.measure_inference_time(pruned_model, sample_data)

                    # Model performance
                    performance = self.evaluate_model(pruned_model, test_loader, criterion)

                    # Compilation time
                    compile_time = time.time() - start_time

                    result = {
                        'Method': method_name,
                        'Pruning_Ratio': ratio,
                        'Model_Size_MB': model_size,
                        'Total_Parameters': total_params,
                        'Non_Zero_Parameters': non_zero_params,
                        'Sparsity': sparsity,
                        'CPU_Memory_MB': cpu_mem,
                        'GPU_Memory_MB': gpu_mem,
                        'Inference_Time_ms': inference_time,
                        'Compile_Time_s': compile_time,
                        'Accuracy': performance['accuracy'],
                        'F1_Score': performance['f1_score'],
                        'Precision': performance['precision'],
                        'Recall': performance['recall'],
                        'Loss': performance['loss']
                    }

                    results.append(result)

                    print(f"  Model Size: {model_size:.2f} MB")
                    print(f"  Sparsity: {sparsity:.2%}")
                    print(f"  Accuracy: {performance['accuracy']:.4f}")
                    print(f"  F1-Score: {performance['f1_score']:.4f}")
                    print(f"  Inference Time: {inference_time:.2f} ms")

                except Exception as e:
                    print(f"Error processing {method_name} with ratio {ratio}: {e}")
                    continue
                finally:
                    # Clean up
                    if 'pruned_model' in locals():
                        del pruned_model
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

        return pd.DataFrame(results)

    def create_visualizations(self, results_df):
        """Create comprehensive visualizations"""
        if results_df.empty:
            print("No results to visualize")
            return None

        plt.style.use('default')
        fig = plt.figure(figsize=(20, 15))

        # 1. Accuracy vs Pruning Ratio
        ax1 = plt.subplot(3, 3, 1)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['Accuracy'],
                    marker='o', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 2. Model Size vs Pruning Ratio
        ax2 = plt.subplot(3, 3, 2)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['Model_Size_MB'],
                    marker='s', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('Model Size (MB)')
        plt.title('Model Size vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 3. Inference Time vs Pruning Ratio
        ax3 = plt.subplot(3, 3, 3)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['Inference_Time_ms'],
                    marker='^', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('Inference Time (ms)')
        plt.title('Inference Time vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 4. Sparsity vs Accuracy
        ax4 = plt.subplot(3, 3, 4)
        for method in results_df['Method'].unique():
            if method != 'Original':
                method_data = results_df[results_df['Method'] == method]
                plt.scatter(method_data['Sparsity'], method_data['Accuracy'],
                           s=100, alpha=0.7, label=method)
        plt.xlabel('Sparsity')
        plt.ylabel('Accuracy')
        plt.title('Sparsity vs Accuracy Trade-off')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 5. F1-Score Comparison
        ax5 = plt.subplot(3, 3, 5)
        pruning_50 = results_df[results_df['Pruning_Ratio'] == 0.5]
        if not pruning_50.empty:
            methods = pruning_50['Method'].tolist()
            f1_scores = pruning_50['F1_Score'].tolist()
            bars = plt.bar(methods, f1_scores, alpha=0.7, color=plt.cm.Set3(range(len(methods))))
            plt.ylabel('F1-Score')
            plt.title('F1-Score Comparison (50% Pruning)')
            plt.xticks(rotation=45)
            for bar, score in zip(bars, f1_scores):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{score:.3f}', ha='center', va='bottom')

        # 6. Memory Efficiency
        ax6 = plt.subplot(3, 3, 6)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.plot(method_data['Pruning_Ratio'], method_data['CPU_Memory_MB'],
                    marker='d', linewidth=2, label=method)
        plt.xlabel('Pruning Ratio')
        plt.ylabel('CPU Memory (MB)')
        plt.title('Memory Usage vs Pruning Ratio')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 7. Accuracy vs Model Size (Efficiency Plot)
        ax7 = plt.subplot(3, 3, 7)
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            plt.scatter(method_data['Model_Size_MB'], method_data['Accuracy'],
                       s=100, alpha=0.7, label=method)
        plt.xlabel('Model Size (MB)')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs Model Size (Efficiency)')
        plt.legend()
        plt.grid(True, alpha=0.3)

        # 8. Comprehensive Performance Heatmap
        ax8 = plt.subplot(3, 3, 8)
        pivot_data = results_df.pivot_table(
            values='Accuracy',
            index='Method',
            columns='Pruning_Ratio',
            fill_value=np.nan
        )
        if not pivot_data.empty:
            sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='YlOrRd',
                       cbar_kws={'label': 'Accuracy'})
        plt.title('Accuracy Heatmap by Method and Pruning Ratio')

        # 9. Pareto Frontier (Accuracy vs Compression)
        ax9 = plt.subplot(3, 3, 9)
        original_data = results_df[results_df['Method'] == 'Original']
        if not original_data.empty:
            original_size = original_data['Model_Size_MB'].iloc[0]

            for method in results_df['Method'].unique():
                if method != 'Original':
                    method_data = results_df[results_df['Method'] == method]
                    compression_ratio = original_size / method_data['Model_Size_MB']
                    plt.scatter(compression_ratio, method_data['Accuracy'],
                               s=100, alpha=0.7, label=method)

            plt.xlabel('Compression Ratio')
            plt.ylabel('Accuracy')
            plt.title('Pareto Frontier: Accuracy vs Compression')
            plt.legend()
            plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

        return fig

    def generate_report(self, results_df):
        """Generate comprehensive comparison report"""
        if results_df.empty:
            print("No results to report")
            return results_df

        print("\n" + "="*80)
        print("COMPREHENSIVE PRUNING COMPARISON REPORT")
        print("="*80)

        # Overall statistics
        print(f"\nTested {len(results_df['Method'].unique())} pruning methods")
        print(f"Pruning ratios: {sorted(results_df['Pruning_Ratio'].unique())}")

        # Best performers at different pruning ratios
        for ratio in [0.3, 0.5, 0.7]:
            ratio_data = results_df[results_df['Pruning_Ratio'] == ratio]
            if not ratio_data.empty:
                best_accuracy = ratio_data.loc[ratio_data['Accuracy'].idxmax()]
                efficiency_metric = ratio_data['Accuracy'] / ratio_data['Model_Size_MB']
                best_efficiency = ratio_data.loc[efficiency_metric.idxmax()]

                print(f"\n--- At {ratio*100}% Pruning ---")
                print(f"Best Accuracy: {best_accuracy['Method']} "
                      f"({best_accuracy['Accuracy']:.4f})")
                print(f"Best Efficiency: {best_efficiency['Method']} "
                      f"(Acc: {best_efficiency['Accuracy']:.4f}, "
                      f"Size: {best_efficiency['Model_Size_MB']:.2f} MB)")

        # Method comparison summary
        print(f"\n--- Method Summary ---")
        for method in results_df['Method'].unique():
            method_data = results_df[results_df['Method'] == method]
            avg_accuracy = method_data['Accuracy'].mean()
            avg_size = method_data['Model_Size_MB'].mean()
            avg_inference = method_data['Inference_Time_ms'].mean()

            print(f"{method}:")
            print(f"  Avg Accuracy: {avg_accuracy:.4f}")
            print(f"  Avg Model Size: {avg_size:.2f} MB")
            print(f"  Avg Inference Time: {avg_inference:.2f} ms")

        # Detailed comparison table
        print(f"\n--- Detailed Results ---")
        print(results_df.round(4).to_string(index=False))

        return results_df


# Fixed usage example with proper batch handling
def create_fixed_dummy_data(model_config, batch_size=32):
    """Create properly structured dummy data"""

    # Create graph data that will result in the correct batch size
    num_graphs = batch_size
    nodes_per_graph = 10
    total_nodes = num_graphs * nodes_per_graph

    # Node features
    x = torch.randn(total_nodes, model_config['num_features'])

    # Create edges within each graph
    edge_list = []
    for i in range(num_graphs):
        start_node = i * nodes_per_graph
        end_node = start_node + nodes_per_graph
        # Create a simple connected graph for each batch element
        for j in range(start_node, end_node - 1):
            edge_list.append([j, j + 1])
            edge_list.append([j + 1, j])  # bidirectional

    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    # Create batch indices
    batch = torch.repeat_interleave(torch.arange(num_graphs), nodes_per_graph)

    # Create labels for each graph
    y = torch.randint(0, model_config['num_classes'], (num_graphs,))

    return x, edge_index, batch, y


# Example usage with the fix
if __name__ == "__main__":
    # Assuming you have a GINModel class defined
    # You'll need to replace this with your actual model class

    model_config = {
        'num_features': 128,
        'hidden_dim': 128,
        'num_classes': 2,
        'num_layers': 4,
        'dropout': 0.5,
        'global_pool': 'max',
        'eps': 0.0,
        'train_eps': True,
        'batch_norm': False,
        'layer_norm': True
    }

    # Create device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create original model (you need to implement this)
    # original_model = GINModel(**model_config).to(device)

    # Create properly structured dummy data
    batch_size = 32
    x, edge_index, batch, y = create_fixed_dummy_data(model_config, batch_size)
    sample_data = (x, edge_index, batch)

    # Create dummy test loader with proper batch handling
    class FixedDummyData:
        def __init__(self, x, edge_index, batch, y):
            self.x = x
            self.edge_index = edge_index
            self.batch = batch
            self.y = y

    # Create multiple batches for testing
    dummy_test_loader = []
    for _ in range(10):
        x_batch, edge_index_batch, batch_batch, y_batch = create_fixed_dummy_data(model_config, batch_size)
        dummy_test_loader.append(FixedDummyData(x_batch, edge_index_batch, batch_batch, y_batch))

    # Initialize pruning comparator
    comparator = PruningComparator(model_config, device)

    print("Fixed dummy data created successfully!")
    print(f"Sample data shapes:")
    print(f"  x: {x.shape}")
    print(f"  edge_index: {edge_index.shape}")
    print(f"  batch: {batch.shape}")
    print(f"  y: {y.shape}")
    print(f"  Unique batch values: {torch.unique(batch)}")

# Now you can run the comparison with your actual model:
results_df = comparator.compare_pruning_methods(
    original_model,
    dummy_test_loader,
    sample_data,
    pruning_ratios=[0.1, 0.3, 0.5, 0.7, 0.9]
)

PruningComparator initialized with device: cuda
Fixed dummy data created successfully!
Sample data shapes:
  x: torch.Size([320, 128])
  edge_index: torch.Size([2, 576])
  batch: torch.Size([320])
  y: torch.Size([32])
  Unique batch values: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

Testing Original Pruning

Pruning ratio: 0.0
  Model Size: 0.51 MB
  Sparsity: 0.39%
  Accuracy: 0.4813
  F1-Score: 0.3127
  Inference Time: 2.42 ms

Testing Unstructured Pruning

Pruning ratio: 0.1
  Model Size: 0.51 MB
  Sparsity: 10.23%
  Accuracy: 0.4813
  F1-Score: 0.3127
  Inference Time: 4.74 ms

Pruning ratio: 0.3
  Model Size: 0.51 MB
  Sparsity: 29.92%
  Accuracy: 0.4813
  F1-Score: 0.3127
  Inference Time: 2.63 ms

Pruning ratio: 0.5
  Model Size: 0.51 MB
  Sparsity: 49.62%
  Accuracy: 0.4813
  F1-Score: 0.3235
  Inference Time: 2.41 ms

Pruning ratio: 0.7
  Model Size: 0.51 MB
  Sparsity: 69.3