In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque, defaultdict
import pickle

import torch

I_INF = 2**31 - 1  # large int sentinel (like jnp.int64 max)

def fetch_first(mask: torch.Tensor) -> int:
    """Return the first index where mask is True, else I_INF."""
    idx = torch.nonzero(mask, as_tuple=True)[0]
    return idx[0].item() if idx.numel() > 0 else I_INF

def unflatten_conns(nodes: torch.Tensor, conns: torch.Tensor) -> torch.Tensor:
    """
    Transform the (C, CL) connections to (N, N), containing the index of each connection.
    nodes: (N, node_attrs)
    conns: (C, conn_attrs) where [:,0]=input_key, [:,1]=output_key
    """
    N = nodes.shape[0]
    C = conns.shape[0]
    
    node_keys = nodes[:, 0]
    i_keys, o_keys = conns[:, 0], conns[:, 1]
    
    def key_to_index(key, keys):
        return fetch_first(keys == key)
    
    i_idxs = torch.tensor([key_to_index(k.item(), node_keys) for k in i_keys], dtype=torch.int64)
    o_idxs = torch.tensor([key_to_index(k.item(), node_keys) for k in o_keys], dtype=torch.int64)
    
    unflatten = torch.full((N, N), I_INF, dtype=torch.int64)
    for idx, (i, o) in enumerate(zip(i_idxs, o_idxs)):
        unflatten[i, o] = idx
        
    return unflatten

def topological_sort(nodes: torch.Tensor, conns: torch.Tensor) -> torch.Tensor:
    """
    A PyTorch version of topological_sort.
    nodes: (N, node_attrs)
    conns: (N, N) unflattened connection indices (not weights, just index presence)
    """
    N = nodes.shape[0]
    
    # in_degree: count of incoming connections per node
    valid_nodes = ~torch.isnan(nodes[:, 0])  # nodes with a key
    in_degree = torch.where(valid_nodes, (conns != I_INF).sum(dim=0).float(), torch.nan)
    
    res = torch.full((N,), I_INF, dtype=torch.int64)
    idx = 0
    
    while True:
        zero_in = torch.nonzero(in_degree == 0.0, as_tuple=True)[0]
        if zero_in.numel() == 0:
            break
        
        i = zero_in[0].item()
        res[idx] = i
        idx += 1
        in_degree[i] = -1  # mark visited
        
        children = conns[i, :] != I_INF
        in_degree = torch.where(children, in_degree - 1, in_degree)
    
    return res


class Node:
    def __init__(self, id_, input_=False, output=False, bias=0.0, initial_val=1.0):
        self.id = id_
        self.is_input = input_
        self.is_output = output

        self.bias = torch.tensor([bias], dtype=torch.float64)
        self.initial_val = torch.tensor([initial_val], dtype=torch.float64)  # new attribute
        self.val = None
        
        self.num_incoming_connections = 0 
        self.received = None

# --------------------------------------------------------------------------------------------------------


class ConnectionGene:
    def __init__(self, in_node, out_node, innov_num, weight):
        self.in_node = in_node # Nodes not node id
        self.out_node = out_node
        
        self.innov_num = innov_num
        
        self.weight = weight # Weights are tensors
        self.enable = True # If node is disabled, it CAN be reenabled


# --------------------------------------------------------------------------------------------------------


class NN(nn.Module):

    # For assigning node ids to new nodes
    next_node_id = 0
    
    # key: (in, out) 
    # value: resulting node
    resulting_node_map = {}

    # key: (in, out) 
    # value: the innov_num 
    innov_num_map = {}
    next_innov_num = 0
    
    def __init__(self, input_dim, output_dim, cloned=False):
        super(NN, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.nodes = [] # Nodes objects in this specific NN
        
        self.connections_by_id = {} # Connections objects in this specific NN | in node id is key for forward lookup
        self.connections = [] # All connections for mutating
        
        # Cloned models should not be inited!
        if cloned:
            return
        else:
            # When init a new model, the node nums and innov should be the same as any other new initialized model
            # This branch is for initial population models
            # Initalize a fully connected NN with no hidden layers
            for i in range(input_dim):
                if i >= NN.next_node_id:
                    NN.next_node_id += 1
                self.nodes.append(Node(i, True))
    
            for i in range(output_dim):
                node_index = i + input_dim
                if node_index >= NN.next_node_id:
                    NN.next_node_id += 1
                self.nodes.append(Node(node_index, False, True))

                for in_id in range(input_dim):
                    in_out_tuple = (in_id, node_index)
                    
                    if in_out_tuple not in NN.resulting_node_map:
                        NN.resulting_node_map.update({in_out_tuple: None}) # No resulting node until it is split for the first time
                        NN.innov_num_map.update({in_out_tuple: NN.next_innov_num}) # No resulting node until it is split for the first time
                        NN.next_innov_num += 1
                        
                    innov_num = NN.innov_num_map[in_out_tuple]

                    conn = ConnectionGene(self.nodes[in_id], self.nodes[node_index], innov_num, torch.randn(1))
                    
                    if in_id not in self.connections_by_id:
                        self.connections_by_id[in_id] = [conn]
                    else:
                        self.connections_by_id[in_id].append(conn)

                    self.connections.append(conn)
                    
                    self.nodes[node_index].num_incoming_connections += 1 # Each output starts off fully connected to input
                            
    def clone(self):
        # Create a new NN instance with same input/output dims
        new_nn = NN(self.input_dim, self.output_dim, cloned=True)
    
        # Deep copy nodes
        new_nn.nodes = []
        id_to_node = {}
        for node in self.nodes:
            new_node = Node(node.id, node.is_input, node.is_output)
            new_node.num_incoming_connections = node.num_incoming_connections
            new_nn.nodes.append(new_node)
            id_to_node[node.id] = new_node
    
        # Deep copy connections
        new_nn.connections = []

        new_nn.connections_by_id = {}

        for conn_list in self.connections_by_id.values():  # each value is a list of connections
            for conn in conn_list:
                in_node = id_to_node[conn.in_node.id]
                out_node = id_to_node[conn.out_node.id]
        
                new_conn = ConnectionGene(
                    in_node, out_node, conn.innov_num, conn.weight.clone().detach()
                )
                new_conn.enable = conn.enable

                new_nn.connections.append(new_conn)

                if in_node.id not in new_nn.connections_by_id:
                    new_nn.connections_by_id[in_node.id] = [new_conn]
                else:
                    new_nn.connections_by_id[in_node.id].append(new_conn)
        
        return new_nn

    import torch
    from collections import deque
    import torch.nn.functional as F
    def forward(self, x: torch.Tensor):
        # Flatten square images
        x = x.view(x.size(0), -1)

        if x.shape[1] != self.input_dim:
            raise ValueError("Input dim is not correct")
        
        batch_size = x.shape[0]
        device = x.device

        # Reset node states
        for node in self.nodes:
            node.val = torch.zeros(batch_size, device=device)

        # Set input values
        for idx in range(self.input_dim):
            self.nodes[idx].val = x[:, idx]

        # Build tensors for nodes and connections
        nodes_tensor = torch.tensor([[node.id] for node in self.nodes], dtype=torch.float64, device=device)
        conns_tensor = torch.tensor(
            [[conn.in_node.id, conn.out_node.id] for conn in self.connections if conn.enable],
            dtype=torch.float64,
            device=device
        )

        # Create adjacency matrix and get topological order
        unflat = unflatten_conns(nodes_tensor, conns_tensor)        # (N, N)
        topo_order = topological_sort(nodes_tensor, unflat)         # (N,)

        # Traverse nodes in topological order
        for idx in topo_order:
            if idx == I_INF:
                break  # unused slots

            curr_node = self.nodes[int(idx.item())]

            # Skip input nodes (already set)
            if curr_node.is_input:
                continue

            # Aggregate inputs
            total_input = torch.zeros(batch_size, device=device)
            for conn in self.connections:
                if conn.enable and conn.out_node.id == curr_node.id:
                    total_input += conn.in_node.val * conn.weight

            # Now multiply by initial_val and then add bias:
            total_input = total_input * curr_node.initial_val + curr_node.bias

            # Activation:
            if curr_node.is_output:
                curr_node.val = F.relu(total_input)
            else:
                curr_node.val = torch.sigmoid(total_input)


        # Collect logits from output nodes
        output_vals = [node.val for node in self.nodes if node.is_output]
        logits = torch.stack(output_vals, dim=1)  # shape: (batch_size, num_outputs)
        return logits




    def creates_cycle(self, source, target):
        """Returns True if adding an edge from `source` to `target` would create a cycle."""
        visited = set()
    
        def dfs(node):
            if node.id in visited:
                return False
            if node == source:
                return True  # Found a path back to source — would create cycle
            visited.add(node.id)
            for conn in self.connections:
                if conn.enable and conn.in_node == node:
                    if dfs(conn.out_node):
                        return True
            return False
    
    def to(self, device):
        for node in self.nodes:
            if node.val is not None:
                node.val = node.val.to(device)
            if node.received is not None:
                node.received = node.received.to(device)
        for conn in self.connections:
            conn.weight = conn.weight.to(device)
        return self
    
    @classmethod
    def from_numpy(cls, nodes_np: np.ndarray, conns_np: np.ndarray, input_dim: int, output_dim: int):
        new_nn = cls(input_dim, output_dim, cloned=True)
        
        new_nn.nodes = []
        new_nn.connections = []
        new_nn.connections_by_id = {}
        
        id_to_node = {}
        
        for row in nodes_np:
            node_id = int(row[0])
            bias = float(row[1])
            initial_val = float(row[2])  # <-- get initial_val here
            is_input = node_id < input_dim
            is_output = input_dim <= node_id < input_dim + output_dim
            
            node = Node(node_id, is_input, is_output, bias=bias, initial_val=initial_val)
            new_nn.nodes.append(node)
            id_to_node[node_id] = node
        
        
        # Add connections (same as before)
        for row in conns_np:
            in_id, out_id, weight = int(row[0]), int(row[1]), torch.tensor([row[2]], dtype=torch.float64)
            
            in_node = id_to_node[in_id]
            out_node = id_to_node[out_id]
            
            in_out_tuple = (in_id, out_id)
            if in_out_tuple not in NN.innov_num_map:
                NN.innov_num_map[in_out_tuple] = NN.next_innov_num
                NN.resulting_node_map[in_out_tuple] = None
                NN.next_innov_num += 1
            
            innov_num = NN.innov_num_map[in_out_tuple]
            conn = ConnectionGene(in_node, out_node, innov_num, weight)
            
            if in_id not in new_nn.connections_by_id:
                new_nn.connections_by_id[in_id] = [conn]
            else:
                new_nn.connections_by_id[in_id].append(conn)
            
            new_nn.connections.append(conn)
            out_node.num_incoming_connections += 1
        
        NN.next_node_id = max(NN.next_node_id, max(id_to_node.keys()) + 1)
        
        return new_nn


    @classmethod
    def from_pickle(cls, filepath: str, input_dim: int, output_dim: int):
        with open(filepath, "rb") as f:
            data = pickle.load(f)

        # If you saved arrays directly
        if isinstance(data, tuple) and len(data) == 2:
            nodes_np, conns_np = data
        # If saved as dict
        elif isinstance(data, dict):
            nodes_np = data["nodes"]
            conns_np = data["conns"]
        else:
            raise ValueError("Unsupported pickle format")

        return cls.from_numpy(nodes_np, conns_np, input_dim, output_dim)




def reset_NN_class_state():
    NN.next_node_id = 0
    NN.resulting_node_map = {}
    NN.innov_num_map = {}
    NN.next_innov_num = 0

In [79]:
import numpy as np
import torch
torch.set_default_dtype(torch.float64)
# Load .npz file
data = np.load("models/genomes/49.npz")

with open("models/4.pkl", "rb") as f:
    data = pickle.load(f)

# Access arrays by keys
X_val = torch.tensor(data["nodes"], dtype=torch.float64)
Y_val = torch.tensor(data["conns"], dtype=torch.float64)

# Convert and remove NaNs
nodes_clean = data["nodes"][~np.isnan(data["nodes"]).any(axis=1)]
conns_clean = data["conns"][~np.isnan(data["conns"]).any(axis=1)]

X_val = torch.tensor(nodes_clean, dtype=torch.float64)
Y_val = torch.tensor(conns_clean, dtype=torch.float64)
print(X_val)
model = NN.from_numpy(X_val, Y_val, input_dim=784, output_dim=10)


tensor([[ 0.0000e+00, -1.9828e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0000e+00, -1.3459e+00,  0.0000e+00,  0.0000e+00],
        [ 2.0000e+00,  3.5875e-01,  0.0000e+00,  0.0000e+00],
        ...,
        [ 7.9300e+02, -2.4648e-01,  0.0000e+00,  0.0000e+00],
        [ 3.9411e+04, -8.0786e-01,  0.0000e+00,  0.0000e+00],
        [ 4.0350e+05,  1.3595e-02,  0.0000e+00, -1.0000e+00]])


In [80]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F

model.eval
def evaluate_model(model, batch_size=256):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1))  # Flatten 28x28 → 784
    ])

    test_dataset = datasets.MNIST(
        root="./data",
        train=True,
        transform=transform,
        download=True
    )
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    eps = 1e-8
    total_loss = 0.0
    total_samples = 0
    correct = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)                     # [batch, num_classes]
            probs = F.softmax(outputs, dim=1)           # match JAX preds
            Y_onehot = F.one_hot(labels, num_classes=10).float()

            # Manual cross-entropy: -mean(sum(Y * log(preds)))
            loss = -(Y_onehot * torch.log(probs + eps)).sum(dim=1).mean()

            total_loss += loss.item() * labels.size(0)
            total_samples += labels.size(0)

            predicted_classes = torch.argmax(probs, dim=-1)
            correct += (predicted_classes == labels).sum().item()

    avg_loss = total_loss / total_samples
    accuracy = correct / total_samples

    print(f"Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return accuracy


In [81]:
evaluate_model(model)

Test Loss: 2.4509, Accuracy: 0.0986


0.09863333333333334