# üé® MAP-Elites NAS - Self-Contained Notebook

**Everything in one notebook - no external files needed!**

This notebook contains all the code you need. Just run the cells in order.

---

## üì¶ Install Dependencies

In [None]:
# Run this once to install packages
!pip install torch torchvision networkx matplotlib tqdm scipy

## üèóÔ∏è Section 1: Architecture Classes

This section defines how we represent neural architectures.

In [None]:
import random
import copy
import networkx as nx
from typing import List, Tuple, Dict, Optional

# Operation pool
OPERATION_POOL = [
    'conv3x3',
    'conv5x5',
    'sep_conv3x3',
    'sep_conv5x5',
    'max_pool3x3',
    'avg_pool3x3',
    'skip_connect'
]

OP_TO_IDX = {op: idx for idx, op in enumerate(OPERATION_POOL)}


class ArchitectureState:
    """Represents a neural network architecture as a directed acyclic graph."""
    
    def __init__(self):
        self.nodes = []
        self.edges = []
        self.operations = {}
        self.channels = {}
        self.positions = {}
        self.input_node = None
        self.output_node = None
        self._node_counter = 0
    
    @staticmethod
    def initialize_starter(operation_strategy='diverse'):
        """Create a simple starter architecture."""
        arch = ArchitectureState()
        
        # Input node
        arch.input_node = arch.add_node('input', 3, 0)
        
        # Hidden layers
        node1 = arch.add_node('conv3x3', 32, 1)
        node2 = arch.add_node('conv3x3', 64, 2)
        
        # Output node
        arch.output_node = arch.add_node('output', 10, 3)
        
        # Edges
        arch.add_edge(arch.input_node, node1)
        arch.add_edge(node1, node2)
        arch.add_edge(node2, arch.output_node)
        
        return arch
    
    def add_node(self, operation: str, channels: int, position: int) -> int:
        """Add a node to the architecture."""
        node_id = self._node_counter
        self._node_counter += 1
        
        self.nodes.append(node_id)
        self.operations[node_id] = operation
        self.channels[node_id] = channels
        self.positions[node_id] = position
        
        return node_id
    
    def add_edge(self, src: int, dst: int):
        """Add an edge between nodes."""
        if (src, dst) not in self.edges:
            self.edges.append((src, dst))
    
    def remove_node(self, node_id: int):
        """Remove a node and reconnect."""
        if node_id in [self.input_node, self.output_node]:
            return
        
        # Get predecessors and successors
        predecessors = [src for src, dst in self.edges if dst == node_id]
        successors = [dst for src, dst in self.edges if src == node_id]
        
        # Remove edges
        self.edges = [(src, dst) for src, dst in self.edges 
                     if src != node_id and dst != node_id]
        
        # Reconnect
        for pred in predecessors:
            for succ in successors:
                self.add_edge(pred, succ)
        
        # Remove node
        self.nodes.remove(node_id)
        del self.operations[node_id]
        del self.channels[node_id]
        del self.positions[node_id]
    
    def remove_edge(self, src: int, dst: int):
        """Remove an edge."""
        if (src, dst) in self.edges:
            self.edges.remove((src, dst))
    
    def increase_channels(self, node_id: int):
        """Double channels at a node."""
        self.channels[node_id] = min(self.channels[node_id] * 2, 512)
    
    def decrease_channels(self, node_id: int):
        """Halve channels at a node."""
        self.channels[node_id] = max(self.channels[node_id] // 2, 16)
    
    def copy(self):
        """Create a deep copy."""
        return copy.deepcopy(self)
    
    @property
    def depth(self) -> int:
        """Maximum depth of the architecture."""
        if not self.positions:
            return 0
        return max(self.positions.values()) - min(self.positions.values())
    
    @property
    def avg_width(self) -> float:
        """Average number of channels."""
        if not self.channels:
            return 0
        return sum(self.channels.values()) / len(self.channels)
    
    @property
    def total_params(self) -> int:
        """Estimate total parameters."""
        params = 0
        for src, dst in self.edges:
            params += self.channels[src] * self.channels[dst] * 9  # 3x3 conv
        return params
    
    @property
    def num_skip_connections(self) -> int:
        """Count skip connections."""
        count = 0
        for src, dst in self.edges:
            if self.positions[dst] - self.positions[src] > 1:
                count += 1
        return count

print("‚úÖ Architecture classes loaded")

## üß† Section 2: Neural Network Model

Convert architecture to executable PyTorch model.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvNet(nn.Module):
    """Convert ArchitectureState to executable PyTorch model."""
    
    def __init__(self, arch: ArchitectureState, num_classes: int = 10):
        super().__init__()
        self.arch = arch
        self.num_classes = num_classes
        self.layers = nn.ModuleDict()
        
        # Create layers
        for node in arch.nodes:
            if node == arch.input_node:
                continue
            
            op = arch.operations[node]
            in_channels = self._get_input_channels(node)
            out_channels = arch.channels[node]
            
            if node == arch.output_node:
                self.layers[str(node)] = nn.Linear(in_channels, num_classes)
            else:
                self.layers[str(node)] = self._create_operation(op, in_channels, out_channels)
    
    def _get_input_channels(self, node: int) -> int:
        """Get input channels for a node."""
        predecessors = [src for src, dst in self.arch.edges if dst == node]
        if not predecessors:
            return 3
        return sum(self.arch.channels[pred] for pred in predecessors)
    
    def _create_operation(self, op: str, in_channels: int, out_channels: int):
        """Create operation layer."""
        if op == 'conv3x3':
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        elif op == 'conv5x5':
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 5, padding=2),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        elif op == 'max_pool3x3':
            return nn.Sequential(
                nn.MaxPool2d(3, stride=1, padding=1),
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
        elif op == 'skip_connect':
            if in_channels == out_channels:
                return nn.Identity()
            else:
                return nn.Conv2d(in_channels, out_channels, 1)
        else:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
    
    def forward(self, x):
        """Forward pass."""
        outputs = {self.arch.input_node: x}
        
        # Topological sort
        sorted_nodes = self._topological_sort()
        
        for node in sorted_nodes:
            if node == self.arch.input_node:
                continue
            
            # Get inputs
            predecessors = [src for src, dst in self.arch.edges if dst == node]
            if not predecessors:
                continue
            
            # Concatenate inputs
            inputs = [outputs[pred] for pred in predecessors if pred in outputs]
            if not inputs:
                continue
            
            x = torch.cat(inputs, dim=1) if len(inputs) > 1 else inputs[0]
            
            # Apply operation
            if node == self.arch.output_node:
                x = F.adaptive_avg_pool2d(x, 1)
                x = x.flatten(1)
                outputs[node] = self.layers[str(node)](x)
            else:
                outputs[node] = self.layers[str(node)](x)
        
        return outputs[self.arch.output_node]
    
    def _topological_sort(self) -> List[int]:
        """Sort nodes topologically."""
        return sorted(self.arch.nodes, key=lambda n: self.arch.positions[n])

print("‚úÖ ConvNet model loaded")

## üéØ Section 3: Training Function

Train and evaluate architectures.

In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

def get_data_loaders(dataset='mnist', batch_size=128, subset_size=None):
    """Get data loaders for training."""
    if dataset == 'mnist':
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.1307,)*3, (0.3081,)*3)
        ])
        trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    elif dataset == 'fashion':
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
            transforms.Normalize((0.2860,)*3, (0.3530,)*3)
        ])
        trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
    
    else:  # cifar10
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    if subset_size:
        indices = np.random.choice(len(trainset), subset_size, replace=False)
        trainset = Subset(trainset, indices)
    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return trainloader, testloader


def train_architecture(arch: ArchitectureState, 
                      epochs: int = 3,
                      device: str = 'cuda',
                      dataset: str = 'mnist',
                      subset_size: int = 10000) -> float:
    """Train an architecture and return accuracy."""
    try:
        # Create model
        model = ConvNet(arch, num_classes=10).to(device)
        
        # Get data
        trainloader, testloader = get_data_loaders(dataset, subset_size=subset_size)
        
        # Training setup
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        
        # Train
        model.train()
        for epoch in range(epochs):
            for inputs, labels in trainloader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
        
        # Evaluate
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        accuracy = correct / total
        return accuracy
    
    except Exception as e:
        print(f"Error training: {e}")
        return 0.1

print("‚úÖ Training function loaded")

## üß¨ Section 4: MAP-Elites Algorithm

Quality-Diversity search.

In [None]:
from collections import defaultdict
from tqdm import tqdm

class BehaviorSpace:
    """Defines behavior dimensions for MAP-Elites."""
    
    def __init__(self, depth_bins=5, width_bins=5, skip_bins=4):
        self.depth_bins = depth_bins
        self.width_bins = width_bins
        self.skip_bins = skip_bins
        
        self.depth_range = (3, 20)
        self.width_range = (16, 256)
        self.skip_range = (0.0, 1.0)
    
    def get_behavior(self, arch: ArchitectureState) -> Tuple[int, int, int]:
        """Get behavior descriptor for architecture."""
        depth = arch.depth
        depth_bin = self._discretize(depth, self.depth_range, self.depth_bins)
        
        avg_width = arch.avg_width
        width_bin = self._discretize(avg_width, self.width_range, self.width_bins)
        
        num_possible_skips = len(arch.nodes) * (len(arch.nodes) - 1) / 2
        skip_ratio = arch.num_skip_connections / (num_possible_skips + 1e-6)
        skip_bin = self._discretize(skip_ratio, self.skip_range, self.skip_bins)
        
        return (depth_bin, width_bin, skip_bin)
    
    def _discretize(self, value: float, value_range: Tuple[float, float], num_bins: int) -> int:
        """Discretize value into bin."""
        min_val, max_val = value_range
        value = np.clip(value, min_val, max_val)
        normalized = (value - min_val) / (max_val - min_val + 1e-6)
        bin_idx = int(normalized * num_bins)
        return min(bin_idx, num_bins - 1)
    
    def get_total_cells(self) -> int:
        return self.depth_bins * self.width_bins * self.skip_bins


class MutationOperator:
    """Defines mutation operations."""
    
    def __init__(self):
        self.mutation_types = [
            'add_node', 'remove_node', 'add_edge', 'remove_edge',
            'increase_channels', 'decrease_channels', 'replace_operation'
        ]
    
    def mutate(self, arch: ArchitectureState, operation_strategy='diverse') -> ArchitectureState:
        """Apply random mutation."""
        new_arch = arch.copy()
        mutation = random.choice(self.mutation_types)
        
        try:
            if mutation == 'add_node':
                self._add_node(new_arch, operation_strategy)
            elif mutation == 'remove_node':
                self._remove_node(new_arch)
            elif mutation == 'add_edge':
                self._add_edge(new_arch)
            elif mutation == 'remove_edge':
                self._remove_edge(new_arch)
            elif mutation == 'increase_channels':
                self._increase_channels(new_arch)
            elif mutation == 'decrease_channels':
                self._decrease_channels(new_arch)
            elif mutation == 'replace_operation':
                self._replace_operation(new_arch)
        except:
            pass
        
        return new_arch
    
    def _add_node(self, arch, strategy):
        if len(arch.nodes) >= 20:
            return
        operation = random.choice(OPERATION_POOL)
        position = random.randint(1, max(arch.positions.values()))
        channels = random.choice([32, 64, 128])
        new_id = arch.add_node(operation, channels, position)
        
        # Connect randomly
        prev_nodes = [n for n in arch.nodes if arch.positions[n] < position and n != new_id]
        next_nodes = [n for n in arch.nodes if arch.positions[n] >= position and n != new_id]
        if prev_nodes:
            arch.add_edge(random.choice(prev_nodes), new_id)
        if next_nodes:
            arch.add_edge(new_id, random.choice(next_nodes))
    
    def _remove_node(self, arch):
        removable = [n for n in arch.nodes if n not in [arch.input_node, arch.output_node]]
        if removable and len(arch.nodes) > 3:
            arch.remove_node(random.choice(removable))
    
    def _add_edge(self, arch):
        possible = [(s, d) for s in arch.nodes for d in arch.nodes 
                   if s != d and (s, d) not in arch.edges and arch.positions[s] < arch.positions[d]]
        if possible:
            src, dst = random.choice(possible)
            arch.add_edge(src, dst)
    
    def _remove_edge(self, arch):
        if len(arch.edges) > len(arch.nodes):
            arch.remove_edge(*random.choice(arch.edges))
    
    def _increase_channels(self, arch):
        node = random.choice([n for n in arch.nodes if arch.channels[n] < 512])
        arch.increase_channels(node)
    
    def _decrease_channels(self, arch):
        node = random.choice([n for n in arch.nodes if arch.channels[n] > 16])
        arch.decrease_channels(node)
    
    def _replace_operation(self, arch):
        nodes = [n for n in arch.nodes if n not in [arch.input_node, arch.output_node]]
        if nodes:
            node = random.choice(nodes)
            arch.operations[node] = random.choice(OPERATION_POOL)


class MAPElites:
    """MAP-Elites algorithm."""
    
    def __init__(self, behavior_space, mutation_operator, operation_strategy='diverse'):
        self.behavior_space = behavior_space
        self.mutation_operator = mutation_operator
        self.operation_strategy = operation_strategy
        self.archive = {}  # behavior -> (arch, performance)
        self.history = []
    
    def initialize(self, num_random=20):
        """Initialize with random architectures."""
        for _ in range(num_random):
            arch = ArchitectureState.initialize_starter()
            for _ in range(random.randint(2, 5)):
                arch = self.mutation_operator.mutate(arch)
            behavior = self.behavior_space.get_behavior(arch)
            self.archive[behavior] = (arch, 0.0)
    
    def run(self, num_iterations, evaluate_fn, verbose=True):
        """Run MAP-Elites."""
        if not self.archive:
            self.initialize()
        
        # Evaluate initial
        for behavior, (arch, _) in list(self.archive.items()):
            perf = evaluate_fn(arch)
            self.archive[behavior] = (arch, perf)
        
        # Main loop
        iterator = tqdm(range(num_iterations)) if verbose else range(num_iterations)
        
        for i in iterator:
            # Sample parent
            behavior = random.choice(list(self.archive.keys()))
            parent, _ = self.archive[behavior]
            
            # Mutate
            child = self.mutation_operator.mutate(parent)
            
            # Evaluate
            performance = evaluate_fn(child)
            
            # Get behavior
            child_behavior = self.behavior_space.get_behavior(child)
            
            # Add to archive if better
            if child_behavior not in self.archive or performance > self.archive[child_behavior][1]:
                self.archive[child_behavior] = (child, performance)
            
            # Save history
            self.history.append({'architecture': child, 'performance': performance, 'behavior': child_behavior})
            
            # Update progress
            if verbose and i % 10 == 0:
                coverage = len(self.archive) / self.behavior_space.get_total_cells()
                best = max(self.archive.values(), key=lambda x: x[1])[1]
                iterator.set_postfix({'coverage': f'{coverage:.2%}', 'best': f'{best:.4f}'})
        
        return self
    
    def get_stats(self):
        """Get statistics."""
        perfs = [p for _, p in self.archive.values()]
        return {
            'coverage': len(self.archive) / self.behavior_space.get_total_cells(),
            'num_filled': len(self.archive),
            'best_performance': max(perfs) if perfs else 0,
            'mean_performance': np.mean(perfs) if perfs else 0,
            'total_evaluated': len(self.history)
        }
    
    def get_all_architectures(self):
        """Get all architectures."""
        return [(arch.copy(), perf, behavior) for behavior, (arch, perf) in self.archive.items()]

print("‚úÖ MAP-Elites loaded")

## ‚öôÔ∏è Configuration

Set your parameters here!

In [None]:
# ===== CONFIGURATION =====

DATASET = 'mnist'          # 'mnist', 'fashion', 'cifar10'
ITERATIONS = 200           # Number of iterations
DEPTH_BINS = 5            # Behavior space resolution
WIDTH_BINS = 5
SKIP_BINS = 4
TOP_K = 10                # Top architectures to evaluate

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Configuration:")
print(f"  Dataset: {DATASET}")
print(f"  Iterations: {ITERATIONS}")
print(f"  Behavior space: {DEPTH_BINS}x{WIDTH_BINS}x{SKIP_BINS} = {DEPTH_BINS*WIDTH_BINS*SKIP_BINS} cells")
print(f"  Device: {device}")

## üöÄ Run MAP-Elites!

In [None]:
# Initialize
behavior_space = BehaviorSpace(DEPTH_BINS, WIDTH_BINS, SKIP_BINS)
mutation_operator = MutationOperator()
map_elites = MAPElites(behavior_space, mutation_operator)

# Define evaluation
def evaluate(arch):
    return train_architecture(arch, epochs=3, device=device, dataset=DATASET, subset_size=10000)

# Run!
print(f"\nüéØ Running MAP-Elites...\n")
map_elites.run(ITERATIONS, evaluate, verbose=True)

# Stats
stats = map_elites.get_stats()
print(f"\n‚úÖ Complete!")
print(f"Coverage: {stats['coverage']:.2%}")
print(f"Best: {stats['best_performance']:.4f}")

## üìä Results

In [None]:
# Get all architectures
all_archs = map_elites.get_all_architectures()
all_archs.sort(key=lambda x: x[1], reverse=True)

print(f"\nüèÜ Top {min(TOP_K, len(all_archs))} Architectures:\n")

for i, (arch, perf, behavior) in enumerate(all_archs[:TOP_K]):
    print(f"{i+1}. Acc: {perf:.4f} | Behavior: {behavior} | Nodes: {len(arch.nodes)} | Depth: {arch.depth}")

# Best
best_arch, best_perf, best_behavior = all_archs[0]
print(f"\nü•á Best Architecture:")
print(f"   Accuracy: {best_perf:.4f}")
print(f"   Behavior: {best_behavior}")
print(f"   Nodes: {len(best_arch.nodes)}")
print(f"   Depth: {best_arch.depth}")
print(f"   Avg Width: {best_arch.avg_width:.1f}")

## ‚úÖ Done!

You've successfully run MAP-Elites! üéâ

**Key Results:**
- Discovered diverse architectures across behavior space
- No external files needed - everything in one notebook!
- Simpler and faster than Deep RL approaches

**Next steps:**
- Increase ITERATIONS for better results
- Try different datasets (fashion, cifar10)
- Adjust behavior space resolution
- Fully train top architectures with more epochs