# C.A.S.T.L.E.

#### Component Assembly Structure Tracking for Learning Emergence

Castle is an assembly tracking neural network, which partitions weight tensors into `molecular blocks` and assigns `atomic codes` based on block statistics. The resulting `molecular lattice` encodes the network's structural motifs, enabling calculation of an assembly index that quantifies **modularity and reuse**. During training, the `loss function` incorporates a `complexity reward` or penalty derived from the assembly index, and gradients are modulated according to molecular reuse. If the assembly index exceeds a threshold, the architecture is evolved to favor more efficient or interpretable structures. This approach bridges neural network optimization with principles from assembly theory, promoting modularity, interpretability, and adaptive architectural evolution.

***

## Mathematical Model of Assembly Tracking Neural Network with Atomic Codes

## 1. Weight Tensor Partitioning

Let $W \in \mathbb{R}^{m \times n}$ be the weight matrix of a neural network layer, where:
- $m$: number of input features,
- $n$: number of output features.

Partition $W$ into non-overlapping blocks (molecules) of size $k \times k$:

$$W = \bigcup_{i=1}^{M} \bigcup_{j=1}^{N} W_{i,j}$$

where:
- $W_{i,j}$: block (molecule) at position $(i, j)$,
- $M = \lfloor m/k \rfloor$: number of blocks along rows,
- $N = \lfloor n/k \rfloor$: number of blocks along columns,
- $k$: molecule (block) size.

## 2. Atomic Code Assignment

For each molecule $W_{i,j}$, compute its average value:

$$\mu_{i,j} = \frac{1}{k^2} \sum_{a=1}^{k} \sum_{b=1}^{k} W_{i,j}[a, b]$$

where:
- $\mu_{i,j}$: average value of molecule $W_{i,j}$,
- $a, b$: indices within the block.

Assign an atomic code $S_{i,j}$ using a discretization function $f$:

$$S_{i,j} = f(\mu_{i,j})$$

where:
- $S_{i,j}$: atomic code assigned to molecule $W_{i,j}$,
- $f$: function mapping average value to a symbol (e.g., 'Ze', 'Sm', 'Md').

## 3. Lattice Construction

The set of atomic codes forms a molecular lattice $L$:

$$L = \{ S_{i,j} \mid 1 \leq i \leq M, 1 \leq j \leq N \}$$

where:
- $L$: molecular lattice,
- $S_{i,j}$: atomic code at position $(i, j)$.

## 4. Assembly Index Calculation

Define the assembly index $A(L)$ as the minimal number of unique atomic codes or assembly steps required to construct $L$:

$$A(L) = \min \left( \text{number of steps to assemble } L \text{ from atomic codes and reused sub-lattices} \right)$$

where:
- $A(L)$: assembly index of lattice $L$,
- "steps": count of unique codes and reused patterns needed for construction.

## 5. Complexity Reward in Loss Function

During training, modify the loss function to reward or penalize assembly complexity:

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{base}} - \lambda \cdot R(A(L))$$

where:
- $\mathcal{L}_{\text{total}}$: total loss,
- $\mathcal{L}_{\text{base}}$: standard loss (e.g., binary cross-entropy),
- $R(A(L))$: reward function based on assembly index,
- $\lambda$: hyperparameter controlling reward strength.

## 6. Gradient Modification

Scale gradients for each molecule according to reuse and complexity:

$$\nabla W_{i,j} \leftarrow \nabla W_{i,j} \cdot g(S_{i,j}, A(L))$$

where:
- $\nabla W_{i,j}$: gradient for molecule $W_{i,j}$,
- $g$: modifier function (e.g., reduces updates for highly reused molecules),
- $S_{i,j}$: atomic code,
- $A(L)$: assembly index.

## 7. Architecture Evolution

If assembly index exceeds a threshold, suggest architectural changes:

$$\text{If } A(L) > T, \text{ then modify layer sizes or add/remove layers}$$

where:
- $T$: complexity threshold for architectural evolution,
- $A(L)$: assembly index.

## 8. System Complexity

### 8.1 Assembly Theory Metric

$$A_{\text{sys}}^{(t)} = \frac{1}{|\mathcal{P}|} \sum_{i=1}^{|\mathcal{P}|} e^{a_i} \cdot \frac{n_i - 1}{|\mathcal{P}|}$$

Where:  
- $A_{\text{sys}}^{(t)}$: System assembly complexity at time $t$  
- $|\mathcal{P}|$: Population size  
- $a_i$: Assembly complexity of module $i$  
- $n_i$: Copy number of module type $i$  

## 9. Refined Assembly Theory Metric for Neural Networks

To better capture the modularity and reuse in neural networks, we refine the assembly theory metric by explicitly accounting for the diversity and recurrence of molecular patterns (atomic codes) in the lattice:

Let $\mathcal{S} = \{ S_{i,j} \}$ be the set of atomic codes in lattice $L$, and let $u_s$ be the number of unique codes, $n_s$ the copy number of code $s$, and $a_s$ the assembly complexity of code $s$.

Define the refined system assembly complexity as:

$$A_{\text{sys}}^{(t)} = \frac{1}{u_s} \sum_{s \in \mathcal{S}} e^{a_s} \cdot \frac{n_s - 1}{|\mathcal{S}|}$$

where:
- $u_s$: number of unique atomic codes in $L$,
- $n_s$: number of occurrences of code $s$ in $L$,
- $a_s$: assembly complexity of code $s$ (e.g., minimal steps to construct $s$ from primitives),
- $|\mathcal{S}|$: total number of molecules in $L$.

**Interpretation:**  
This metric rewards reuse (high $n_s$) and penalizes diversity (high $u_s$), while weighting by the intrinsic complexity $a_s$ of each code. The exponential term amplifies the impact of complex codes, and the normalization by $u_s$ and $|\mathcal{S}|$ ensures comparability across networks.

**Application in Training:**  
- Use $A_{\text{sys}}^{(t)}$ as a regularizer in the loss function to promote modularity and efficient reuse.
- Track $A_{\text{sys}}^{(t)}$ over epochs to monitor the evolution of network complexity and modularity.
- Suggest architectural changes if $A_{\text{sys}}^{(t)}$ exceeds a threshold, indicating excessive diversity or insufficient reuse.

## 10. Complexity-Accuracy Relationship

Let $\mathcal{A}(t)$ denote the model accuracy at time $t$, and $A_{\text{sys}}^{(t)}$ represent the system assembly complexity. The relationship between accuracy and complexity can be modeled as:

$$\mathcal{A}(t) = \mathcal{A}_{\text{base}} + \gamma \cdot \left(1 - e^{-\beta \cdot A_{\text{sys}}^{(t)}}\right)$$

where:
- $\mathcal{A}_{\text{base}}$: base accuracy achievable with minimal complexity,
- $\gamma$: maximum potential accuracy improvement from complexity,
- $\beta$: rate parameter controlling how quickly complexity benefits manifest.

This formulation captures three key properties:
1. As $A_{\text{sys}}^{(t)} \to 0$, $\mathcal{A}(t) \to \mathcal{A}_{\text{base}}$,
2. As $A_{\text{sys}}^{(t)} \to \infty$, $\mathcal{A}(t) \to \mathcal{A}_{\text{base}} + \gamma$,
3. The marginal benefit of additional complexity diminishes as complexity increases, reflecting the principle of diminishing returns.

### Alternative Formulation for Small Complexity Ranges

For practical application with small ranges of complexity, a linear approximation may be suitable:

$$\mathcal{A}(t) \approx \mathcal{A}_{\text{min}} + \delta \cdot A_{\text{sys}}^{(t)}$$

where:
- $\mathcal{A}_{\text{min}}$: minimum accuracy with zero complexity,
- $\delta$: linear rate of accuracy improvement with complexity.

## 11. Empirical Validation

This relationship can be validated by:
1. Measuring $A_{\text{sys}}^{(t)}$ and $\mathcal{A}(t)$ at multiple training epochs,
2. Fitting the model parameters ($\mathcal{A}_{\text{base}}$, $\gamma$, $\beta$) or ($\mathcal{A}_{\text{min}}$, $\delta$),
3. Testing predictive power on held-out model configurations.

***

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, confusion_matrix, classification_report
from sklearn.datasets import load_breast_cancer, load_iris, load_wine

import matplotlib.pyplot as plt
import seaborn as sns
import warnings

# self-made imports
from core.trainer import train_complexity_aware, train_basic, train_reward_complexity
from core.analyzer import validate_complexity_accuracy_relationship, analyze_complexity_accuracy_relationship
from visualize.lattice_builder import plot_pyvis_3d_lattice_interactive

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
torch.manual_seed(42)
np.random.seed(42)

In [7]:
# Data Parameters
test_size = 0.2         # Proportion of the dataset to include in the test split (default for all functions is 0.2))
random_state = 42

# Data Loading Functions
def load_breast_cancer_data(test_size=0.2, random_state=42):
    # Load the breast cancer dataset
    data = load_breast_cancer()
    X, y = data.data, data.target
    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    return X_train, X_test, y_train, y_test

def load_iris_data(test_size=0.2, random_state=42):
    # Load the iris dataset
    data = load_iris()
    X, y = data.data, data.target
    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
    return X_train, X_test, y_train, y_test

def load_wine_data(test_size=0.2, random_state=42):
    # Load the wine dataset
    data = load_wine()
    X, y = data.data, data.target
    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
    return X_train, X_test, y_train, y_test

def load_mnsit_data():
    # Load the MNIST dataset
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test)
    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.
    x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
    return x_train, x_test, y_train, y_test

# Output
try:
    X_train_cancer, X_test_cancer, y_train_cancer, y_test_cancer = load_breast_cancer_data(test_size, random_state)
    X_train_iris, X_test_iris, y_train_iris, y_test_iris = load_iris_data(test_size, random_state)
    X_train_wine, X_test_wine, y_train_wine, y_test_wine = load_wine_data(test_size, random_state)
    X_train_mnist, X_test_mnist, y_train_mnist, y_test_mnist = load_mnsit_data()
    print("Data loaded successfully.")
except Exception as e:
    print(f"An error occurred while loading the data: {e}")

# Print pd dataframe of dataset's shape, classes, and number of features
pd.DataFrame({
    "Dataset": ["Breast Cancer", "Iris", "Wine", "MNIST"],
    "Shape of Training Data": [X_train_cancer.shape, X_train_iris.shape, X_train_wine.shape, X_train_mnist.shape],
    "Number of Classes": [len(np.unique(y_train_cancer)), len(np.unique(y_train_iris)), len(np.unique(y_train_wine)), len(np.unique(y_train_mnist))],
    "Number of Features": [X_train_cancer.shape[1], X_train_iris.shape[1], X_train_wine.shape[1], X_train_mnist.shape[1]],
    "Classes": [np.unique(y_train_cancer), np.unique(y_train_iris), np.unique(y_train_wine), np.unique(y_train_mnist)]
}).set_index("Dataset").T.applymap(lambda x: x if isinstance(x, str) else str(x)).style.set_properties(**{'text-align': 'left'}).set_table_attributes('style="width: 100%; border-collapse: collapse;"').set_caption("Dataset Information").format(na_rep='N/A')   


Data loaded successfully.


Dataset,Breast Cancer,Iris,Wine,MNIST
Shape of Training Data,"(455, 30)","(120, 4)","(142, 13)","(60000, 784)"
Number of Classes,2,3,3,2
Number of Features,30,4,13,784
Classes,[0 1],[0 1 2],[0 1 2],[0. 1.]


***

## Testing and Validation

Let $\mathcal{A}(t)$ denote the model accuracy at time $t$, and $A_{\text{sys}}^{(t)}$ represent the system assembly complexity. The relationship between accuracy and complexity can be modeled as:

$$\mathcal{A}(t) = \mathcal{A}_{\text{base}} + \gamma \cdot \left(1 - e^{-\beta \cdot A_{\text{sys}}^{(t)}}\right)$$

where:
- $\mathcal{A}_{\text{base}}$: base accuracy achievable with minimal complexity,
- $\gamma$: maximum potential accuracy improvement from complexity,
- $\beta$: rate parameter controlling how quickly complexity benefits manifest.

This formulation captures three key properties:
1. As $A_{\text{sys}}^{(t)} \to 0$, $\mathcal{A}(t) \to \mathcal{A}_{\text{base}}$,
2. As $A_{\text{sys}}^{(t)} \to \infty$, $\mathcal{A}(t) \to \mathcal{A}_{\text{base}} + \gamma$,
3. The marginal benefit of additional complexity diminishes as complexity increases, reflecting the principle of diminishing returns.

#### Alternative Formulation for Small Complexity Ranges

For practical application with small ranges of complexity, a linear approximation may be suitable:

$$\mathcal{A}(t) \approx \mathcal{A}_{\text{min}} + \delta \cdot A_{\text{sys}}^{(t)}$$

where:
- $\mathcal{A}_{\text{min}}$: minimum accuracy with zero complexity,
- $\delta$: linear rate of accuracy improvement with complexity.

>```python
>def exponential_model(complexity, A_base, gamma, beta):
>    """
>    Exponential model for accuracy-complexity relationship:
>    A(t) = A_base + γ * (1 - exp(-β * A_sys(t)))
>    """
>    return A_base + gamma * (1 - np.exp(-beta * complexity))
>
>def linear_model(complexity, A_min, delta):
>    """
>    Linear model for accuracy-complexity relationship:
>    A(t) ≈ A_min + δ * A_sys(t)
>    """
>    return A_min + delta * complexity
>```