In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils as U
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import random

### Define the Neural Network

We want to be able to set the weight of the network manually

In [None]:
class NeuralNet(nn.Module):
    def __init__(self, in_dim=2, out_dim=2):
        super(NeuralNet, self).__init__()
        
        self.layer1=nn.Linear(in_features=in_dim, out_features=4)
        self.layer2=nn.Linear(in_features=4, out_features=out_dim)

        self.weights_initialization()
    
    def weights_initialization(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)  # This is the default in PyTorch
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        out = self.layer1(x)
        out = F.relu(out)
        out = self.layer2(out)
        return out
    
    def get_flat_params(self):
        return U.parameters_to_vector(self.parameters())
    
    def set_flat_params(self, flat_params):
        U.vector_to_parameters(flat_params, self.parameters())

### Load Data

In [3]:
X, y = make_moons(n_samples=1000, noise=0.1, random_state=42)

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

### Evolutionary Algorithm

In our case the flat parameters will be the chromosomes

In [None]:
# Initialize model and get flat parameters
model = NeuralNet()
params_vector = model.get_flat_params()
print(params_vector)

# Configuration
POPULATION_SIZE = 100
CHROMOSOME_LENGTH = len(params_vector)

tensor([ 5.4157e-01, -9.6938e-01,  6.8159e-01, -3.3370e-01,  3.3292e-01,
         6.7682e-01,  9.4453e-01, -1.3159e-01,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  7.5796e-02, -2.5452e-01,  1.1230e-01,
         9.0249e-01, -1.3721e-04, -2.8852e-01,  7.4950e-01,  8.8189e-01,
         0.0000e+00,  0.0000e+00], grad_fn=<CatBackward0>)


In [41]:
def generate_population(size, chromosome_length):
    """ Generate random population """
    return [torch.rand(chromosome_length) for _ in range(size)]

def fitness(chromosome, model):
    """ Algorithm's fitness function """
    model.set_flat_params(chromosome)   # Update model's parameters with the chromosome
    with torch.no_grad():
        outputs = model(X_test)
        pred = outputs.argmax(dim=1)
        accuracy = (pred == y_test).float().mean().item()

    return accuracy

def selection(population, fitness_scores):
    """ Roulette Wheel Selection """
    total_fitness = sum(fitness_scores)
    selection_probs = [f / total_fitness for f in fitness_scores]
    parent1 = population[random.choices(range(len(population)), selection_probs)[0]]
    parent2 = population[random.choices(range(len(population)), selection_probs)[0]]

    return parent1, parent2

def crossover(parent1, parent2):
    """ Two-Point Crossover """
    # Choose two random points
    point1 = random.randint(1, len(parent1)-1)
    point2 = random.randint(1, len(parent1)-1)

    # Ensure the points are different
    while point1 == point2:
        point2 = random.randint(1, len(parent1)-1)
    
    if point2 > point1:
        offspring1 = torch.cat([parent1[:point1], parent2[point1:point2], parent1[point2:]])
        offspring2 = torch.cat([parent2[:point1], parent1[point1:point2], parent2[point2:]])
    else:
        offspring1 = torch.cat([parent1[:point2], parent2[point2:point1], parent1[point1:]])
        offspring2 = torch.cat([parent2[:point2], parent1[point2:point1], parent2[point1:]])
    
    return offspring1, offspring2

def mutate_v1(chromosome, mutation_rate=0.01):
    """ Mutation Function """
    noise = torch.randn_like(chromosome) * mutation_rate # Generate random tensor with same length as chromosome
    return chromosome + noise

def mutate(chromosome, mutation_rate=0.05, mutation_strength=0.1):
    """
        Mutate a small subset of the genes of the chromosome.

        Args:
            chromosome (tensor): Chromosome to mutate.
            mutation_rate (float): Probability of mutating each gene.
            mutation_strength (float): How much noise is aggregated to mutated genes.
    """
    mutated = chromosome.clone()

    for i in range(len(mutated)):
        if random.random() < mutation_rate:
            mutated[i] += torch.randn(1).item() * mutation_strength

    return mutated



In [43]:
population = generate_population(POPULATION_SIZE, CHROMOSOME_LENGTH)
fitness_scores = [fitness(individual, model) for individual in population]
parent1, parent2 = selection(population, fitness_scores)
offspring1, offspring2 = crossover(parent1, parent2)
mutated_offspring1 = mutate(offspring1)

print(f"Population Sample: {population[0]}\n")
print(f"Fitness Score Sample: {fitness_scores[0]}\n")
print(f"Selection Result: {parent1, parent2}\n")
print(f"Crossover Result: {offspring1, offspring2}\n")
print(f"Mutated Offspring 1: {mutated_offspring1}")


Population Sample: tensor([0.5301, 0.2916, 0.7357, 0.5392, 0.3083, 0.5996, 0.8071, 0.7868, 0.6914,
        0.1242, 0.8817, 0.5972, 0.8713, 0.4578, 0.5880, 0.2074, 0.9841, 0.9371,
        0.5900, 0.7942, 0.3009, 0.3163])

Fitness Score Sample: 0.5

Selection Result: (tensor([0.9512, 0.7355, 0.3156, 0.3789, 0.2979, 0.9569, 0.4994, 0.5506, 0.9616,
        0.2657, 0.8021, 0.5724, 0.5779, 0.1046, 0.9346, 0.6555, 0.0720, 0.9649,
        0.1768, 0.9612, 0.3261, 0.0611]), tensor([0.0786, 0.8736, 0.5765, 0.9070, 0.0564, 0.3788, 0.5298, 0.0863, 0.8242,
        0.6577, 0.4509, 0.4715, 0.9762, 0.7907, 0.2238, 0.2519, 0.1346, 0.0428,
        0.3270, 0.8231, 0.4135, 0.9019]))

Crossover Result: (tensor([0.9512, 0.7355, 0.3156, 0.9070, 0.0564, 0.3788, 0.5298, 0.0863, 0.8242,
        0.6577, 0.4509, 0.5724, 0.5779, 0.1046, 0.9346, 0.6555, 0.0720, 0.9649,
        0.1768, 0.9612, 0.3261, 0.0611]), tensor([0.0786, 0.8736, 0.5765, 0.3789, 0.2979, 0.9569, 0.4994, 0.5506, 0.9616,
        0.2657, 0.8021, 0.4