In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def fitness_function(models, validation_loader, alpha=0.5, beta=0.3, gamma=0.2):
    """
    Evaluates different pretrained MobileNetV3 weight spaces based on accuracy, sparsity, stability, and weight health.
    
    Args:
        models (list): List of models with different weights.
        validation_loader (DataLoader): Validation dataset.
        alpha (float): Weight for accuracy.
        beta (float): Weight for sparsity.
        gamma (float): Weight for stability and weight health.

    Returns:
        best_model (torch.nn.Module): Model with the best weight space.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    best_score = float('-inf')
    best_model = None

    for model in models:
        model.to(device)
        model.eval()
        
        correct = 0
        total = 0
        sparsity = 0
        num_params = 0
        stability_scores = []
        weight_health_scores = []
        
        with torch.no_grad():
            for images, labels in validation_loader:
                images, labels = images.to(device), labels.to(device)

                # Run multiple forward passes for stability check
                outputs_list = []
                for _ in range(3):  # Three forward passes
                    outputs = model(images)
                    outputs_list.append(F.softmax(outputs, dim=1).cpu().numpy())

                # Compute stability as variance across multiple predictions
                outputs_array = np.stack(outputs_list, axis=0)
                stability_score = -np.mean(np.var(outputs_array, axis=0))  # Lower variance is better
                
                stability_scores.append(stability_score)

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
        
        # Calculate accuracy
        accuracy = correct / total

        # Compute sparsity (percentage of zero weights)
        for param in model.parameters():
            num_params += param.numel()
            sparsity += (param == 0).sum().item()
        
        sparsity_ratio = sparsity / num_params
        
        # Compute weight health (variance of weights should not be too low/high)
        for param in model.parameters():
            weight_std = torch.std(param).item()
            weight_mean = torch.mean(param).item()
            weight_health_score = -abs(weight_std - 1.0) - abs(weight_mean)  # Ideally, std close to 1 and mean close to 0
            weight_health_scores.append(weight_health_score)
        
        avg_stability = np.mean(stability_scores)
        avg_weight_health = np.mean(weight_health_scores)

        # Define the fitness score (combination of accuracy, sparsity, stability, and weight health)
        fitness_score = (
            alpha * accuracy - beta * sparsity_ratio + gamma * (avg_stability + avg_weight_health)
        )
        
        # Select the best model based on fitness score
        if fitness_score > best_score:
            best_score = fitness_score
            best_model = model

    return best_model
