In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.models as models
import copy
import numpy as np
from torch.utils.data import Dataset, DataLoader, TensorDataset

# from metadata import MetadataChromosome, initialize_population, select_parents, genetic_algorithm

class MetadataChromosome:
    def __init__(self, metadata):
        self.metadata = metadata
        self.fitness = 0.0

    def calculate_fitness(self, models):
        """
        Evaluates different pretrained MobileNetV3 weight spaces based on accuracy, sparsity, stability, and weight health.
    
        Args:
            models (list): List of models with different weights.
            validation_loader (DataLoader): Validation dataset.
            alpha (float): Weight for accuracy.
            beta (float): Weight for sparsity.
            gamma (float): Weight for stability and weight health.

        Returns:
            best_model (torch.nn.Module): Model with the best weight space.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
        best_score = float('-inf')
        best_model = None

        for model in models:
            model.to(device)
            model.eval()
        
            correct = 0
            total = 0
            sparsity = 0
            num_params = 0
            stability_scores = []
            weight_health_scores = []
        
            with torch.no_grad():
                for images, labels in validation_loader:
                    images, labels = images.to(device), labels.to(device)

                    # Run multiple forward passes for stability check
                    outputs_list = []
                    for _ in range(3):  # Three forward passes
                        outputs = model(images)
                        outputs_list.append(F.softmax(outputs, dim=1).cpu().numpy())

                    # Compute stability as variance across multiple predictions
                    outputs_array = np.stack(outputs_list, axis=0)
                    stability_score = -np.mean(np.var(outputs_array, axis=0))  # Lower variance is better
                
                    stability_scores.append(stability_score)

                    _, predicted = torch.max(outputs, 1)
                    correct += (predicted == labels).sum().item()
                    total += labels.size(0)
        
            # Calculate accuracy
            accuracy = correct / total

            # Compute sparsity (percentage of zero weights)
            for param in model.parameters():
                num_params += param.numel()
                sparsity += (param == 0).sum().item()
        
            sparsity_ratio = sparsity / num_params
        
            # Compute weight health (variance of weights should not be too low/high)
            for param in model.parameters():
                weight_std = torch.std(param).item()
                weight_mean = torch.mean(param).item()
                weight_health_score = -abs(weight_std - 1.0) - abs(weight_mean)  # Ideally, std close to 1 and mean close to 0
                weight_health_scores.append(weight_health_score)
        
            avg_stability = np.mean(stability_scores)
            avg_weight_health = np.mean(weight_health_scores)

            # Define the fitness score (combination of accuracy, sparsity, stability, and weight health)
            fitness_score = (
                alpha * accuracy - beta * sparsity_ratio + gamma * (avg_stability + avg_weight_health)
            )
        
            # Select the best model based on fitness score
            if fitness_score > best_score:
                best_score = fitness_score
                best_model = model

        return best_model
    

    def mutate(self):
        # Randomly mutate one of the metadata values
        key = random.choice(list(self.metadata.keys()))
        self.metadata[key] = random.uniform(0, 1)

    def crossover(self, other):
        # Perform crossover with another chromosome
        child_metadata = {}
        for key in self.metadata.keys():
            if random.random() < 0.5:
                child_metadata[key] = self.metadata[key]
            else:
                child_metadata[key] = other.metadata[key]
        return MetadataChromosome(child_metadata)


# Initialize a population of chromosomes
def initialize_model_population(models):
    population = []
    for model in models:
        metadata = {'model': model}
        chromosome = MetadataChromosome(metadata)
        population.append(chromosome)
    return population


#def initialize_population(size, metadata_keys):
#    population = []
#    for _ in range(size):
#        metadata = {key: random.uniform(0, 1) for key in metadata_keys}
#        chromosome = MetadataChromosome(metadata)
#        population.append(chromosome)
#    return population

# Select parents for crossover
def select_parents(population):
    population.sort(key=lambda x: x.fitness, reverse=True)
    return population[:2]


# Genetic Algorithm
def genetic_algorithm(metadata_keys, models, population, generations=100, population_size=10):
    # population = initialize_population(population_size, metadata_keys)

    for generation in range(generations):
        # Calculate fitness for each chromosome
        for chromosome in population:
            chromosome.calculate_fitness(models)

        # Select parents
        parents = select_parents(population)

        # Create next generation
        new_population = []
        for _ in range(population_size // 2):
            parent1, parent2 = random.sample(parents, 2)
            child1 = parent1.crossover(parent2)
            child2 = parent2.crossover(parent1)
            child1.mutate()
            child2.mutate()
            new_population.extend([child1, child2])

        population = new_population

        # Print the best fitness of the current generation
        best_fitness = max(chromosome.fitness for chromosome in population)
        print(f"Generation {generation + 1}: Best Fitness = {best_fitness}")

    # Return the best chromosome
    best_chromosome = max(population, key=lambda x: x.fitness)
    return best_chromosome


# Federated Learning Parameters
NUM_CLIENTS = 5  # Number of clients
ROUNDS = 10  # Number of training rounds
EPOCHS_PER_CLIENT = 2  # Local training epochs per client
AGGREGATION_METHOD = "FedAvg"  # Federated averaging


def federated_learning_with_ga():
    # Define synthetic data
    num_samples = 1000
    num_classes = 10
    image_size = (3, 224, 224)

    # Generate random images and labels
    images = torch.randn(num_samples, *image_size)
    labels = torch.randint(0, num_classes, (num_samples,))

    # Define preprocessing steps
    transform = transforms.Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Apply transformations to synthetic data
    images = torch.stack([transform(image) for image in images])

    # Create TensorDataset and DataLoader
    synthetic_dataset = TensorDataset(images, labels)
    train_loader = DataLoader(synthetic_dataset, batch_size=32, shuffle=True, num_workers=4)
    
    global_model = load_mobilenetv3()
    clients = [load_mobilenetv3() for _ in range(NUM_CLIENTS)]
    
    # Update train_loaders
    train_loaders = [train_loader for _ in range(NUM_CLIENTS)]  

    for round in range(ROUNDS):
        print(f"Round {round+1}/{ROUNDS}")
        client_models = []

        for client_id in range(NUM_CLIENTS):
            client_model = copy.deepcopy(global_model)
            updated_weights = train_client(client_model, train_loaders[client_id])
            client_models.append(updated_weights)

        # Initialize population with client models
        population = initialize_model_population(client_models)

        # Run genetic algorithm to find the best model
        best_chromosome = genetic_algorithm(metadata_keys=['model'], models=client_models, population=population, 
                                            generations=10, population_size=NUM_CLIENTS)
        global_model = best_chromosome.metadata['model']
        print("Global model updated using Genetic Algorithm.")

        # Validation step
        global_model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = global_model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
        
        val_loss /= len(val_loader)
        accuracy = correct / total
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')

    return global_model


# Load Pretrained MobileNetV3 Model
def load_mobilenetv3():
    model = models.mobilenet_v3_small(pretrained=True)
    model.classifier[3] = nn.Linear(model.classifier[3].in_features, 10)  # Example: 10 classes
    return model


# Simulated Client Training
def train_client(model, train_loader, epochs=EPOCHS_PER_CLIENT, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for _ in range(epochs):
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    return model.state_dict()


# Federated Averaging (FedAvg)
def federated_averaging(global_model, client_models):
    new_state_dict = copy.deepcopy(global_model.state_dict())
    for key in new_state_dict.keys():
        new_state_dict[key] = torch.stack([client_models[i][key] for i in range(NUM_CLIENTS)], dim=0).mean(dim=0)
    global_model.load_state_dict(new_state_dict)
    return global_model


# Federated Learning Simulation
def federated_learning():
    global_model = load_mobilenetv3()
    clients = [load_mobilenetv3() for _ in range(NUM_CLIENTS)]
    # # Replace with actual DataLoaders [None]
    train_loaders =  DataLoader(dataset=train_dataset, batch_size=1000, num_workers=1, shuffle=True) * NUM_CLIENTS  

    # Load dataset
    train_dataset = datasets.ImageFolder(root='train_data_path/', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    for round in range(ROUNDS):
        print(f"Round {round+1}/{ROUNDS}")
        client_models = []
        
        for client_id in range(NUM_CLIENTS):
            client_model = copy.deepcopy(global_model)
            updated_weights = train_client(client_model, train_loaders[client_id])
            client_models.append(updated_weights)
        
        global_model = federated_averaging(global_model, client_models)
        print("Global model updated.")
    
    return global_model

# Run federated learning
global_model = federated_learning_with_ga()

#global_model = federated_learning()



Round 1/10


AttributeError: 'collections.OrderedDict' object has no attribute 'to'