In [None]:
# 1. IMPORTS
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import numpy as np
import copy
import os
import time
import matplotlib.pyplot as plt
from thop import profile
from functools import lru_cache
from typing import List, Tuple, Dict, Any

# 2. CONFIGURATION
class Config:
    """Central configuration class for all hyperparameters."""
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    class DATASET:
        ROOT = 'dataset/'
        TRAIN_SIZE = 50000
        # Use a smaller subset for faster fine-tuning during GA evolution
        FINETUNE_SUBSET_SIZE = 10000 
    
    class TRAINING:
        BATCH_SIZE = 128
        # Epochs for training the initial baseline model
        BASELINE_EPOCHS = 5
        # Epochs for fine-tuning each pruned candidate during GA evolution
        FINETUNE_EPOCHS = 3
        LEARNING_RATE = 0.001

    class MODEL:
        BASELINE_MODEL_PATH = 'models/baseline_model_improved.pth'
    
    class GA:
        POP_SIZE = 20
        NUM_GENERATIONS = 10
        CROSSOVER_RATE = 0.8
        # Increased mutation rate for better exploration of architectures
        MUTATION_RATE = 0.05 
        TOURNAMENT_SIZE = 3
        # Define which convolutional layers are prunable
        PRUNABLE_LAYER_NAMES = ['conv1', 'conv2']

    class FITNESS:
        # Define the importance of accuracy vs. computational cost
        # w_acc * (accuracy / 100) + w_macs * (1 - macs_ratio)
        W_ACC = 0.6
        W_MACS = 0.4

# 3. MODEL DEFINITION
class BaselineCNN(nn.Module):
    """
    A simple CNN for MNIST classification. The channel numbers are chosen
    to be easily prunable.
    """
    def __init__(self, c1: int = 32, c2: int = 64):
        super().__init__()
        self.c1 = c1
        self.c2 = c2
        self.conv1 = nn.Conv2d(1, self.c1, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(self.c1, self.c2, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(7 * 7 * self.c2, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7 * 7 * self.c2)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 4. UTILITY FUNCTIONS
def get_data_loaders(config: Config) -> Tuple[DataLoader, DataLoader, DataLoader, DataLoader]:
    """Loads and prepares MNIST data loaders."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    full_train_dataset = datasets.MNIST(root=config.DATASET.ROOT, train=True, download=True, transform=transform)
    val_size = len(full_train_dataset) - config.DATASET.TRAIN_SIZE
    train_dataset, val_dataset = random_split(full_train_dataset, [config.DATASET.TRAIN_SIZE, val_size])

    # Create a smaller subset of the training data for faster fine-tuning
    finetune_indices = np.random.choice(len(train_dataset), config.DATASET.FINETUNE_SUBSET_SIZE, replace=False)
    finetune_subset = Subset(train_dataset, finetune_indices)

    test_dataset = datasets.MNIST(root=config.DATASET.ROOT, train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=config.TRAINING.BATCH_SIZE, shuffle=True)
    finetune_loader = DataLoader(finetune_subset, batch_size=config.TRAINING.BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.TRAINING.BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=config.TRAINING.BATCH_SIZE, shuffle=False)
    
    return train_loader, finetune_loader, val_loader, test_loader

def train_model(model: nn.Module, loader: DataLoader, epochs: int, lr: float, device: str):
    """Generic function to train or fine-tune a model."""
    model.to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"  Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

def evaluate_model(model: nn.Module, loader: DataLoader, device: str) -> float:
    """Evaluates the model's accuracy on a given data loader."""
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100 * correct / total

def calculate_model_metrics(model: nn.Module, test_loader: DataLoader, device: str) -> Dict[str, Any]:
    """Calculates all relevant metrics for a given model."""
    model.to(device)
    accuracy = evaluate_model(model, test_loader, device)
    
    dummy_input = torch.randn(1, 1, 28, 28).to(device)
    macs, params = profile(model, inputs=(dummy_input,), verbose=False)
    
    # Calculate model size
    if not os.path.exists('models'): os.makedirs('models')
    torch.save(model.state_dict(), 'models/temp_model_size.pth')
    size_mb = os.path.getsize('models/temp_model_size.pth') / (1024 * 1024)
    os.remove('models/temp_model_size.pth')
    
    return {"accuracy": accuracy, "params": params, "macs": macs, "size_mb": size_mb}

# 5. GENETIC ALGORITHM & PRUNING LOGIC

def create_pruned_model_from_chromosome(
    baseline_model: nn.Module, 
    chromosome: Tuple[int, ...], 
    config: Config
) -> nn.Module:
    """
    Creates a new, smaller model instance based on a chromosome mask.
    This function performs structured pruning by removing entire filters.
    """
    # Count how many channels to keep in each prunable layer
    prunable_layers = {name: module for name, module in baseline_model.named_modules() if name in config.GA.PRUNABLE_LAYER_NAMES}
    
    c1_channels_base = prunable_layers['conv1'].out_channels
    
    c1_channels_pruned = sum(chromosome[0:c1_channels_base])
    c2_channels_pruned = sum(chromosome[c1_channels_base:])
    
    # Ensure at least one channel is kept to avoid creating an invalid layer
    if c1_channels_pruned == 0: c1_channels_pruned = 1
    if c2_channels_pruned == 0: c2_channels_pruned = 1

    # Instantiate the new, smaller architecture
    pruned_model = BaselineCNN(c1=c1_channels_pruned, c2=c2_channels_pruned)
    
    # --- Carefully copy weights from the baseline model to the new pruned model ---
    baseline_state_dict = baseline_model.state_dict()
    pruned_state_dict = pruned_model.state_dict()
    
    current_gene_idx = 0
    last_kept_indices = None
    
    # Handle Convolutional Layers
    for name, module in baseline_model.named_modules():
        if name in config.GA.PRUNABLE_LAYER_NAMES:
            out_channels = module.out_channels
            layer_chromosome = chromosome[current_gene_idx : current_gene_idx + out_channels]
            kept_indices = [i for i, gene in enumerate(layer_chromosome) if gene == 1]
            if not kept_indices: kept_indices = [np.random.randint(0, out_channels)] # Keep one if all pruned

            # Prune the output channels of the current layer's weight
            weight = baseline_state_dict[f'{name}.weight'][kept_indices, :, :, :]
            # Prune the input channels based on the previous layer's pruning
            if last_kept_indices is not None:
                weight = weight[:, last_kept_indices, :, :]
            
            pruned_state_dict[f'{name}.weight'] = weight
            pruned_state_dict[f'{name}.bias'] = baseline_state_dict[f'{name}.bias'][kept_indices]
            
            last_kept_indices = kept_indices
            current_gene_idx += out_channels

    # Handle the first fully-connected layer that connects to the last pruned conv layer
    fc1_weight = baseline_state_dict['fc1.weight']
    feature_map_size = 7 * 7
    kept_features = []
    if last_kept_indices:
        for idx in last_kept_indices:
            kept_features.extend(range(idx * feature_map_size, (idx + 1) * feature_map_size))
        
        pruned_state_dict['fc1.weight'] = fc1_weight[:, kept_features]
        pruned_state_dict['fc1.bias'] = baseline_state_dict['fc1.bias']
    
    # Copy the last layer's weights
    pruned_state_dict['fc2.weight'] = baseline_state_dict['fc2.weight']
    pruned_state_dict['fc2.bias'] = baseline_state_dict['fc2.bias']
    
    pruned_model.load_state_dict(pruned_state_dict)
    return pruned_model


class GeneticAlgorithmPruner:
    """
    Manages the evolutionary process for finding the optimal pruning mask.
    """
    def __init__(self, baseline_model: nn.Module, config: Config, fitness_evaluator):
        self.baseline_model = baseline_model
        self.config = config
        self.fitness_evaluator = fitness_evaluator
        self.chromosome_length = self._get_chromosome_length()
        self.population = self._initialize_population()
        print(f"🧬 GA Initialized. Chromosome length: {self.chromosome_length}")

    def _get_chromosome_length(self) -> int:
        length = 0
        for name in self.config.GA.PRUNABLE_LAYER_NAMES:
            layer = dict(self.baseline_model.named_modules())[name]
            length += layer.out_channels
        return length

    def _initialize_population(self) -> List[List[int]]:
        return [np.random.randint(0, 2, self.chromosome_length).tolist() for _ in range(self.config.GA.POP_SIZE)]

    def _tournament_selection(self, fitnesses: List[float]) -> Tuple[List[int], List[int]]:
        parents = []
        for _ in range(2):
            tournament_indices = np.random.choice(len(self.population), self.config.GA.TOURNAMENT_SIZE, replace=False)
            best_in_tournament_idx = tournament_indices[np.argmax([fitnesses[i] for i in tournament_indices])]
            parents.append(self.population[best_in_tournament_idx])
        return tuple(parents)

    def _uniform_crossover(self, parent1: List[int], parent2: List[int]) -> Tuple[List[int], List[int]]:
        if np.random.rand() > self.config.GA.CROSSOVER_RATE:
            return parent1[:], parent2[:]
        
        mask = np.random.rand(self.chromosome_length) < 0.5
        child1 = [p1 if m else p2 for p1, p2, m in zip(parent1, parent2, mask)]
        child2 = [p2 if m else p1 for p1, p2, m in zip(parent1, parent2, mask)]
        return child1, child2

    def _bit_flip_mutation(self, chromosome: List[int]) -> List[int]:
        return [(1 - gene if np.random.rand() < self.config.GA.MUTATION_RATE else gene) for gene in chromosome]

    def run_evolution(self) -> Tuple[int, ...]:
        print("\n---  Starting Genetic Algorithm Evolution ---\n")
        history = {'best_fitness': [], 'avg_fitness': []}
        best_chromosome_so_far = None
        best_fitness_so_far = -1

        for gen in range(self.config.GA.NUM_GENERATIONS):
            start_time = time.time()
            
            # Fitness evaluation (uses caching for efficiency)
            fitnesses = [self.fitness_evaluator(tuple(chromo)) for chromo in self.population]

            best_fitness_gen = np.max(fitnesses)
            avg_fitness_gen = np.mean(fitnesses)
            best_idx_gen = np.argmax(fitnesses)

            if best_fitness_gen > best_fitness_so_far:
                best_fitness_so_far = best_fitness_gen
                best_chromosome_so_far = self.population[best_idx_gen]

            history['best_fitness'].append(best_fitness_gen)
            history['avg_fitness'].append(avg_fitness_gen)
            
            gen_time = time.time() - start_time
            print(f" Gen {gen+1:02d}/{self.config.GA.NUM_GENERATIONS} | "
                  f" Best Fitness: {best_fitness_gen:.4f} | "
                  f" Avg Fitness: {avg_fitness_gen:.4f} | "
                  f" Time: {gen_time:.2f}s")

            # Create the next generation
            new_population = [best_chromosome_so_far]  # Elitism

            while len(new_population) < self.config.GA.POP_SIZE:
                parent1, parent2 = self._tournament_selection(fitnesses)
                child1, child2 = self._uniform_crossover(parent1, parent2)
                child1 = self._bit_flip_mutation(child1)
                child2 = self._bit_flip_mutation(child2)
                new_population.extend([child1, child2])

            self.population = new_population[:self.config.GA.POP_SIZE]

        self._plot_history(history)
        print("\n---  Evolution Complete ---")
        return tuple(best_chromosome_so_far)

    def _plot_history(self, history: dict):
        plt.figure(figsize=(12, 6))
        plt.plot(history['best_fitness'], label='Best Fitness per Generation', marker='o', linestyle='-')
        plt.plot(history['avg_fitness'], label='Average Fitness per Generation', marker='x', linestyle='--')
        plt.title('GA Fitness Evolution Over Generations')
        plt.xlabel('Generation')
        plt.ylabel('Fitness Score')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig('ga_fitness_evolution.png')
        print("\n Fitness evolution graph saved to 'ga_fitness_evolution.png'")


# 6. MAIN EXECUTION BLOCK
if __name__ == "__main__":
    config = Config()
    
    # --- STEP 1: LOAD DATA ---
    print("--- 1. Loading Data ---")
    train_loader, finetune_loader, val_loader, test_loader = get_data_loaders(config)
    print(f"Using device: {config.DEVICE}")

    # --- STEP 2: TRAIN OR LOAD BASELINE MODEL ---
    print("\n--- 2. Setting Up Baseline Model ---")
    if not os.path.exists(config.MODEL.BASELINE_MODEL_PATH):
        print("No baseline model found. Training a new one...")
        if not os.path.exists('models'): os.makedirs('models')
        baseline_model = BaselineCNN()
        train_model(baseline_model, train_loader, config.TRAINING.BASELINE_EPOCHS, config.TRAINING.LEARNING_RATE, config.DEVICE)
        torch.save(baseline_model.state_dict(), config.MODEL.BASELINE_MODEL_PATH)
        print(f"Baseline model trained and saved to {config.MODEL.BASELINE_MODEL_PATH}")
    else:
        print(f"Loading existing baseline model from {config.MODEL.BASELINE_MODEL_PATH}")
        baseline_model = BaselineCNN()
        baseline_model.load_state_dict(torch.load(config.MODEL.BASELINE_MODEL_PATH, map_location=config.DEVICE))

    # --- STEP 3: EVALUATE BASELINE MODEL ---
    print("\n--- 3. Evaluating Baseline Model Performance ---")
    baseline_metrics = calculate_model_metrics(baseline_model, test_loader, config.DEVICE)
    baseline_macs = baseline_metrics['macs']

    # --- STEP 4: DEFINE FITNESS FUNCTION AND RUN GA ---
    # This is the core function for the GA. It creates a pruned model,
    # fine-tunes it, and evaluates its fitness. Caching is essential for speed.
    @lru_cache(maxsize=None)
    def evaluate_fitness(chromosome: Tuple[int, ...]) -> float:
        """
        Fitness function: creates, fine-tunes, and evaluates a pruned model.
        """
        print(f"\n>> Evaluating new chromosome: {str(chromosome)[:50]}...")
        # 1. Create pruned model from chromosome
        pruned_model = create_pruned_model_from_chromosome(baseline_model, chromosome, config)
        
        # 2. Fine-tune the new model on a subset of data
        print(f"   Fine-tuning for {config.TRAINING.FINETUNE_EPOCHS} epochs...")
        train_model(pruned_model, finetune_loader, config.TRAINING.FINETUNE_EPOCHS, config.TRAINING.LEARNING_RATE, config.DEVICE)
        
        # 3. Evaluate on validation set
        accuracy = evaluate_model(pruned_model, val_loader, config.DEVICE)
        
        # 4. Calculate computational cost (MACs)
        dummy_input = torch.randn(1, 1, 28, 28).to(config.DEVICE)
        macs, _ = profile(pruned_model, inputs=(dummy_input,), verbose=False)
        mac_reduction = 1.0 - (macs / baseline_macs)
        
        # 5. Calculate final fitness score
        fitness = (config.FITNESS.W_ACC * (accuracy / 100.0)) + (config.FITNESS.W_MACS * mac_reduction)
        print(f"   Done. Accuracy: {accuracy:.2f}%, MAC Reduction: {mac_reduction:.2%}, Fitness: {fitness:.4f}")
        return fitness

    # Instantiate and run the GA
    ga_pruner = GeneticAlgorithmPruner(baseline_model, config, fitness_evaluator=evaluate_fitness)
    best_chromosome = ga_pruner.run_evolution()

    # --- STEP 5: BUILD AND EVALUATE THE FINAL PRUNED MODEL ---
    print("\n--- 4. Building and Evaluating Final Pruned Model ---")
    print(f"Best chromosome found: {best_chromosome}")
    final_pruned_model = create_pruned_model_from_chromosome(baseline_model, best_chromosome, config)
    
    # We must re-train the final model on the full training set for a fair comparison
    print("Fine-tuning final model on the full training dataset...")
    train_model(final_pruned_model, train_loader, config.TRAINING.BASELINE_EPOCHS, config.TRAINING.LEARNING_RATE, config.DEVICE)
    
    # Evaluate the final, re-trained model on the unseen test set
    pruned_metrics = calculate_model_metrics(final_pruned_model, test_loader, config.DEVICE)
    
    # --- STEP 6: DISPLAY FINAL RESULTS ---
    print("\n\n" + "="*50)
    print("               FINAL RESULTS COMPARISON")
    print("="*50)
    print(f"{'Metric':<20} | {'Baseline Model':<15} | {'Pruned Model':<15}")
    print("-"*50)
    print(f"{'Accuracy (%)':<20} | {baseline_metrics['accuracy']:<15.2f} | {pruned_metrics['accuracy']:<15.2f}")
    print(f"{'Parameters (M)':<20} | {baseline_metrics['params']/1e6:<15.2f} | {pruned_metrics['params']/1e6:<15.2f}")
    print(f"{'MACs (G)':<20} | {baseline_metrics['macs']/1e9:<15.2f} | {pruned_metrics['macs']/1e9:<15.2f}")
    print(f"{'Size (MB)':<20} | {baseline_metrics['size_mb']:<15.2f} | {pruned_metrics['size_mb']:<15.2f}")
    print("-"*50)
    
    reduction_params = 100 * (1 - pruned_metrics['params'] / baseline_metrics['params'])
    reduction_macs = 100 * (1 - pruned_metrics['macs'] / baseline_metrics['macs'])
    print(f"Parameter Reduction: {reduction_params:.2f}%")
    print(f"Computational Reduction (MACs): {reduction_macs:.2f}%")
    print("="*50)

--- 1. Loading Data ---
Using device: cuda

--- 2. Setting Up Baseline Model ---
No baseline model found. Training a new one...
  Epoch 1/5, Loss: 0.0350
  Epoch 2/5, Loss: 0.0131
  Epoch 3/5, Loss: 0.0100
  Epoch 4/5, Loss: 0.0130
  Epoch 5/5, Loss: 0.0014
Baseline model trained and saved to models/baseline_model_improved.pth

--- 3. Evaluating Baseline Model Performance ---
🧬 GA Initialized. Chromosome length: 96

--- 🚀 Starting Genetic Algorithm Evolution ---


>> Evaluating new chromosome: (1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0...
   Fine-tuning for 3 epochs...
  Epoch 1/3, Loss: 0.0089
  Epoch 2/3, Loss: 0.0004
  Epoch 3/3, Loss: 0.0000
   Done. Accuracy: 99.01%, MAC Reduction: 66.49%, Fitness: 0.8600

>> Evaluating new chromosome: (0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1...
   Fine-tuning for 3 epochs...
  Epoch 1/3, Loss: 0.0002
  Epoch 2/3, Loss: 0.0002
  Epoch 3/3, Loss: 0.0003
   Done. Accuracy: 99.03%, MAC Reduction: 66.15%, Fitness: 0.8588

>> Evaluati

: 