# ARCNET
***
## TODO:

<span style="color:#f04646">**Red**</span> indicates high prio issue. 

<span style="color:#67e063">**Green**</span> indicates future goals. 

Everything else is mainly stuff that just needs to get done. I'll use issue flags once project is technically 'submitted'.

- **Transfer class functions of `ARCNET.models.arcnet.ConceptModule` to their respective locations**
    - e.g., *assembly_complexity* / *mutation* / *message_passing* should go into core / evolution / etc. 
    - update ConceptModule with gettr settr methods to call replaced methods
<br>
- **Clean up `ARCNET.core.trainer` and `ARCNET.models.arcnet.ConceptModule` and `ARCNET.evolution.fitness`**
    - bloat comments and legacy code
    - ensure loops are working as intended
<br>
- <span style="color:#FFF646">**Implement a simple assembly tracking to ensure that it is properly working**</span>
    
    - It's hard to tell with so many interconnected parts that if assembly tracking is properly happening. Some testruns need to be done with early stop output to see how hashes of layer weights are being stored and transfered. 
    - The current assembly complexity actually seems to be somewhat working (better than the original version)
<br>
- **Update output functions for graphical viewing**
    - code was copy/paste from original 8000 line jupyter notebook and not implemented properly in the new modular system, needs to be
<br>
- **Implement `tqdm` in main `core.trainer` loop**
    - better visualization of progress, less clutter on cell output
<br>
- <span style="color:#f04646">**Solve the fitness params issue**</span>
    - Too many items are currently conflicting leading to the models being unable to come to a final answer when running
    - Likely has to do with fitness caluclations and something inside `ConceptModule` as this wasn't an issue in the second gen version 
<br>
- <span style="color:#f04646">**Assembly tracking needs to track generation depth and components properly**</span>
    - current implementation works computationally but trying to visualize it seems to have issues. A better plotting function is needed
    - look into pyvis for interactive plots, allow users to explore node components more easily
<br>
- <span style="color:#67e063">**Add layer flexibility to the `ConceptModule` layer structure**</span>
    - ideally we want it to be able to change it's own shape/hyper params to better adjust to the data (e.g., self adapting organism)
<br>
- **Move the assembly stats printing below to own module**




***
***
## Methods:

#### 1. State Space Definition
Let $\mathcal{S}$ be the system state space where each module $m_i^{(t)}$ at time $t$ is defined as:

$$m_i^{(t)} = \{\mathbf{W}_i, \mathbf{b}_i, \mathbf{p}_i^{(t)}, Q_i^{(t)}, f_i^{(t)}, A_i^{(t)}\}$$

Where:

- $\mathbf{W}_i, \mathbf{b}_i$: Neural network parameters  
- $\mathbf{p}_i^{(t)} \in [0,1]^d$: Position on learned manifold $\mathcal{M}$  
- $Q_i^{(t)}$: Q-learning function (neural or tabular)  
- $f_i^{(t)} \in [0,1]$: Fitness score  
- $A_i^{(t)}$: Assembly properties  

#### 2. Core Mathematical Operations

**2.1 Manifold Learning Component**

The system learns a manifold embedding: 

$$
\mathbf{z}_i = \text{Encoder}(\bar{\mathbf{x}}) \rightarrow \mathbf{p}_i^{(0)} = \sigma(\mathbf{z}_i)
$$

Where:  
- $\mathbf{z}_i$: Latent vector for module $i$  
- $\text{Encoder}(\cdot)$: Neural encoder function  
- $\bar{\mathbf{x}}$: Input data  
- $\sigma(\cdot)$: Activation function (e.g., sigmoid)  
- $\mathbf{p}_i^{(0)}$: Initial position on manifold for module $i$  

Geodesic distance on learned manifold: 

$$
d_{\mathcal{M}}(\mathbf{p}_i, \mathbf{p}_j) = 
\begin{cases} 
|\mathbf{p}_i - \mathbf{p}_j|_2 & \text{if } \mathbf{T}_i = \emptyset \\
|\mathbf{T}_i^T(\mathbf{p}_j - \mathbf{p}_i)|_2 \cdot (1 + 0.1|\kappa_i||\mathbf{T}_i^T(\mathbf{p}_j - \mathbf{p}_i)|_2) & \text{otherwise} 
\end{cases}
$$

Where:  
- $d_{\mathcal{M}}(\mathbf{p}_i, \mathbf{p}_j)$: Geodesic distance between modules $i$ and $j$  
- $\mathbf{p}_i, \mathbf{p}_j$: Positions on manifold  
- $\mathbf{T}_i$: Tangent space at $\mathbf{p}_i$  
- $\kappa_i$: Curvature at $\mathbf{p}_i$  
- $|\cdot|_2$: Euclidean norm  

**2.2 Q-Learning Evolution Dynamics**

Each module maintains a Q-function $Q_i: \mathcal{S} \times \mathcal{A} \rightarrow \mathbb{R}$ updated via: 

$$
Q_i(s,a) \leftarrow Q_i(s,a) + \alpha[R(m_{\text{child}}) + \gamma \max_{a'} Q_i(s',a') - Q_i(s,a)]
$$

Where:  
- $Q_i(s,a)$: Q-value for state $s$ and action $a$ for module $i$  
- $\alpha$: Learning rate  
- $R(m_{\text{child}})$: Reward for child module  
- $\gamma$: Discount factor  
- $s'$: Next state  
- $a'$: Next action  
- $\max_{a'} Q_i(s',a')$: Maximum Q-value for next state  

Proof of Q-Learning Convergence: Under standard assumptions (bounded rewards, sufficient exploration), this satisfies the Robbins-Monro conditions, guaranteeing convergence to $Q^*$.

**2.3 Multi-Objective Fitness Function**

The adaptive fitness function implements a time-varying optimization:

$$
F_i^{(t)} = w_1^{(t)} \cdot \text{Accuracy}_i + w_2^{(t)} \cdot \text{Diversity}_i + w_3^{(t)} \cdot \text{Entropy}_i - \text{Penalty}_i^{(t)}
$$

Where:  
- $F_i^{(t)}$: Fitness of module $i$ at time $t$  
- $w_1^{(t)}, w_2^{(t)}, w_3^{(t)}$: Time-varying weights  
- $\text{Accuracy}_i$: Accuracy metric for module $i$  
- $\text{Diversity}_i$: Diversity metric  
- $\text{Entropy}_i$: Entropy metric  
- $\text{Penalty}_i^{(t)}$: Penalty term at time $t$  

Weights evolve as: 

$$
w^{(t)} = 
\begin{cases} 
(1.5w_n, 1.2w_e, 0.8w_a) & \text{if } t < 0.3T \\
(w_n, w_e, w_a) & \text{if } 0.3T \leq t < 0.7T \\
(0.7w_n, 0.8w_e, 1.3w_a) & \text{if } t \geq 0.7T 
\end{cases}
$$

Where:  
- $w_n, w_e, w_a$: Base weights for accuracy, diversity, and entropy  
- $T$: Total time steps  
- $t$: Current time step  

#### 3. Autocatalytic Assembly Dynamics

The assembly complexity grows according to: 

$$
A_{\text{sys}}^{(t)} = \frac{1}{|\mathcal{P}|} \sum_{i=1}^{|\mathcal{P}|} e^{a_i} \cdot \frac{n_i - 1}{|\mathcal{P}|}
$$

Where:  
- $A_{\text{sys}}^{(t)}$: System assembly complexity at time $t$  
- $|\mathcal{P}|$: Population size  
- $a_i$: Assembly complexity of module $i$  
- $n_i$: Copy number of module type $i$  

Assembly Theorem: Under controlled catalysis, the system complexity $A_{\text{sys}}^{(t)}$ is bounded by: 

$$
A_{\text{sys}}^{(t)} \leq C \cdot \log(t) + A_0
$$

Where:  
- $C$: Constant  
- $A_0$: Initial complexity  
- $t$: Time step  

Proof Sketch: The exponential assembly growth is constrained by the bias elimination mechanism and survivor selection, creating a logarithmic upper bound.

#### 4. Bias Elimination Mechanism

The system implements an adaptive bias detection function: 

$$
\text{Bias}(m_i, t) = \max_k P(y_i = k | \mathbf{x}) - \text{threshold}(t)
$$

Where:  
- $\text{Bias}(m_i, t)$: Bias for module $i$ at time $t$  
- $P(y_i = k | \mathbf{x})$: Probability of class $k$ given input $\mathbf{x}$  
- $\text{threshold}(t)$: Adaptive threshold at time $t$  

Where $\text{threshold}(t) = 0.75 + 0.2 \cdot \frac{t}{T}$ creates an adaptive tolerance.

- $T$: Total time steps  
- $t$: Current time step  

Anti-Convergence Theorem: The bias elimination ensures population diversity: 

$$
\lim_{t \rightarrow \infty} \mathbb{E}[\text{Entropy}(\mathcal{P}^{(t)})] \geq H_{\min} > 0
$$

Where:  
- $\mathbb{E}[\text{Entropy}(\mathcal{P}^{(t)})]$: Expected entropy of population at time $t$  
- $H_{\min}$: Minimum entropy bound  

#### 5. Message Passing and Information Flow

For each module $m_i$, select neighbors by manifold distance: 

$$
\mathcal{N}_{\mathcal{M}}(m_i) = \{m_j : d_{\mathcal{M}}(\mathbf{p}_i, \mathbf{p}_j) \text{ among } k \text{ smallest}\}
$$

Where:  
- $\mathcal{N}_{\mathcal{M}}(m_i)$: Set of $k$ nearest neighbors to $m_i$ on the manifold  
- $d_{\mathcal{M}}(\mathbf{p}_i, \mathbf{p}_j)$: Manifold distance between $i$ and $j$  
- $k$: Number of neighbors  

Information propagates via manifold-aware messaging: 

$$
\mathbf{M}_i^{(t)} = \sigma(g_i) \cdot \frac{1}{|\mathcal{N}_{\mathcal{M}}(m_i)|} \sum_{j \in \mathcal{N}_{\mathcal{M}}(m_i)} \frac{\mathbf{h}_j^{(t-1)}}{1 + d_{\mathcal{M}}(\mathbf{p}_i, \mathbf{p}_j)}
$$

Where:  
- $\mathbf{M}_i^{(t)}$: Message received by module $i$ at time $t$  
- $\sigma(g_i)$: Activation function applied to gating variable $g_i$  
- $|\mathcal{N}_{\mathcal{M}}(m_i)|$: Number of neighbors  
- $\mathbf{h}_j^{(t-1)}$: Hidden state of neighbor $j$ at previous time  
- $d_{\mathcal{M}}(\mathbf{p}_i, \mathbf{p}_j)$: Manifold distance  

Information Flow Theorem: Under connected manifold topology, information propagates globally in $O(\log n)$ steps.

#### 6. Main Convergence Proof

Theorem (AAN-Q Global Convergence): Under the following conditions:

- Bounded fitness landscape: $F: \mathcal{S} \rightarrow [0,1]$
- Sufficient exploration: $\epsilon$-greedy Q-learning with $\epsilon > 0$
- Controlled bias elimination: $|\text{eliminated}| \leq \alpha |\mathcal{P}|$ per step
- Manifold Lipschitz continuity: $|F(\mathbf{p}_1) - F(\mathbf{p}_2)| \leq L \cdot d_{\mathcal{M}}(\mathbf{p}_1, \mathbf{p}_2)$

The system converges to a stable configuration: 

$$
\lim_{t \rightarrow \infty} \mathbb{E}[F_{\text{best}}^{(t)}] = F^* - \delta
$$

Where:  
- $\mathbb{E}[F_{\text{best}}^{(t)}]$: Expected best fitness at time $t$  
- $F^*$: Optimal fitness  
- $\delta$: Exploration-exploitation gap  

Proof Strategy:

- Q-Learning Convergence: Standard MDP theory guarantees $Q_i \rightarrow Q_i^*$
- Manifold Regularization: The geodesic distance creates smooth fitness landscapes
- Population Dynamics: Bias elimination prevents premature convergence
- Assembly Constraints: Bounded complexity prevents runaway growth

#### 7. Computational Complexity

The algorithm has complexity:

Per Step: $O(n^2 d + n \cdot |Q|)$ where $n = |\mathcal{P}|$, $d$ = manifold dimension

Overall: $O(T \cdot n^2 d)$ for $T$ evolution steps

Where:  
- $n$: Population size  
- $d$: Manifold dimension  
- $|Q|$: Size of Q-table or Q-function  
- $T$: Number of evolution steps  

#### 8. System Complexity

**8.1 Assembly Theory Metric**

$$
A_{\text{sys}}^{(t)} = \frac{1}{|\mathcal{P}|} \sum_{i=1}^{|\mathcal{P}|} e^{a_i} \cdot \frac{n_i - 1}{|\mathcal{P}|}
$$

Where:  
- $A_{\text{sys}}^{(t)}$: System assembly complexity at time $t$  
- $|\mathcal{P}|$: Population size  
- $a_i$: Assembly complexity of module $i$  
- $n_i$: Copy number of module type $i$  

Each neural module (`ConceptModule`) tracks the minimal assembly complexity of it's weights. 
- Each layer's weights are wrapped in `ARCENET.core.ModuleComponent` that records the pathway
- When a module is mutated, new components are created with parent references, enabling reuse tracking
- The per-layer and total assemble complexity can be queried for any module with `ARCENET.models.arcnet_learner.system_assembly_complexity`
- System-level complexity is computed as the formula above, using the sum of exponentiated assembly indicies

#### 9. Complete System Evolution

The system evolves according to: 

$$
\mathcal{P}^{(t+1)} = \text{Select}(\mathcal{P}^{(t)}) \cup \text{Mutate}(\text{Select}(\mathcal{P}^{(t)}), Q^{(t)})
$$

Where:  
- $\mathcal{P}^{(t)}$: Population at time $t$  
- $\text{Select}(\cdot)$: Selection operator  
- $\text{Mutate}(\cdot, Q^{(t)})$: Mutation operator using Q-function $Q^{(t)}$  

With the objective of maximizing: 

$$
\mathcal{L} = \mathbb{E}_{t,i}\left[ R(m_i^{(t)}) \right] - \lambda A_{\text{sys}}^{(t)}
$$

Where:  
- $\mathcal{L}$: Objective function  
- $R(m_i^{(t)})$: Reward for module $i$ at time $t$  
- $\lambda$: Regularization parameter  
- $A_{\text{sys}}^{(t)}$: System assembly complexity at time $t$  

***
***

## Init

In [1]:
# import packages
import torch
import warnings
import numpy as np
import pandas as pd
import random
import sys
import os
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.datasets import fetch_openml
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report, roc_auc_score
from sklearn.exceptions import UndefinedMetricWarning

# get current working directory (avoid issues with relative imports)
sys.path.append(os.getcwd())

# import ARCNET modules
from core.trainer import (Trainer, enhanced_trainer_step_with_assembly_tracking)
from core.registry import AssemblyTrackingRegistry

# set all seeds
seeds = 42
random.seed(seeds)
np.random.seed(seeds)
torch.manual_seed(seeds)
torch.cuda.manual_seed_all(seeds)

# suppress UserWarnings from torchvision (no User path output)
warnings.filterwarnings("ignore", category=UserWarning, module='torchvision')
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)


In [2]:
# check if GPU is available
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
print(torch.cuda.is_available())  # Should print True if a compatible GPU is available
print(torch.cuda.device_count())  # Number of GPUs available
print(torch.cuda.get_device_name(0))  # Name of the first GPU

if torch.cuda.is_available():
    total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    reserved = torch.cuda.memory_reserved(0) / (1024**3)
    allocated = torch.cuda.memory_allocated(0) / (1024**3)
    free = reserved - allocated
    print(f"Total GPU memory: {total:.2f} GB")
    print(f"Reserved by PyTorch: {reserved:.2f} GB")
    print(f"Currently allocated: {allocated:.2f} GB")
    print(f"Free inside reserved: {free:.2f} GB")

    # choose to enable device
    global device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    global cpu
    cpu = torch.device("cpu")


PyTorch version: 2.7.1+cu128
CUDA version: 12.8
cuDNN version: 90701
True
1
NVIDIA GeForce RTX 4070 Laptop GPU
Total GPU memory: 8.00 GB
Reserved by PyTorch: 0.00 GB
Currently allocated: 0.00 GB
Free inside reserved: 0.00 GB


In [3]:
def statistical_breakdown(population, y_test, X_test):
    """
    Provides a statistical breakdown of the best module's performance.
    """
    
    best = max(population, key=lambda m: m.fitness)
    
    # Convert inputs to tensors if they're not already
    if isinstance(X_test, np.ndarray):
        X_test = torch.tensor(X_test, dtype=torch.float32)
    if isinstance(y_test, np.ndarray):
        y_test = torch.tensor(y_test, dtype=torch.long)
    
    # Move to same device as model
    device = next(best.parameters()).device
    X_test = X_test.to(device)
    y_test = y_test.to(device)
    
    with torch.no_grad():
        outputs = best(X_test)
        predictions = outputs.argmax(dim=1).cpu().numpy()
        y_true = y_test.cpu().numpy()
        
        # Handle probability extraction for binary classification
        if outputs.shape[1] == 2:
            probs = outputs.softmax(dim=1)[:, 1].cpu().numpy()
        else:
            # For multi-class, use max probability
            probs = outputs.softmax(dim=1).max(dim=1)[0].cpu().numpy()

    # Ensure binary classification by converting to 0/1
    y_true_binary = (y_true > 0).astype(int)
    predictions_binary = (predictions > 0).astype(int)

    # Compute metrics
    acc = accuracy_score(y_true_binary, predictions_binary)
    prec = precision_score(y_true_binary, predictions_binary, average='binary', zero_division=0)
    rec = recall_score(y_true_binary, predictions_binary, average='binary', zero_division=0)
    f1 = f1_score(y_true_binary, predictions_binary, average='binary', zero_division=0)
    auc = roc_auc_score(y_true_binary, probs)
    cm = confusion_matrix(y_true_binary, predictions_binary)
    report = classification_report(y_true_binary, predictions_binary, output_dict=True)

    print("=== Statistical Breakdown ===")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall: {rec:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"ROC AUC: {auc:.4f}")
    print("\nConfusion Matrix:")
    print(cm)
    print("\nPer-class breakdown:")
    for label in report:
        if label in ['0', '1']:
            print(f"Class {label}: Precision={report[label]['precision']:.3f}, Recall={report[label]['recall']:.3f}, F1={report[label]['f1-score']:.3f}, Support={report[label]['support']:.0f}")

***
***

## Load Datasets & Split/Encode/etc.

In [4]:
# Run the following cell to load the breast cancer dataset and split it into training and testing sets
def load_breast_cancer_data(test_size=0.2, random_state=42):
    
    """
    default test_size is 0.2 (20% of the data for testing)
    """
    
    # Load the breast cancer dataset
    data = load_breast_cancer()
    X, y = data.data, data.target

    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    return X_train, X_test, y_train, y_test

# Output
X_train, X_test, y_train, y_test = load_breast_cancer_data()

In [None]:
# Run the following cell to load the wine dataset and split it into training and testing sets
def load_wine_data(test_size=0.2, random_state=42):
    """
    default test_size is 0.2 (20% of the data for testing)
    """
    
    # Load the wine dataset
    data = fetch_openml(name='wine', version=1, as_frame=True)
    X, y = data.data, data.target

    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # Encode the labels if necessary
    label_encoder = LabelEncoder()
    y_train = label_encoder.fit_transform(y_train)
    y_test = label_encoder.transform(y_test)

    return X_train, X_test, y_train, y_test

# Output
X_train, X_test, y_train, y_test = load_wine_data()

In [None]:
# Run the following cell to load the MNIST dataset and split it into training and testing sets

# load mnist dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# convert to one-hot vector
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# resize and normalize
x_train = np.reshape(x_train, [-1, 784])
x_train = x_train.astype('float32') / 255
x_test = np.reshape(x_test, [-1, 784])
x_test = x_test.astype('float32') / 255

***
***
## Model Generation

Models are procedurely generated through the `core.trainer` file, which has a wrapper for the entire evolutionary process. See the file itself for the controllable structure, default values are passed for most arguments. 


#### Fitness method (original AAN_Manifold_Implementation_copy3.ipynb)

```python
... = Trainer(..., 
    training_method = 'fitness', 
    ...
    )
```

In [None]:
# Train the model 
allmods, lineagesnap, flname, bestmod, stats = Trainer(
    X_train=X_train,
    y_train=y_train,
    input_dim=X_train.shape[1],
    hidden_dim=30,
    output_dim=2,
    initial_population=80,
    steps=75,
    epochs=1,
    lineage_prune_rate=1500,
    lineage_kept=800,
    num_survivors=22,
    q_learning_method='neural',
    training_method='fitness',
    enable_irxn=True,
    enable_model_save=False,
    enable_bias_elimination=False,
    enable_lineage_snap=False,
    experiment_name="arcnet_modular",
    track_best_models=True,
    top_k_per_generation=5,
    debug=False
)

STARTING FRESH EVOLUTION

=== Step 0 Diversity Check ===
Module 670487: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.025
Module 116739: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.025
Module 26225: BIASED - Class 0: 98.5%, Class 1: 1.5%, Fitness: 0.057
Module 777572: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.025
Module 288389: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.025
Module 256787: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.025
Module 234053: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.025
Module 146316: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.025
Population Bias Risk: 0.961
Biased Modules: 77/80
 INTERVENTION REQUIRED: High bias detected
Step 0: Avg=0.047, Best=0.541, Q-Mem=244.2MB, Q-Exp=0
Step 1: Avg=0.139, Best=0.620, Q-Mem=91.7MB, Q-Exp=1826
Step 2: Avg=0.100, Best=0.604, Q-Mem=82.5MB, Q-Exp=2604

=== Step 3 Diversity Check ===
Module 442417: BIASED - Class 0: 96.9%, Class 1: 3.1%, Fitness: 0.114
Module 426117: B

In [10]:
# Print the best model's performance
statistical_breakdown(allmods, y_test, X_test)


=== Statistical Breakdown ===
Accuracy: 0.9035
Precision: 0.8947
Recall: 0.9577
F1 Score: 0.9252
ROC AUC: 0.9417

Confusion Matrix:
[[35  8]
 [ 3 68]]

Per-class breakdown:
Class 0: Precision=0.921, Recall=0.814, F1=0.864, Support=43
Class 1: Precision=0.895, Recall=0.958, F1=0.925, Support=71


In [11]:
# Print assembly information of best model
best_model = max(allmods, key=lambda m: m.fitness)

print("=== Assembly Information for Best Model ===")
print(f"Model ID: {best_model.id}")
print(f"Fitness: {best_model.fitness:.4f}")
print(f"Created at step: {best_model.created_at}")
print(f"Parent ID: {best_model.parent_id}")
print()

# Assembly complexity information
print("=== Assembly Complexity ===")
assembly_complexity = best_model.get_assembly_complexity()
print(f"Assembly Complexity: {assembly_complexity}")
print(f"Assembly Index: {best_model.assembly_index}")
print(f"Assembly Steps: {best_model.assembly_steps}")
print(f"Copy Number: {best_model.copy_number}")
print()

# Assembly pathway information
print("=== Assembly Pathway ===")
if hasattr(best_model, 'assembly_pathway') and best_model.assembly_pathway:
    print(f"Assembly pathway length: {len(best_model.assembly_pathway)}")
    for i, component in enumerate(best_model.assembly_pathway[:5]):  # Show first 5 components
        if hasattr(component, 'id'):
            print(f"  Component {i}: {component.id}")
        else:
            print(f"  Component {i}: {type(component).__name__}")
    if len(best_model.assembly_pathway) > 5:
        print(f"  ... and {len(best_model.assembly_pathway) - 5} more components")
else:
    print("No assembly pathway recorded")
print()

# Assembly operations
print("=== Assembly Operations ===")
if hasattr(best_model, 'assembly_operations') and best_model.assembly_operations:
    print(f"Number of assembly operations: {len(best_model.assembly_operations)}")
    for i, op in enumerate(best_model.assembly_operations):
        print(f"  Operation {i}: {op}")
else:
    print("No assembly operations recorded")
print()

# Catalytic information
print("=== Catalytic Information ===")
print(f"Is autocatalytic: {best_model.is_autocatalytic}")
print(f"Catalyzed by: {best_model.catalyzed_by}")
print(f"Catalyzes: {best_model.catalyzes}")
print()

# Layer components assembly complexity
print("=== Layer Components Assembly Complexity ===")
if hasattr(best_model, 'layer_components'):
    for layer_name, component in best_model.layer_components.items():
        if hasattr(component, 'get_minimal_assembly_complexity'):
            complexity = component.get_minimal_assembly_complexity()
            print(f"  {layer_name}: {complexity}")
        else:
            print(f"  {layer_name}: No complexity method available")
print()

# System-level assembly complexity
print("=== System Assembly Complexity ===")
try:
    from models.arcnet import system_assembly_complexity
    sys_complexity = system_assembly_complexity(allmods)
    print(f"Total system assembly complexity: {sys_complexity:.4f}")
except Exception as e:
    print(f"Could not compute system complexity: {e}")
print()

# Position and manifold information
print("=== Manifold Position Information ===")
print(f"Manifold dimension: {best_model.manifold_dim}")
print(f"Position: {best_model.position.data[:5].tolist()}...")  # Show first 5 dimensions
print(f"Curvature: {best_model.curvature}")
if best_model.position_info:
    print(f"Position info: {best_model.position_info}")
print()

# Q-learning information related to assembly
print("=== Q-Learning Assembly Information ===")
print(f"Q-learning method: {best_model.q_learning_method}")
if hasattr(best_model, 'q_function') and best_model.q_function is not None:
    if hasattr(best_model.q_function, 'replay_buffer'):
        print(f"Q-function replay buffer size: {len(best_model.q_function.replay_buffer)}")
    print(f"Q-memory usage: {best_model.get_q_memory_usage():.2f} MB")
else:
    print("No Q-function available")

=== Assembly Information for Best Model ===
Model ID: 66907
Fitness: 1.5642
Created at step: 68
Parent ID: 335476

=== Assembly Complexity ===
Assembly Complexity: {'fc1': 7, 'fc2': 7, 'fc3': 7, 'total': 21}
Assembly Index: 13
Assembly Steps: 13
Copy Number: 1

=== Assembly Pathway ===
Assembly pathway length: 3
  Component 0: bf092c722aa8433996a768256e9affb4
  Component 1: 438932d0d9ec4349941a30ef5892a962
  Component 2: 3491f8a55f5548ff8c5f8cf8557c3e1c

=== Assembly Operations ===
Number of assembly operations: 1
  Operation 0: {'type': 'mutation', 'inputs': [335476], 'catalysts': [335476, 842836], 'step': 68}

=== Catalytic Information ===
Is autocatalytic: False
Catalyzed by: [335476, 842836]
Catalyzes: [164344, 858679, 871826, 393329, 784536, 576378]

=== Layer Components Assembly Complexity ===
  fc1: 7
  fc2: 7
  fc3: 7

=== System Assembly Complexity ===
Total system assembly complexity: 0.0000

=== Manifold Position Information ===
Manifold dimension: 7
Position: [1.0, 7.088146

***
#### Loss method (NEW)

```python
... = Trainer(..., 
    epochs = 5,
    training_method = 'loss', 
    ...
    )
```

In [6]:
# Train the model 
allmods, lineagesnap, flname, bestmod, stats = Trainer(
    X_train=X_train,
    y_train=y_train,
    input_dim=X_train.shape[1],
    hidden_dim=30,
    output_dim=2,
    initial_population=80,
    steps=75,
    epochs=1,
    lineage_prune_rate=1500,
    lineage_kept=800,
    num_survivors=22,
    q_learning_method='neural',
    training_method='loss',
    enable_irxn=True,
    enable_model_save=False,
    enable_bias_elimination=False,
    enable_lineage_snap=False,
    experiment_name="arcnet_modular_loss",
    track_best_models=True,
    top_k_per_generation=5,
    debug=False
)

STARTING FRESH EVOLUTION

=== Step 0 Diversity Check ===
Module 161450: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.032
Module 522584: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.025
Module 977957: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.060
Module 751413: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.231
Module 306039: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.116
Module 533590: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.033
Module 739677: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.085
Module 286706: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.040
Population Bias Risk: 0.885
Biased Modules: 71/80
 INTERVENTION REQUIRED: High bias detected
Step 0: Avg=0.145, Best=0.442, Q-Mem=244.2MB, Q-Exp=0
Step 1: Avg=0.283, Best=0.452, Q-Mem=76.4MB, Q-Exp=1479
Step 2: Avg=0.292, Best=0.403, Q-Mem=79.5MB, Q-Exp=2444

=== Step 3 Diversity Check ===
Module 214353: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.451
Module 104180

In [6]:
# Print the best model's performance
statistical_breakdown(allmods, y_test, X_test)

=== Statistical Breakdown ===
Accuracy: 0.8772
Precision: 0.8353
Recall: 1.0000
F1 Score: 0.9103
ROC AUC: 0.9745

Confusion Matrix:
[[29 14]
 [ 0 71]]

Per-class breakdown:
Class 0: Precision=1.000, Recall=0.674, F1=0.806, Support=43
Class 1: Precision=0.835, Recall=1.000, F1=0.910, Support=71


In [7]:
# Print assembly information of best model
best_model = max(allmods, key=lambda m: m.fitness)

print("=== Assembly Information for Best Model ===")
print(f"Model ID: {best_model.id}")
print(f"Fitness: {best_model.fitness:.4f}")
print(f"Created at step: {best_model.created_at}")
print(f"Parent ID: {best_model.parent_id}")
print()

# Assembly complexity information
print("=== Assembly Complexity ===")
assembly_complexity = best_model.get_assembly_complexity()
print(f"Assembly Complexity: {assembly_complexity}")
print(f"Assembly Index: {best_model.assembly_index}")
print(f"Assembly Steps: {best_model.assembly_steps}")
print(f"Copy Number: {best_model.copy_number}")
print()

# Assembly pathway information
print("=== Assembly Pathway ===")
if hasattr(best_model, 'assembly_pathway') and best_model.assembly_pathway:
    print(f"Assembly pathway length: {len(best_model.assembly_pathway)}")
    for i, component in enumerate(best_model.assembly_pathway[:5]):  # Show first 5 components
        if hasattr(component, 'id'):
            print(f"  Component {i}: {component.id}")
        else:
            print(f"  Component {i}: {type(component).__name__}")
    if len(best_model.assembly_pathway) > 5:
        print(f"  ... and {len(best_model.assembly_pathway) - 5} more components")
else:
    print("No assembly pathway recorded")
print()

# Assembly operations
print("=== Assembly Operations ===")
if hasattr(best_model, 'assembly_operations') and best_model.assembly_operations:
    print(f"Number of assembly operations: {len(best_model.assembly_operations)}")
    for i, op in enumerate(best_model.assembly_operations):
        print(f"  Operation {i}: {op}")
else:
    print("No assembly operations recorded")
print()

# Catalytic information
print("=== Catalytic Information ===")
print(f"Is autocatalytic: {best_model.is_autocatalytic}")
print(f"Catalyzed by: {best_model.catalyzed_by}")
print(f"Catalyzes: {best_model.catalyzes}")
print()

# Layer components assembly complexity
print("=== Layer Components Assembly Complexity ===")
if hasattr(best_model, 'layer_components'):
    for layer_name, component in best_model.layer_components.items():
        if hasattr(component, 'get_minimal_assembly_complexity'):
            complexity = component.get_minimal_assembly_complexity()
            print(f"  {layer_name}: {complexity}")
        else:
            print(f"  {layer_name}: No complexity method available")
print()

# System-level assembly complexity
print("=== System Assembly Complexity ===")
try:
    from models.arcnet import system_assembly_complexity
    sys_complexity = system_assembly_complexity(allmods)
    print(f"Total system assembly complexity: {sys_complexity:.4f}")
except Exception as e:
    print(f"Could not compute system complexity: {e}")
print()

# Position and manifold information
print("=== Manifold Position Information ===")
print(f"Manifold dimension: {best_model.manifold_dim}")
print(f"Position: {best_model.position.data[:5].tolist()}...")  # Show first 5 dimensions
print(f"Curvature: {best_model.curvature}")
if best_model.position_info:
    print(f"Position info: {best_model.position_info}")
print()

# Q-learning information related to assembly
print("=== Q-Learning Assembly Information ===")
print(f"Q-learning method: {best_model.q_learning_method}")
if hasattr(best_model, 'q_function') and best_model.q_function is not None:
    if hasattr(best_model.q_function, 'replay_buffer'):
        print(f"Q-function replay buffer size: {len(best_model.q_function.replay_buffer)}")
    print(f"Q-memory usage: {best_model.get_q_memory_usage():.2f} MB")
else:
    print("No Q-function available")

=== Assembly Information for Best Model ===
Model ID: 883498
Fitness: 0.6785
Created at step: 73
Parent ID: 182138

=== Assembly Complexity ===
Assembly Complexity: {'fc1': 20, 'fc2': 20, 'fc3': 20, 'total': 60}
Assembly Index: 60
Assembly Steps: 60
Copy Number: 1

=== Assembly Pathway ===
Assembly pathway length: 3
  Component 0: c0f34851b3cd42c59eb10d0a0a7d7828
  Component 1: e64829a7ee324cb2a73080efcc24b28c
  Component 2: 6d4b6f6568344c468270be85407fcd40

=== Assembly Operations ===
Number of assembly operations: 1
  Operation 0: {'type': 'mutation', 'inputs': [182138], 'catalysts': [182138], 'step': 73}

=== Catalytic Information ===
Is autocatalytic: False
Catalyzed by: [182138]
Catalyzes: [626241]

=== Layer Components Assembly Complexity ===
  fc1: 20
  fc2: 20
  fc3: 20

=== System Assembly Complexity ===
Total system assembly complexity: 0.0000

=== Manifold Position Information ===
Manifold dimension: 7
Position: [0.0009999951580539346, 0.0010013002902269363, 0.00099999480880

***
## Comments

The assembly complexity of the loss is much larger than the pure fitness evaluation, however this could be due to `epochs` being active per generation. 


In [8]:
from pyvis.network import Network

def plot_full_component_lineage_tree_pyvis(best_model, lineagesnap):
    import networkx as nx

    # Traverse all modules in the lineage, collecting their generation step (created_at)
    lineage_modules = []
    current = best_model
    visited = set()
    while current is not None and getattr(current, 'id', None) is not None and current.id not in visited:
        lineage_modules.append(current)
        visited.add(current.id)
        parent_id = getattr(current, 'parent_id', None)
        if parent_id is not None and parent_id in lineagesnap:
            current = lineagesnap[parent_id]
        else:
            break

    # Sort modules by generation step (created_at, oldest to newest)
    lineage_modules = sorted(lineage_modules, key=lambda m: getattr(m, 'created_at', 0))
    
    # Create a mapping of component ID to module for detailed info
    comp_to_module = {}

    # Build a component graph across all modules in the lineage
    G_components = nx.DiGraph()
    node_steps = {}  # Map component id to module generation step for labeling
    for mod in lineage_modules:
        step = getattr(mod, 'created_at', 0)
        if hasattr(mod, 'assembly_pathway') and mod.assembly_pathway:
            for idx, comp in enumerate(mod.assembly_pathway):
                comp_id = getattr(comp, 'id', f'{mod.id}_comp_{idx}')
                G_components.add_node(comp_id)
                node_steps[comp_id] = step
                comp_to_module[comp_id] = (comp, mod)  # Store component and module reference
                if hasattr(comp, 'parents') and comp.parents:
                    for parent in comp.parents:
                        parent_id = getattr(parent, 'id', str(parent))
                        G_components.add_edge(parent_id, comp_id)

    # Identify final components of the best model
    final_comp_ids = set()
    if hasattr(best_model, 'assembly_pathway') and best_model.assembly_pathway:
        for comp in best_model.assembly_pathway:
            comp_id = getattr(comp, 'id', None)
            if comp_id is not None:
                final_comp_ids.add(comp_id)

    # Create PyVis network with larger nodes
    net = Network(height="800px", width="100%", directed=True, notebook=True)
    net.barnes_hut()
    
    # Configure the physics and appearance for larger nodes
    net.set_options("""
    var options = {
      "nodes": {
        "font": {
          "size": 18,
          "color": "black",
          "face": "Arial"
        },
        "borderWidth": 3,
        "size": 40,
        "scaling": {
          "min": 30,
          "max": 60
        }
      },
      "edges": {
        "color": {
          "inherit": "from"
        },
        "smooth": {
          "enabled": true,
          "type": "dynamic"
        },
        "arrows": {
          "to": {
            "enabled": true,
            "scaleFactor": 1.2
          }
        },
        "width": 2
      },
      "physics": {
        "enabled": true,
        "stabilization": {
          "enabled": true,
          "iterations": 150
        },
        "barnesHut": {
          "gravitationalConstant": -8000,
          "centralGravity": 0.3,
          "springLength": 120,
          "springConstant": 0.04,
          "damping": 0.09
        }
      },
      "interaction": {
        "hover": true,
        "selectConnectedEdges": true
      }
    }
    """)

    # Add nodes with detailed component information
    for node in G_components.nodes():
        generation = node_steps.get(node, '?')
        label = f"Gen {generation}"
        
        # Build detailed tooltip with component and layer information
        tooltip_parts = [f"<b>Component ID:</b> {node}", f"<b>Generation:</b> {generation}"]
        
        if node in comp_to_module:
            comp, mod = comp_to_module[node]
            
            # Add module information
            tooltip_parts.extend([
                f"<b>Module ID:</b> {mod.id}",
                f"<b>Module Fitness:</b> {getattr(mod, 'fitness', 'N/A'):.4f}" if hasattr(mod, 'fitness') else "<b>Module Fitness:</b> N/A",
                f"<b>Assembly Complexity:</b> {getattr(comp, 'assembly_complexity', 'N/A')}" if hasattr(comp, 'assembly_complexity') else "<b>Assembly Complexity:</b> N/A"
            ])
            
            # Add layer component information if available
            if hasattr(mod, 'layer_components'):
                tooltip_parts.append("<br><b>Layer Components:</b>")
                for layer_name, layer_comp in mod.layer_components.items():
                    if hasattr(layer_comp, 'get_minimal_assembly_complexity'):
                        layer_complexity = layer_comp.get_minimal_assembly_complexity()
                        tooltip_parts.append(f"• {layer_name}: {layer_complexity}")
                    else:
                        tooltip_parts.append(f"• {layer_name}: No complexity data")
            
            # Add assembly pathway information
            if hasattr(mod, 'assembly_pathway') and mod.assembly_pathway:
                tooltip_parts.append(f"<br><b>Assembly Pathway Length:</b> {len(mod.assembly_pathway)}")
                
            # Add parent information
            if hasattr(comp, 'parents') and comp.parents:
                parent_ids = [getattr(p, 'id', str(p)) for p in comp.parents]
                tooltip_parts.append(f"<b>Parent Components:</b> {', '.join(parent_ids[:3])}" + ("..." if len(parent_ids) > 3 else ""))
        
        title = "<br>".join(tooltip_parts)
        
        # Larger nodes with different sizes for final vs intermediate components
        if node in final_comp_ids:
            net.add_node(node, 
                         label=label, 
                         title=title,
                         color={"background": "#ff4444", "border": "#cc0000", "highlight": {"background": "#ff6666", "border": "#ff0000"}}, 
                         size=50,
                         font={'size': 20, 'color': 'white', 'face': 'Arial Bold'})
        else:
            net.add_node(node, 
                         label=label, 
                         title=title,
                         color={"background": "#ff9500", "border": "#cc7700", "highlight": {"background": "#ffaa33", "border": "#ff8800"}}, 
                         size=35,
                         font={'size': 16, 'color': 'black', 'face': 'Arial'})

    # Add edges with better styling
    for source, target in G_components.edges():
        net.add_edge(source, target, width=2, color={'color': '#666666', 'highlight': '#000000'})

    # Show with custom HTML template for better interaction
    html_template = """
    <html>
    <head>
        <title>Component Lineage Tree</title>
        <style>
            body { 
                font-family: Arial, sans-serif; 
                margin: 10px; 
                background-color: #f5f5f5;
            }
            #mynetworkid { 
                border: 2px solid #ccc; 
                border-radius: 8px;
                background-color: white;
            }
            .info { 
                margin: 10px 0; 
                padding: 10px; 
                background-color: #e7f3ff; 
                border-radius: 5px;
                border-left: 4px solid #2196F3;
            }
        </style>
    </head>
    <body>
        <div class="info">
            <h3>Component Lineage Visualization</h3>
            <p><strong>Red nodes:</strong> Final components in best model | <strong>Orange nodes:</strong> Intermediate components</p>
            <p><strong>Click on nodes</strong> to see detailed component and layer information | <strong>Drag</strong> to reposition</p>
        </div>
        <div id="mynetworkid"></div>
    </body>
    </html>
    """
    
    # Save with custom template
    net.show("component_lineage_tree.html")

    


# Usage:
plot_full_component_lineage_tree_pyvis(best_model, lineagesnap)


component_lineage_tree.html


In [9]:

from pyvis.network import Network
import networkx as nx

def plot_comprehensive_assembly_lineage(best_model, lineagesnap):
    """
    Creates a comprehensive visualization showing the full assembly lineage
    with proper component reuse tracking and hierarchical organization.
    """
    
    # First, let's gather ALL modules in the lineage and organize by generation
    all_modules = {}  # generation -> list of modules
    lineage_modules = []
    
    # Traverse the lineage
    current = best_model
    visited = set()
    while current is not None and getattr(current, 'id', None) is not None and current.id not in visited:
        lineage_modules.append(current)
        visited.add(current.id)
        generation = getattr(current, 'created_at', 0)
        if generation not in all_modules:
            all_modules[generation] = []
        all_modules[generation].append(current)
        
        parent_id = getattr(current, 'parent_id', None)
        if parent_id is not None and parent_id in lineagesnap:
            current = lineagesnap[parent_id]
        else:
            break

    print(f"Found {len(lineage_modules)} modules across {len(all_modules)} generations")
    
    # Now build a comprehensive component graph that tracks actual reuse
    G = nx.DiGraph()
    component_info = {}  # component_id -> {module, generation, layer, complexity, etc.}
    component_hierarchy = {}  # track parent-child relationships
    
    # Process each module's assembly pathway
    for mod in lineage_modules:
        generation = getattr(mod, 'created_at', 0)
        mod_id = getattr(mod, 'id', 'unknown')
        
        print(f"\nProcessing Module {mod_id} (Gen {generation}):")
        
        # Check assembly pathway
        if hasattr(mod, 'assembly_pathway') and mod.assembly_pathway:
            print(f"  Assembly pathway length: {len(mod.assembly_pathway)}")
            
            for idx, comp in enumerate(mod.assembly_pathway):
                comp_id = getattr(comp, 'id', f'{mod_id}_comp_{idx}')
                
                # Store comprehensive component information
                component_info[comp_id] = {
                    'module': mod,
                    'generation': generation,
                    'component': comp,
                    'pathway_index': idx,
                    'assembly_complexity': getattr(comp, 'assembly_complexity', 0),
                    'layer_info': f"Component {idx}"
                }
                
                G.add_node(comp_id)
                print(f"    Component {idx}: {comp_id} (complexity: {getattr(comp, 'assembly_complexity', 'N/A')})")
                
                # Track parent relationships
                if hasattr(comp, 'parents') and comp.parents:
                    for parent in comp.parents:
                        parent_id = getattr(parent, 'id', str(parent))
                        if parent_id != comp_id:  # Avoid self-loops
                            G.add_edge(parent_id, comp_id)
                            print(f"      -> Parent: {parent_id}")
                            
                            # Store hierarchy info
                            if comp_id not in component_hierarchy:
                                component_hierarchy[comp_id] = []
                            component_hierarchy[comp_id].append(parent_id)
        
        # Also check layer components for additional assembly info
        if hasattr(mod, 'layer_components'):
            print(f"  Layer components: {list(mod.layer_components.keys())}")
            for layer_name, layer_comp in mod.layer_components.items():
                if hasattr(layer_comp, 'get_minimal_assembly_complexity'):
                    complexity = layer_comp.get_minimal_assembly_complexity()
                    print(f"    {layer_name}: complexity {complexity}")

    print(f"\nTotal components in graph: {G.number_of_nodes()}")
    print(f"Total edges (reuse relationships): {G.number_of_edges()}")
    
    # Identify different types of components
    final_components = set()
    if hasattr(best_model, 'assembly_pathway') and best_model.assembly_pathway:
        for comp in best_model.assembly_pathway:
            comp_id = getattr(comp, 'id', None)
            if comp_id:
                final_components.add(comp_id)
    
    # Find root components (no parents) and leaf components (no children)
    root_components = [n for n in G.nodes() if G.in_degree(n) == 0]
    leaf_components = [n for n in G.nodes() if G.out_degree(n) == 0]
    
    print(f"Root components (no parents): {len(root_components)}")
    print(f"Leaf components (no children): {len(leaf_components)}")
    print(f"Final model components: {len(final_components)}")
    
    # Find longest paths to understand assembly chains
    longest_paths = []
    for root in root_components:
        for leaf in leaf_components:
            try:
                path = nx.shortest_path(G, root, leaf)
                longest_paths.append((len(path), path))
            except nx.NetworkXNoPath:
                continue
    
    longest_paths.sort(reverse=True)
    if longest_paths:
        print(f"Longest assembly chain: {longest_paths[0][0]} components")
        print(f"Path: {' -> '.join(longest_paths[0][1][:5])}{'...' if len(longest_paths[0][1]) > 5 else ''}")

    # Create enhanced PyVis visualization
    net = Network(height="900px", width="100%", directed=True, notebook=True)
    net.barnes_hut()
    
    # Enhanced options for better visualization
    net.set_options("""
    var options = {
      "nodes": {
        "font": {
          "size": 14,
          "color": "black",
          "face": "Arial"
        },
        "borderWidth": 2,
        "scaling": {
          "min": 20,
          "max": 80
        }
      },
      "edges": {
        "color": {
          "color": "#666666",
          "highlight": "#333333"
        },
        "arrows": {
          "to": {
            "enabled": true,
            "scaleFactor": 1
          }
        },
        "smooth": {
          "enabled": true,
          "type": "dynamic",
          "roundness": 0.5
        },
        "width": 2
      },
      "layout": {
        "hierarchical": {
          "enabled": true,
          "direction": "LR",
          "sortMethod": "directed",
          "shakeTowards": "leaves"
        }
      },
      "physics": {
        "enabled": true,
        "hierarchicalRepulsion": {
          "centralGravity": 0.3,
          "springLength": 100,
          "springConstant": 0.01,
          "nodeDistance": 120,
          "damping": 0.09
        }
      },
      "interaction": {
        "hover": true,
        "selectConnectedEdges": true,
        "navigationButtons": true,
        "keyboard": true
      }
    }
    """)

    # Add nodes with comprehensive information and size based on complexity
    for node in G.nodes():
        info = component_info.get(node, {})
        generation = info.get('generation', '?')
        complexity = info.get('assembly_complexity', 0)
        
        # Create detailed label and tooltip
        label = f"G{generation}\nC{complexity}"
        
        tooltip_parts = [
            f"<b>Component ID:</b> {node}",
            f"<b>Generation:</b> {generation}",
            f"<b>Assembly Complexity:</b> {complexity}",
            f"<b>Pathway Index:</b> {info.get('pathway_index', 'N/A')}"
        ]
        
        if 'module' in info:
            mod = info['module']
            tooltip_parts.extend([
                f"<b>Module ID:</b> {getattr(mod, 'id', 'N/A')}",
                f"<b>Module Fitness:</b> {getattr(mod, 'fitness', 'N/A'):.4f}" if hasattr(mod, 'fitness') else "<b>Module Fitness:</b> N/A"
            ])
            
            # Add layer components info
            if hasattr(mod, 'layer_components'):
                tooltip_parts.append("<br><b>Layer Components:</b>")
                for layer_name, layer_comp in mod.layer_components.items():
                    if hasattr(layer_comp, 'get_minimal_assembly_complexity'):
                        layer_complexity = layer_comp.get_minimal_assembly_complexity()
                        tooltip_parts.append(f"• {layer_name}: {layer_complexity}")

        # Add parent/child information
        parents = component_hierarchy.get(node, [])
        children = list(G.successors(node))
        if parents:
            tooltip_parts.append(f"<b>Parents:</b> {', '.join(parents[:3])}" + ("..." if len(parents) > 3 else ""))
        if children:
            tooltip_parts.append(f"<b>Children:</b> {', '.join(children[:3])}" + ("..." if len(children) > 3 else ""))
        
        title = "<br>".join(tooltip_parts)
        
        # Determine node appearance based on role and complexity
        size = max(20, min(60, 20 + complexity * 2))  # Size based on complexity
        
        if node in final_components:
            # Final components - red
            color = {"background": "#ff4444", "border": "#cc0000", "highlight": {"background": "#ff6666", "border": "#ff0000"}}
        elif node in root_components:
            # Root components - green
            color = {"background": "#44ff44", "border": "#00cc00", "highlight": {"background": "#66ff66", "border": "#00ff00"}}
        elif complexity > 10:
            # High complexity intermediate - purple
            color = {"background": "#8844ff", "border": "#6600cc", "highlight": {"background": "#aa66ff", "border": "#8800ff"}}
        else:
            # Regular intermediate - orange
            color = {"background": "#ff9500", "border": "#cc7700", "highlight": {"background": "#ffaa33", "border": "#ff8800"}}
        
        net.add_node(node, label=label, title=title, color=color, size=size)

    # Add edges
    for source, target in G.edges():
        net.add_edge(source, target)

    # Generate and save
    net.show("comprehensive_assembly_lineage.html")
    
    print("🟢 Green = Root components | 🔴 Red = Final components | 🟣 Purple = High complexity | 🟠 Orange = Regular")
    return G, component_info, longest_paths

# Usage:
assembly_graph, comp_info, chains = plot_comprehensive_assembly_lineage(best_model, lineagesnap)


Found 20 modules across 19 generations

Processing Module 883498 (Gen 73):
  Assembly pathway length: 3
    Component 0: c0f34851b3cd42c59eb10d0a0a7d7828 (complexity: N/A)
      -> Parent: d076e21440104c959be2e41c77b4b2a4
    Component 1: e64829a7ee324cb2a73080efcc24b28c (complexity: N/A)
      -> Parent: 6250897c5abd468da669d6c6dbe22431
    Component 2: 6d4b6f6568344c468270be85407fcd40 (complexity: N/A)
      -> Parent: b4251790e9944dab82862bfb8e74e9d1
  Layer components: ['fc1', 'fc2', 'fc3']
    fc1: complexity 20
    fc2: complexity 20
    fc3: complexity 20

Processing Module 182138 (Gen 72):
  Assembly pathway length: 3
    Component 0: d076e21440104c959be2e41c77b4b2a4 (complexity: N/A)
      -> Parent: 567c561e82af469ab83bd4f4d9411105
    Component 1: 6250897c5abd468da669d6c6dbe22431 (complexity: N/A)
      -> Parent: efa06686b4cc4ea0930e759b6114f027
    Component 2: b4251790e9944dab82862bfb8e74e9d1 (complexity: N/A)
      -> Parent: 577eea96c6014ac7818e862e1f9df3e2
  Layer comp

***
***
## Testing Assembly Trainer



In [5]:
# Train the model 
allmods, lineagesnap, flname, bestmod, stats, assembly_registry = Trainer(
    X_train=X_train,
    y_train=y_train,
    input_dim=X_train.shape[1],
    hidden_dim=30,
    output_dim=2,
    initial_population=80,
    steps=75,
    epochs=1,
    lineage_prune_rate=1500,
    lineage_kept=800,
    num_survivors=22,
    q_learning_method='neural',
    training_method='loss',
    enable_irxn=True,
    enable_model_save=False,
    enable_bias_elimination=False,
    enable_lineage_snap=False,
    experiment_name="arcnet_assembly_loss",
    track_best_models=True,
    top_k_per_generation=5,
    debug=False
)

STARTING FRESH EVOLUTION
Step 0: Enhanced message passing with assembly tracking for 80 modules
  Message hop 1/6


  base_message = torch.tensor(base_message.data, dtype=torch.float32)


  Message hop 2/6
  Message hop 3/6
  Message hop 4/6
  Message hop 5/6
  Message hop 6/6
Assembly Tracking Stats at Step 0:
  Components tracked: 240
  Modules tracked: 80
  Message influences: 0
  Q-learning influences: 0
  Max assembly depth: 0
  Total assembly events: 560

=== Step 0 Diversity Check ===
Module 670487: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.051
Module 116739: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.068
Module 26225: DIVERSE - Class 0: 21.8%, Class 1: 78.2%, Fitness: 0.295
Module 777572: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.053
Module 288389: BIASED - Class 0: 6.8%, Class 1: 93.2%, Fitness: 0.355
Module 256787: BIASED - Class 0: 0.0%, Class 1: 100.0%, Fitness: 0.101
Module 234053: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.114
Module 146316: BIASED - Class 0: 100.0%, Class 1: 0.0%, Fitness: 0.071
Population Bias Risk: 0.896
Biased Modules: 72/80
 INTERVENTION REQUIRED: High bias detected
Step 0: Avg=0.170, Best=0.464, 

In [6]:
statistical_breakdown(allmods, y_test, X_test)


=== Statistical Breakdown ===
Accuracy: 0.8947
Precision: 0.8554
Recall: 1.0000
F1 Score: 0.9221
ROC AUC: 0.9135

Confusion Matrix:
[[31 12]
 [ 0 71]]

Per-class breakdown:
Class 0: Precision=1.000, Recall=0.721, F1=0.838, Support=43
Class 1: Precision=0.855, Recall=1.000, F1=0.922, Support=71


In [7]:
# Print assembly information of best model
best_model = max(allmods, key=lambda m: m.fitness)

print("=== Assembly Information for Best Model ===")
print(f"Model ID: {best_model.id}")
print(f"Fitness: {best_model.fitness:.4f}")
print(f"Created at step: {best_model.created_at}")
print(f"Parent ID: {best_model.parent_id}")
print()

# Assembly complexity information
print("=== Assembly Complexity ===")
assembly_complexity = best_model.get_assembly_complexity()
print(f"Assembly Complexity: {assembly_complexity}")
print(f"Assembly Index: {best_model.assembly_index}")
print(f"Assembly Steps: {best_model.assembly_steps}")
print(f"Copy Number: {best_model.copy_number}")
print()

# Assembly pathway information
print("=== Assembly Pathway ===")
if hasattr(best_model, 'assembly_pathway') and best_model.assembly_pathway:
    print(f"Assembly pathway length: {len(best_model.assembly_pathway)}")
    for i, component in enumerate(best_model.assembly_pathway[:5]):  # Show first 5 components
        if hasattr(component, 'id'):
            print(f"  Component {i}: {component.id}")
        else:
            print(f"  Component {i}: {type(component).__name__}")
    if len(best_model.assembly_pathway) > 5:
        print(f"  ... and {len(best_model.assembly_pathway) - 5} more components")
else:
    print("No assembly pathway recorded")
print()

# Assembly operations
print("=== Assembly Operations ===")
if hasattr(best_model, 'assembly_operations') and best_model.assembly_operations:
    print(f"Number of assembly operations: {len(best_model.assembly_operations)}")
    for i, op in enumerate(best_model.assembly_operations):
        print(f"  Operation {i}: {op}")
else:
    print("No assembly operations recorded")
print()

# Catalytic information
print("=== Catalytic Information ===")
print(f"Is autocatalytic: {best_model.is_autocatalytic}")
print(f"Catalyzed by: {best_model.catalyzed_by}")
print(f"Catalyzes: {best_model.catalyzes}")
print()

# Layer components assembly complexity
print("=== Layer Components Assembly Complexity ===")
if hasattr(best_model, 'layer_components'):
    for layer_name, component in best_model.layer_components.items():
        if hasattr(component, 'get_minimal_assembly_complexity'):
            complexity = component.get_minimal_assembly_complexity()
            print(f"  {layer_name}: {complexity}")
        else:
            print(f"  {layer_name}: No complexity method available")
print()

# System-level assembly complexity
print("=== System Assembly Complexity ===")
try:
    from models.arcnet import system_assembly_complexity
    sys_complexity = system_assembly_complexity(allmods)
    print(f"Total system assembly complexity: {sys_complexity:.4f}")
except Exception as e:
    print(f"Could not compute system complexity: {e}")
print()

# Position and manifold information
print("=== Manifold Position Information ===")
print(f"Manifold dimension: {best_model.manifold_dim}")
print(f"Position: {best_model.position.data[:5].tolist()}...")  # Show first 5 dimensions
print(f"Curvature: {best_model.curvature}")
if best_model.position_info:
    print(f"Position info: {best_model.position_info}")
print()

# Q-learning information related to assembly
print("=== Q-Learning Assembly Information ===")
print(f"Q-learning method: {best_model.q_learning_method}")
if hasattr(best_model, 'q_function') and best_model.q_function is not None:
    if hasattr(best_model.q_function, 'replay_buffer'):
        print(f"Q-function replay buffer size: {len(best_model.q_function.replay_buffer)}")
    print(f"Q-memory usage: {best_model.get_q_memory_usage():.2f} MB")
else:
    print("No Q-function available")

=== Assembly Information for Best Model ===
Model ID: 2303
Fitness: 0.6369
Created at step: 14
Parent ID: 603853

=== Assembly Complexity ===
Assembly Complexity: {'fc1': 7, 'fc2': 7, 'fc3': 7, 'total': 21}
Assembly Index: 21
Assembly Steps: 10
Copy Number: 1

=== Assembly Pathway ===
Assembly pathway length: 3
  Component 0: e703bdbacd454d9fac3f39976b9cdc5f
  Component 1: b26159087e584451b04ea1deb36dac0e
  Component 2: 87fac7284d6246bfb1f60a7afaf64b1d

=== Assembly Operations ===
No assembly operations recorded

=== Catalytic Information ===
Is autocatalytic: False
Catalyzed by: [603853]
Catalyzes: [122520, 433007, 343194, 770158, 80250, 755009, 257912, 247879, 860471, 364519, 90196, 868033, 749290, 848530, 799995, 733218, 105002, 701007, 41750, 418632, 31236, 648841, 658639, 129093, 6128, 679219, 6403, 573258, 619484, 421307, 524914, 428910, 768327, 182826]

=== Layer Components Assembly Complexity ===
  fc1: 7
  fc2: 7
  fc3: 7

=== System Assembly Complexity ===
Total system assemb

In [None]:
from pyvis.network import Network
import networkx as nx

def plot_comprehensive_assembly_lineage(best_model, lineagesnap, assembly_registry=None):
    """
    Creates a comprehensive visualization showing the full assembly lineage
    with proper component reuse tracking and hierarchical organization.
    Uses all available information in assembly_registry if provided.
    """
    
    # First, let's gather ALL modules in the lineage and organize by generation
    all_modules = {}  # generation -> list of modules
    lineage_modules = []
    
    # Traverse the lineage
    current = best_model
    visited = set()
    while current is not None and getattr(current, 'id', None) is not None and current.id not in visited:
        lineage_modules.append(current)
        visited.add(current.id)
        generation = getattr(current, 'created_at', 0)
        if generation not in all_modules:
                all_modules[generation] = []
        all_modules[generation].append(current)
        
        parent_id = getattr(current, 'parent_id', None)
        if parent_id is not None and parent_id in lineagesnap:
            current = lineagesnap[parent_id]
        else:
                break

    print(f"Found {len(lineage_modules)} modules across {len(all_modules)} generations")
    
    # Now build a comprehensive component graph that tracks actual reuse
    G = nx.DiGraph()
    component_info = {}  # component_id -> {module, generation, layer, complexity, etc.}
    component_hierarchy = {}  # track parent-child relationships
    
    # If assembly_registry is provided, use it to enrich component_info and graph
    if assembly_registry is not None:
        for comp_id, comp_data in assembly_registry.component_registryitems():
            # comp_data could be a dict or an object, try to extract info
            if isinstance(comp_data, dict):
                info = comp_data.copy()
                component = info.get('component', None)
            else:
                info = {}
                component = comp_data
            # Try to extract module, generation, complexity, etc.
            module = getattr(component, 'module', None) or info.get('module', None)
            generation = getattr(component, 'generation', None) or info.get('generation', None)
            if generation is None and module is not None:
                generation = getattr(module, 'created_at', 0)
            complexity = getattr(component, 'assembly_complexity', None) or info.get('assembly_complexity', 0)
            pathway_index = getattr(component, 'pathway_index', None) or info.get('pathway_index', None)
            layer_info = getattr(component, 'layer_info', None) or info.get('layer_info', None)
            # Compose info
            component_info[comp_id] = {
                'module': module,
                'generation': generation,
                'component': component,
                'pathway_index': pathway_index,
                'assembly_complexity': complexity,
                'layer_info': layer_info
            }
            G.add_node(comp_id)
            # Add parent relationships if available
            parents = getattr(component, 'parents', None) or info.get('parents', [])
            if parents:
                for parent in parents:
                    parent_id = getattr(parent, 'id', str(parent))
                    if parent_id != comp_id:
                        G.add_edge(parent_id, comp_id)
                        if comp_id not in component_hierarchy:
                            component_hierarchy[comp_id] = []
                        component_hierarchy[comp_id].append(parent_id)
    else:
        # Process each module's assembly pathway
        for mod in lineage_modules:
            generation = getattr(mod, 'created_at', 0)
            mod_id = getattr(mod, 'id', 'unknown')
            
            print(f"\nProcessing Module {mod_id} (Gen {generation}):")
            
            # Check assembly pathway
            if hasattr(mod, 'assembly_pathway') and mod.assembly_pathway:
                print(f"  Assembly pathway length: {len(mod.assembly_pathway)}")
                
                for idx, comp in enumerate(mod.assembly_pathway):
                    comp_id = getattr(comp, 'id', f'{mod_id}_comp_{idx}')
                    
                    # Store comprehensive component information
                    component_info[comp_id] = {
                        'module': mod,
                        'generation': generation,
                        'component': comp,
                        'pathway_index': idx,
                        'assembly_complexity': getattr(comp, 'assembly_complexity', 0),
                        'layer_info': f"Component {idx}"
                    }
                    
                    G.add_node(comp_id)
                    print(f"    Component {idx}: {comp_id} (complexity: {getattr(comp, 'assembly_complexity', 'N/A')})")
                    
                    # Track parent relationships
                    if hasattr(comp, 'parents') and comp.parents:
                        for parent in comp.parents:
                            parent_id = getattr(parent, 'id', str(parent))
                            if parent_id != comp_id:  # Avoid self-loops
                                G.add_edge(parent_id, comp_id)
                                print(f"      -> Parent: {parent_id}")
                                
                                # Store hierarchy info
                                if comp_id not in component_hierarchy:
                                    component_hierarchy[comp_id] = []
                                component_hierarchy[comp_id].append(parent_id)
            
            # Also check layer components for additional assembly info
            if hasattr(mod, 'layer_components'):
                print(f"  Layer components: {list(mod.layer_components.keys())}")
                for layer_name, layer_comp in mod.layer_components.items():
                    if hasattr(layer_comp, 'get_minimal_assembly_complexity'):
                        complexity = layer_comp.get_minimal_assembly_complexity()
                        print(f"    {layer_name}: complexity {complexity}")

    print(f"\nTotal components in graph: {G.number_of_nodes()}")
    print(f"Total edges (reuse relationships): {G.number_of_edges()}")
    
    # Identify different types of components
    final_components = set()
    if hasattr(best_model, 'assembly_pathway') and best_model.assembly_pathway:
        for comp in best_model.assembly_pathway:
            comp_id = getattr(comp, 'id', None)
            if comp_id:
                final_components.add(comp_id)
    
    # Find root components (no parents) and leaf components (no children)
    root_components = [n for n in G.nodes() if G.in_degree(n) == 0]
    leaf_components = [n for n in G.nodes() if G.out_degree(n) == 0]
    
    print(f"Root components (no parents): {len(root_components)}")
    print(f"Leaf components (no children): {len(leaf_components)}")
    print(f"Final model components: {len(final_components)}")
    
    # Find longest paths to understand assembly chains
    longest_paths = []
    for root in root_components:
        for leaf in leaf_components:
            try:
                path = nx.shortest_path(G, root, leaf)
                longest_paths.append((len(path), path))
            except nx.NetworkXNoPath:
                continue
    
    longest_paths.sort(reverse=True)
    if longest_paths:
        print(f"Longest assembly chain: {longest_paths[0][0]} components")
        print(f"Path: {' -> '.join(longest_paths[0][1][:5])}{'...' if len(longest_paths[0][1]) > 5 else ''}")

    # Create enhanced PyVis visualization
    net = Network(height="900px", width="100%", directed=True, notebook=True)
    net.barnes_hut()
    
    # Enhanced options for better visualization
    net.set_options("""
    var options = {
      "nodes": {
        "font": {
          "size": 14,
          "color": "black",
          "face": "Arial"
        },
        "borderWidth": 2,
        "scaling": {
          "min": 20,
          "max": 80
        }
      },
      "edges": {
        "color": {
          "color": "#666666",
          "highlight": "#333333"
        },
        "arrows": {
          "to": {
            "enabled": true,
            "scaleFactor": 1
          }
        },
        "smooth": {
          "enabled": true,
          "type": "dynamic",
          "roundness": 0.5
        },
        "width": 2
      },
      "layout": {
        "hierarchical": {
          "enabled": true,
          "direction": "LR",
          "sortMethod": "directed",
          "shakeTowards": "leaves"
        }
      },
      "physics": {
        "enabled": true,
        "hierarchicalRepulsion": {
          "centralGravity": 0.3,
          "springLength": 100,
          "springConstant": 0.01,
          "nodeDistance": 120,
          "damping": 0.09
        }
      },
      "interaction": {
        "hover": true,
        "selectConnectedEdges": true,
        "navigationButtons": true,
        "keyboard": true
      }
    }
    """)

    # Add nodes with comprehensive information and size based on complexity
    for node in G.nodes():
        info = component_info.get(node, {})
        generation = info.get('generation', '?')
        complexity = info.get('assembly_complexity', 0)
        
        # Create detailed label and tooltip
        label = f"G{generation}\nC{complexity}"
        
        tooltip_parts = [
            f"<b>Component ID:</b> {node}",
            f"<b>Generation:</b> {generation}",
            f"<b>Assembly Complexity:</b> {complexity}",
            f"<b>Pathway Index:</b> {info.get('pathway_index', 'N/A')}"
        ]
        
        if 'module' in info and info['module'] is not None:
            mod = info['module']
            tooltip_parts.extend([
                f"<b>Module ID:</b> {getattr(mod, 'id', 'N/A')}",
                f"<b>Module Fitness:</b> {getattr(mod, 'fitness', 'N/A'):.4f}" if hasattr(mod, 'fitness') else "<b>Module Fitness:</b> N/A"
            ])
            
            # Add layer components info
            if hasattr(mod, 'layer_components'):
                tooltip_parts.append("<br><b>Layer Components:</b>")
                for layer_name, layer_comp in mod.layer_components.items():
                    if hasattr(layer_comp, 'get_minimal_assembly_complexity'):
                        layer_complexity = layer_comp.get_minimal_assembly_complexity()
                        tooltip_parts.append(f"• {layer_name}: {layer_complexity}")

        # Add parent/child information
        parents = component_hierarchy.get(node, [])
        children = list(G.successors(node))
        if parents:
            tooltip_parts.append(f"<b>Parents:</b> {', '.join(parents[:3])}" + ("..." if len(parents) > 3 else ""))
        if children:
            tooltip_parts.append(f"<b>Children:</b> {', '.join(children[:3])}" + ("..." if len(children) > 3 else ""))
        
        title = "<br>".join(tooltip_parts)
        
        # Determine node appearance based on role and complexity
        size = max(20, min(60, 20 + complexity * 2))  # Size based on complexity
        
        if node in final_components:
            # Final components - red
            color = {"background": "#ff4444", "border": "#cc0000", "highlight": {"background": "#ff6666", "border": "#ff0000"}}
        elif node in root_components:
            # Root components - green
            color = {"background": "#44ff44", "border": "#00cc00", "highlight": {"background": "#66ff66", "border": "#00ff00"}}
        elif complexity > 10:
            # High complexity intermediate - purple
            color = {"background": "#8844ff", "border": "#6600cc", "highlight": {"background": "#aa66ff", "border": "#8800ff"}}
        else:
            # Regular intermediate - orange
            color = {"background": "#ff9500", "border": "#cc7700", "highlight": {"background": "#ffaa33", "border": "#ff8800"}}
        
        net.add_node(node, label=label, title=title, color=color, size=size)

    # Add edges
    for source, target in G.edges():
        net.add_edge(source, target)

    # Generate and save
    net.show("comprehensive_assembly_lineage.html")
    
    print("🟢 Green = Root components | 🔴 Red = Final components | 🟣 Purple = High complexity | 🟠 Orange = Regular")
    return G, component_info, longest_paths

# Usage:
# If you have assembly_registry, pass it as a third argument:
assembly_graph, comp_info, chains = plot_comprehensive_assembly_lineage(best_model, lineagesnap, assembly_registry)


Found 7 modules across 7 generations


AttributeError: 'AssemblyTrackingRegistry' object has no attribute 'items'