## 1. Genome design (Genotype construction)
- Node genes: neuron in the neural network
- Connection genes: connections between nodes, each with innovation numbers
- Genomes: a complete neural network (phenotype) that stores all child connection and node genes

In [None]:
import random
import networkx as nx
import matplotlib.pyplot as plt
import math
import pyglet
from pyglet import shapes
import time


In [None]:
"""
represents a capability of the system 
ex: firing capability, wind speed
"""
class NodeGene:
    def __init__(self, id, layer, activation, bias):
        self.id = id
        self.layer = layer # layer node belongs to (input, hidden, output)
        self.activation = activation
        self.bias = bias

"""
represents a rule that utilizes the capability
ex: wind speed (input node A) increases (connection) firing capability (output node B)
"""
class NodeConnection:
    def __init__(self, id, output_node_id, input_node_id, weight, enabled, innovation):
        self.id = id
        self.output_node_id = output_node_id
        self.input_node_id = input_node_id
        self.weight = weight # learned
        self.enabled = enabled # disabled gene?
        self.innovation = innovation # id tracks when the connection first showed up in evolutionary history

"""
represents the entire rule book for a team
ex: doctrine manual for red team vs blue team or a single vessel
"""
class Genome:
    def __init__(self, nodes, connections):
        self.nodes = {node.id: node for node in nodes} # each node keyed by id
        self.connections = {conn.innovation: conn for conn in connections} # each connection is keyed by innovation
        self.fitness = 0
    
    def __repr__(self):
        return f"Genome(nodes={len(self.nodes)}, connections={len(self.connections)}, fitness={self.fitness})"

## 2. Crossover algorithm (heritage)
Perform crossover between two parent genomes to derive a new child genome that inherits structures from the parent
- matching genes: randomize which parent's genes are expressed. topological connection is the same but strength (weight) differs
- disjoint gene: 
- excess genes: -

In [None]:

"""
finds the child connection set
input: parent_a and parent_b genomes
output: child genome
"""
def crossover(parent_a, parent_b, population):
    # ensure parents are different
    if parent_a == parent_b:  # Or check if weights are identical
        parent_a = random.choice([g for g in population if g != parent_b])
    
    parent_fittest = find_fittest_parent(parent_a, parent_b)

    # get set of innovation numbers of parents
    inn_num_parent_a = set(parent_a.connections.keys())
    inn_num_parent_b = set(parent_b.connections.keys())

    matching = inn_num_parent_a & inn_num_parent_b
    only_in_a = inn_num_parent_a - inn_num_parent_b
    only_in_b = inn_num_parent_b - inn_num_parent_a

    # iterate over all innovation numbers
    # collect the node connections in child
    child_node_connections = []
    
    for innovation in matching | only_in_a | only_in_b:
        # 1. matching nums: random selection of gene
        if (innovation in matching):
            random_connection = random.choice([parent_a.connections[innovation], parent_b.connections[innovation]])

            # create new conn obj for the child from parents
            child_conn = NodeConnection(
                id=random_connection.id,
                output_node_id=random_connection.output_node_id,
                input_node_id=random_connection.input_node_id,
                weight=random_connection.weight,  # inherit from selected parent
                enabled=random_connection.enabled,
                innovation=random_connection.innovation
            )
            
            child_node_connections.append(child_conn)
        # 2. include if disjoint/excess in fitter parent
        elif (innovation in parent_fittest.connections):
            parent_conn = parent_fittest.connections[innovation]

            # create new conn obj for the child from parents
            child_conn = NodeConnection(
                id=parent_conn.id,
                output_node_id=parent_conn.output_node_id,
                input_node_id=parent_conn.input_node_id,
                weight=parent_conn.weight,  # inherit from selected parent
                enabled=parent_conn.enabled,
                innovation=parent_conn.innovation
            )
            
            child_node_connections.append(child_conn)
        # 3. dont include if disjoint/excess in less fit parent
        else:
            continue

    # collect child nodes expressed
    child_node_ids = set()
    for conn in child_node_connections:
        child_node_ids.add(conn.input_node_id)
        child_node_ids.add(conn.output_node_id)
    
    # create a new child node for each child
    child_nodes = [
        NodeGene(
        id=node_id,
        layer=parent_fittest.nodes[node_id].layer,
        activation=parent_fittest.nodes[node_id].activation,
        bias=parent_fittest.nodes[node_id].bias
        )
        for node_id in child_node_ids
    ]

    # build child genome
    child = Genome(
        nodes = child_nodes,
        connections = child_node_connections
    )

    return child

"""
finds the fitter parent based on fitness score
input: parent_a and parent_b genomes
output: fitter parent genome
"""
def find_fittest_parent(a, b):
    return a if a.fitness >= b.fitness else b

"""
generates a graph visualization of genome
"""
def visualize_genome(genome, ax=None):
    G = nx.DiGraph()

    # create edge for each enabled connection
    for conn in genome.connections.values():
        if conn.enabled:
            G.add_edge(conn.input_node_id, conn.output_node_id)
    
    # draw force directed graph
    if ax is None:
        nx.draw(G, with_labels=True, arrows=True)
        plt.show()
    else:
        nx.draw(G, ax=ax, with_labels=True, arrows=True)

"""
print genome state
"""
def print_genome_state(genome, label="Genome"):
    print(f"\n=== {label} ===")
    print(f"Nodes: {list(genome.nodes.keys())}")
    print("Connections:")
    for inn, conn in genome.connections.items():
        status = "enabled" if conn.enabled else "DISABLED"
        print(f"  [{inn}] {conn.input_node_id} → {conn.output_node_id} | weight: {conn.weight:.4f} | {status}")

## 3. Innovation counter
tracks the evolutionary history of all new genetic mutations that are created over time for a set population


In [None]:
class InnovationCounter:
    def __init__(self, global_innovations, node_connection_innovations):
        self.global_innovations = global_innovations
        self.node_connection_innovations = node_connection_innovations or {} # dict (key(in_node_id, out_node_id) -> innovation number)
    
    """
    get innovation number per connection. create new one if it doesn't exist
    tracks new unique strategies developed
    """
    def get_inn_for_connection(self, in_node_id, out_node_id):
        # key by in/out node
        key = (in_node_id, out_node_id)

        # assign or create new innovation
        if (key in self.node_connection_innovations):
            return self.node_connection_innovations[key]
        else:
            # increment global counter
            self.global_innovations += 1

            # create new innovation number
            self.node_connection_innovations[key] = self.global_innovations
            return self.global_innovations


## 4. Mutations
apply noise to children after crossover for natural evolution
- mutation_rate - percentage of population mutated
- perturbation_strength - weight is slightly changes
- reset_rate - weight fully randomized

In [None]:
"""
update execution intensity of particular strategies
"""
def mutate_weights(genome, mutation_rate, perturbation_strength, reset_rate):
    for conn in genome.connections.values():
        # run mutation
        if random.random() < mutation_rate:
            # determine type of mutation
            if random.random() < reset_rate:
                conn.weight = random.uniform(0, 1)
            else:
                # apply slight gaussian perturbation
                conn.weight += random.gauss(0, perturbation_strength)

# """
# create new units or variables (ex: sensor unit, wind speed)
# """
# function mutate_nodes(genome, innov: InnovationCounter):

"""
create new connection between two nodes
"""
def mutate_connections(genome, mutation_rate, innov: InnovationCounter):
    # get all possible node pairs
    # possible pairs do not create cyclical connections
    all_node_ids = list(genome.nodes.keys())
    possible_node_pairs = list([(a, b) for a in all_node_ids for b in all_node_ids if a != b ])
    
    # filter out existing node pairs
    existing_connections = set((conn.input_node_id, conn.output_node_id) for conn in genome.connections.values())

    # valid pairs are not already existing
    valid_pairs = [
        (a, b) 
        for (a, b) in possible_node_pairs 
        if (a,b) not in existing_connections
        if genome.nodes[a].layer < genome.nodes[b].layer # lower -> higher layer connections only
    ]

    if not valid_pairs:
        return

    # select 2 random pairs
    if random.random() < mutation_rate:
        (in_node, out_node) = random.choice(valid_pairs)

    # create new connection and add to genome
    innovation_num = innov.get_inn_for_connection(in_node, out_node)
    
    new_conn = NodeConnection(
        id=1, 
        input_node_id=in_node, 
        output_node_id=out_node, 
        weight=random.uniform(0,1), 
        enabled=True, 
        innovation=innovation_num
    )

    genome.connections[innovation_num] = new_conn

#### 4a. test crossover x mutations algorithm
expected: child should adopt all of parent_a's traits (fitter)

In [None]:
# nodes
node_1 = NodeGene(id=1, layer=0, activation='sigmoid', bias=0) # input
node_2 = NodeGene(id=2, layer=0, activation='sigmoid', bias=0) # input

node_3 = NodeGene(id=3, layer=1, activation='sigmoid', bias=0) # hidden
node_4 = NodeGene(id=4, layer=2, activation='sigmoid', bias=0) # output

# create connections for parent A
conn_a1 = NodeConnection(id=1, output_node_id=3, input_node_id=1, 
                        weight=0.87, enabled=True, innovation=1)
conn_a2 = NodeConnection(id=2, output_node_id=4, input_node_id=3, 
                        weight=0.5, enabled=True, innovation=2)

# create connections for parent B  
conn_b1 = NodeConnection(id=1, output_node_id=3, input_node_id=1, 
                        weight=0.5, enabled=True, innovation=1)
conn_b3 = NodeConnection(id=3, output_node_id=4, input_node_id=1, 
                        weight=0.5, enabled=True, innovation=3)

# build genomes
parent_a = Genome(nodes=[node_1, node_2, node_3, node_4], 
                connections=[conn_a1, conn_a2])
parent_a.fitness = 10

parent_b = Genome(nodes=[node_1, node_2, node_3], 
                connections=[conn_b1, conn_b3])
parent_b.fitness = 5

population = [parent_a, parent_b]

# init counter
innov = InnovationCounter(
    global_innovations = 2, # start after existing innovations
    node_connection_innovations = {(1, 3): 1, (2, 3): 2}  # dict with existing connections
)

# print parents
fig, axes = plt.subplots(1, 4, figsize=(15, 5))  # 1 row, 4 columns

print_genome_state(parent_a, "Parent A (fitness=10)")
visualize_genome(parent_a, ax=axes[0])
axes[0].set_title("Parent A (fitness=10)")

print_genome_state(parent_b, "Parent B (fitness=5)")
visualize_genome(parent_b, ax=axes[1])
axes[1].set_title("Parent B (fitness=5)")

# create and print crossover
child = crossover(parent_a, parent_b, population)
print_genome_state(child, "Before Mutations")
visualize_genome(child, ax=axes[2])
axes[2].set_title("Child before mutation")

# mutate weights
mutate_weights(child, mutation_rate=1.0, perturbation_strength=0.3, reset_rate=0.1)
mutate_connections(child, mutation_rate=1.0, innov=innov)
print_genome_state(child, "After Mutations")
visualize_genome(child, ax=axes[3])
axes[3].set_title("Child after mutation")

plt.tight_layout()
plt.show()

## 5. Evolutionary loop (Phenotype construction)
apply crossover over generations of genomes until population fitness threshold is met

Core loop:
1. evaluation: forward pass and evaluate against fitness function
2. speciation: nodes with similar topologies are chosen to reproduce
3. selection: those with high fitness scores are selected to reproduce
4. reproduction: crossover and mutation across selected pairs

In [None]:
class EvolutionLoop:
    """
    track the fitness and innovation numbers for the entire population
    """
    def __init__(self, innov):
        # self.fitness_fn = fitness_fn
        self.innov = innov

    """
    sort genome based on input->hidden->output structure
    """
    def topological_sort(self, genome):
        # compute the in-degree of each node
        in_degree = {id: 0 for id in genome.nodes.keys()}
        for conn in genome.connections.values():
            if conn.enabled:
                in_degree[conn.output_node_id] += 1
        
        # create a processing queue of nodes with no incoming edges (fifo)
        queue = [nid for nid, deg in in_degree.items() if deg == 0]
        topo_order = []

        # process each node in order 
        while queue:
            curr_node = queue.pop(0)
            topo_order.append(curr_node)

            # check next node
            for conn in genome.connections.values():
                if conn.enabled and conn.input_node_id == curr_node:
                    out = conn.output_node_id # find the output node
                    in_degree[out] -= 1 # decrement the output node
                    if in_degree[out] == 0: # add to queue
                        queue.append(out)
        
        # check for cycles
        if len(topo_order) != len(genome.nodes):
            raise ValueError("Cycle detected")
        
        return topo_order
    
    """
    given an input state of the world, compute the resulting actions
    - genome: node structure
    - input_values: external node values fed into the network
    """
    def forward_pass(self, genome, input_values):
        # collect intermediate values of each node in the network
        node_values = {}

        # map input values to input nodes
        for nid, val in input_values.items():
            if nid in genome.nodes:  # only process nodes that exist in this genome
                node_values[nid] = val
        
        # topological sort of nodes
        ordered_nodes = self.topological_sort(genome)

        # process each node in the network
        for nid in ordered_nodes:
            # skip all input values and nodes
            if nid in input_values and genome.nodes[nid].layer == 0:
                continue
        
            # incoming_sum = bias + Σ (source_value × weight)
            incoming_sum = genome.nodes[nid].bias
            for conn in genome.connections.values():
                if conn.enabled and conn.output_node_id == nid:
                    source_value = node_values[conn.input_node_id]
                    incoming_sum += source_value * conn.weight
            
            # apply activation function to sum (sigmoid)
            normalized_sum = max(-700, min(700, incoming_sum))
            node_values[nid] = 1 / (1 + math.exp(-normalized_sum))

        # return output nodes from the full pass
        return [
            node_values[nid] for nid in node_values
            if nid in genome.nodes and genome.nodes[nid].layer == 2 
        ]

    """
    compute fitness based on the lanchester square equation
    rewards greatest units number x firepower strength
    - input: current env state
    """
    # \(dA/dt=-bB\) and \(dB/dt=-aA\),
    def calculate_fitness(self, obs):
        blue_units = obs[1]
        red_units = obs[2]

        # blue preserved
        units_survived = blue_units / 27

        # red eliminated
        enemy_eliminated = (33 - red_units) / 33

        # victory condition (lanchester)
        blue_combat_power = obs[4] * (blue_units ** 2)
        red_combat_power = obs[5] * (red_units ** 2)
        victory_bonus = 1.0 if blue_combat_power > red_combat_power else 0.0

        # compute weighted combination
        fitness = (
            0.4 * units_survived +      # Preserve your forces
            0.3 * enemy_eliminated +     # Eliminate enemy
            0.2 * victory_bonus          # Victory condition
        )

        return fitness

    """
    create initial genome state
    """
    def create_initial_genome(self, num_inputs, num_outputs, innov):
        input_nodes = []
        output_nodes = []
        connections = []

        # create input and output
        for i in range(num_inputs):
            node = NodeGene(id=i, layer=0, activation='sigmoid', bias=random.uniform(-1, 1))
            input_nodes.append(node)
        for i in range(num_outputs):
            index = i + num_inputs
            node = NodeGene(id=index, layer=2, activation='sigmoid', bias=random.uniform(-1, 1))
            output_nodes.append(node)
        
        # create connection between each input and output
        for i in range(num_inputs):
            for j in range(num_outputs):
                in_node_id = i
                out_node_id = j + num_inputs

                innov_num = innov.get_inn_for_connection(in_node_id, out_node_id)
                weight = random.uniform(-1, 1) # init random weight

                connection = NodeConnection(id=innov_num, output_node_id=out_node_id, input_node_id=in_node_id, 
                        weight=weight, enabled=True, innovation=innov_num)
                
                connections.append(connection)

        # create genome
        all_nodes = input_nodes + output_nodes
        genome = Genome(nodes=all_nodes, connections=connections)

        return genome
        
    """
    create default population from genome state
    """
    def init_population(self, pop_size, num_inputs, num_outputs, innov):
        # initialize random population
        population = []
        for _ in range(pop_size):
            genome = self.create_initial_genome(num_inputs, num_outputs, innov)
            population.append(genome)
        
        return population
    
    def eval_population(self, population, max_steps, env):
        for genome in population:
            # run simulation episode
            env.reset()

            # perform actions at each timestamp
            for step in range(max_steps):
                obs = env.get_observation()
                normalized_obs = self.normalize_observations(obs)
                input_values = {i: normalized_obs[i+1] for i in range(5)}  # Maps 0->obs[1], 1->obs[2], etc.
                action = self.forward_pass(genome, input_values)
                obs, done = env.step(action)

                max_indx = action.index(max(action))
                print(f"timestamp {step}: "
                    f"Blue={obs[1]:.1f}, Red={obs[2]:.1f}, "
                    f"Action= ({['Forward','Retreat','Attack','Hold'][max_indx]})")
                
                if done:
                    break
                # viz.render()
                # time.sleep(0.1)  # 100ms delay per step

            # compute fitness of genome
            final_obs = env.get_observation()
            # normalized_obs = self.normalize_observations(final_obs)
            genome.fitness = self.calculate_fitness(final_obs) # set fitness variable
        
    
    def reproduce(self, population, innov):
        new_population = []
        pop_size = len(population)

        # elitism: only keep top k performers based on fitness
        # selection pressure: only allow top X% to reproduce
        # these are the top performing strategies
        FRAC = 0.8
        NUM_ELITES = max(1, int(len(population) * FRAC))
        elites = sorted(population, key=lambda g: g.fitness, reverse=True)[:NUM_ELITES]
        new_population.extend(elites) 

        # TODO: only use population pool to preserve diversity
        reproduction_pool = population

        # reproduce from elites
        while len(new_population) < pop_size:
            parent1 = random.choice(reproduction_pool)
            parent2 = random.choice(reproduction_pool)
            
            if parent1 != parent2:
                print("parent1", (parent1), "\nparent2", (parent2))
                child = crossover(parent1, parent2, reproduction_pool)
                mutate_weights(child, mutation_rate=0.6, perturbation_strength=0.5, reset_rate=0.1)
                mutate_connections(child, mutation_rate=0.6, innov=innov)
                print("child", child)
                new_population.append(child)
        
        return new_population

    def evolution_loop(self, pop_size, num_generations, env, num_inputs, num_outputs, innov):
        # initialize & eval initial env
        # we repeat the steps 100 times based on the timestamp
        max_steps = 100 
        population = self.init_population(pop_size, num_inputs, num_outputs, innov)
        self.eval_population(population, max_steps, env)

        # repeat for num of generations
        for gen in range(num_generations):
            self.eval_population(population, max_steps, env) # compute fitness of all genomes
            
            fitnesses = [g.fitness for g in population]
            print(f"Gen {gen+1}: max_fit={max(fitnesses):.4f}, "
                f"avg_fit={sum(fitnesses)/len(fitnesses):.4f}, "
                f"zero_fit (full fail)={sum(1 for f in fitnesses if f == 0)}")

            population = self.reproduce(population, innov) # assign back to population

    # ==== HELPER ====
    def normalize_observations(self, obs):
        """
        normalize observations to [0, 1]
        prevents saturating the sigmoid
        """
        ranges = {
            1: (0, 50),    # blue_units: expect 0-50
            2: (0, 50),    # red_units: expect 0-50
            3: (0, 1),     # wind_speed: already 0-1
            4: (0, 1),     # blue_firepower: already 0-1
            5: (0, 1),     # red_firepower: already 0-1
        }

        normalized = {}
        for key, val in obs.items():
            min, max = ranges[key]
            normalized[key] = (val - min) / (max - min) if max > min else val
        
        return normalized

        

## 6. test simulation

In [None]:
from environment import WargameEnv
from game import Game

env = WargameEnv()
viz = Game(env)

# obs = env.get_observation()
# action = loop.forward_pass(child, obs)
# print(action)
# env.step(action)
# print(env.step(action))
# viz.render()

# init params
innov = InnovationCounter(global_innovations=0, node_connection_innovations={})
loop = EvolutionLoop(innov)

num_inputs = 5
num_outputs = 4 # up, down, left, wright
pop_size = 60
num_generations=50

loop.evolution_loop(pop_size, num_generations, env, num_inputs, num_outputs, innov)
viz.render()
