<a href="https://colab.research.google.com/github/Arian-Space/Cortex_V0.1_proyect/blob/main/Bio_Inspired_Modular_Architecture_for_Efficient_Memory_and_Reasoning_in_Neural_Networks_Preliminary_Results_of_a_Compact_Prototype_%22Cortex_V0_1%22.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cortex-V0.1: Bio-Inspired Modular Architecture for Efficient Memory & Reasoning

## Compact Prototype on SQuAD (Preliminary Results)

**Author**: Arian Vazquez Fernandez  
**License**: Apache 2.0  
**Date**: January 2026  

This notebook implements the two-phase bio-inspired architecture described in the paper:  
- **Phase 1**: Topological routing + modern Hopfield memory (frozen after training)  
- **Phase 2**: Dynamic reasoning with Mixture of Experts (MoE), with optional ablation (dense MLP)  

All experiments run on free Colab GPU (T4 recommended). Uses Weights & Biases (wandb) for logging & sweeps.

### 1. Setup & Dependencies

Install required packages (quiet mode) and import everything at once.

In [None]:
# === 1. Install Dependencies ===
!pip install torch wandb sentence-transformers datasets scikit-learn matplotlib --quiet

# === 2. Imports ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

import wandb
from sentence_transformers import SentenceTransformer
from datasets import load_dataset

import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import random
from google.colab import userdata

# Login to Weights & Biases (store your key in Colab secrets)
wandb.login(key=userdata.get('WANDB-KEY'))

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

### 2. Load & Preprocess Data (SQuAD subset)

This preprocessing step is crucial for the bio-inspired separation of concerns:

* Unique contexts serve as "long-term memories" stored in the Hopfield module (Phase 1), mimicking hippocampal pattern completion.

* Expanded answer windows provide richer semantic targets for reasoning (Phase 2), simulating cortical integration of partial recall.
The simple keyword-based classification into "math", "history", and "facts" is not used for supervision during training — it is only collected for post-hoc visualization of emergent expert specialization via PCA. This allows us to observe whether experts naturally cluster by semantic domain without explicit labels, a key indicator of distributed, bio-like representation learning.

In [None]:
# === Load Dataset ===
print("Loading SQuAD subset (first 10k samples)...")
dataset = load_dataset("squad", split="train[:10000]")

# Phase 1: Unique contexts for memory
contexts = list(set(item['context'] for item in dataset))
print(f"{len(contexts)} unique contexts for Hopfield memory.")

# Embed contexts (frozen embedder)
embed_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
with torch.no_grad():
    context_embs = embed_model.encode(contexts, batch_size=64, convert_to_tensor=True, device=device)
    context_embs = F.normalize(context_embs, dim=-1)

# Train/eval split for memory patterns
num_train = int(len(context_embs) * 0.8)
context_embs_train = context_embs[:num_train]
context_embs_eval = context_embs[num_train:]

# Phase 2: QA pairs with expanded answers
qa_data = []
for item in dataset:
    answer_text = item['answers']['text'][0]
    context = item['context']
    start = context.find(answer_text)
    if start != -1:
        window_start = max(0, start - 60)
        window_end = min(len(context), start + len(answer_text) + 60)
        long_answer = context[window_start:window_end].strip()
    else:
        long_answer = f"The answer is {answer_text}."
    qa_data.append({
        "question": item['question'],
        "answer": long_answer
    })

# Simple question type classification (for PCA visualization only)
def classify_type(q):
    q = q.lower()
    math_keywords = ['how many', 'what is', 'calculate', 'sum', 'difference', 'product', 'square', 'root', 'area', 'perimeter', 'angle', 'prime']
    history_keywords = ['who', 'when', 'year', 'first', 'last', 'president', 'war', 'battle', 'discovered', 'invented', 'born', 'died']
    if any(kw in q for kw in math_keywords):
        return "math"
    if any(kw in q for kw in history_keywords):
        return "history"
    return "facts"

for d in qa_data:
    d['type'] = classify_type(d['question'])

print(f"QA data ready: {len(qa_data)} examples classified.")

### 3. Model Definitions

#### 3.1 Phase 1: Topological Routing + Modern Hopfield Memory

The routing layers (input_proj → gray → brown → blue_proj) gradually transform the raw question embedding into a query optimized for associative retrieval. This multi-stage processing with residual connections (gray) and recurrent refinement (brown GRU) imitates the hierarchical feature extraction and temporal binding found in biological sensory pathways. The modern Hopfield module itself has no trainable parameters — it acts purely as a content-addressable memory buffer with exponential storage capacity (as theoretically shown in Krotov & Hopfield, 2016; Ramsauer et al., 2020), enabling reliable pattern completion even from noisy or partial cues.

In [3]:
class SimpleModernHopfield(nn.Module):
    """Modern Hopfield memory module (buffer-based, no trainable params)."""
    def __init__(self, dim=384, max_patterns=10000, beta=20.0):
        super().__init__()
        self.dim = dim
        self.beta = beta
        self.register_buffer('patterns', torch.zeros(max_patterns, dim))
        self.num_stored = 0

    def store(self, new_patterns):
        n = new_patterns.shape[0]
        self.patterns[self.num_stored:self.num_stored + n] = new_patterns.float()
        self.num_stored += n

    def forward(self, query, top_k=64):
        patterns_active = self.patterns[:self.num_stored].float()
        logits = torch.matmul(query.float(), patterns_active.t()) / (self.dim ** 0.5)
        logits = self.beta * logits
        attn = F.softmax(logits, dim=-1)
        effective_k = min(top_k, self.num_stored)
        topk_weights, topk_idx = torch.topk(attn, effective_k, dim=-1)
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
        selected = patterns_active[topk_idx]
        retrieved = (topk_weights.unsqueeze(-1) * selected).sum(dim=1)
        return retrieved, topk_weights, topk_idx, logits


class MultimodalModel(nn.Module):
    """Phase 1 full model: Routing layers + Hopfield memory."""
    def __init__(self, embed_dim=384, short_mem_dim=384, hopfield_dim=384, top_k=64, beta=20.0):
        super().__init__()
        self.top_k = top_k
        self.input_proj = nn.Linear(embed_dim, short_mem_dim)
        self.gray = nn.Sequential(
            nn.Linear(short_mem_dim, short_mem_dim * 2),
            nn.ReLU(),
            nn.Linear(short_mem_dim * 2, short_mem_dim),
            nn.Tanh()
        )
        self.brown = nn.GRU(short_mem_dim, short_mem_dim, num_layers=3, batch_first=True, dropout=0.2)
        self.brown_norm = nn.LayerNorm(short_mem_dim)
        self.blue_proj = nn.Linear(short_mem_dim, hopfield_dim)
        self.violet = SimpleModernHopfield(dim=hopfield_dim, max_patterns=len(contexts), beta=beta)

    def forward(self, queries):
        x = self.input_proj(queries)
        mod = self.gray(x)
        x = x + mod
        x_short, _ = self.brown(x.unsqueeze(1))
        x_short = self.brown_norm(x_short.squeeze(1))
        query = self.blue_proj(x_short)
        retrieved, topk_weights, topk_idx, logits = self.violet(query, top_k=self.top_k)
        return retrieved, topk_weights, topk_idx, logits

#### 3.2 Phase 2: Dynamic Reasoning (MoE or Ablation MLP)

The Mixture of Experts layer introduces conditional computation and dynamic sparsity: only a small subset of experts (top-k) is activated per input, drastically reducing FLOPs while allowing specialization. This mirrors biological cortical columns where different neuronal populations handle distinct cognitive functions. The router learns to assign experts implicitly based on the retrieved memory vector, fostering emergent division of labor. The ablation (dense MLP) serves as a strong baseline to quantify the benefit of this sparse, modular routing — demonstrating whether specialization truly emerges from the interaction between noisy retrieval and expert gating.

In [4]:
class DynamicReasoningLayer(nn.Module):
    """Mixture of Experts layer for dynamic reasoning."""
    def __init__(self, dim=384, num_experts=16, expert_layers=4, top_k=4):
        super().__init__()
        self.router = nn.Linear(dim, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(*[nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(0.1)) for _ in range(expert_layers)])
            for _ in range(num_experts)
        ])
        self.top_k = top_k

    def forward(self, retrieved):
        gate_scores = F.softmax(self.router(retrieved), dim=-1)
        topk_vals, topk_idx = torch.topk(gate_scores, self.top_k, dim=-1)
        outputs = torch.zeros_like(retrieved)
        for i in range(self.top_k):
            weight = topk_vals[:, i].unsqueeze(-1)
            expert_idx = topk_idx[:, i]
            selected_out = torch.stack([self.experts[idx](retrieved[b:b+1]).squeeze(0) for b, idx in enumerate(expert_idx)])
            outputs += weight * selected_out
        return outputs, gate_scores.detach()


class AblationMLP(nn.Module):
    """Dense MLP ablation (equivalent capacity to MoE)."""
    def __init__(self, dim=384):
        super().__init__()
        self.net = nn.Sequential(*[nn.Sequential(nn.Linear(dim, dim), nn.GELU()) for _ in range(16)])

    def forward(self, retrieved):
        return self.net(retrieved), None


class Phase2Model(nn.Module):
    """Full Phase 2 model: Frozen Phase 1 + Reasoning layer."""
    def __init__(self, phase1_model, num_experts=16, expert_layers=4, top_k=4, ablation=False):
        super().__init__()
        self.phase1 = phase1_model
        if ablation:
            self.reasoning = AblationMLP()
        else:
            self.reasoning = DynamicReasoningLayer(num_experts=num_experts, expert_layers=expert_layers, top_k=top_k)

    def forward(self, query_emb):
        with torch.no_grad():
            retrieved, _, _, _ = self.phase1(query_emb)
        reasoned, gate_scores = self.reasoning(retrieved)
        return reasoned, gate_scores

### 4. Datasets for Training

Both datasets are designed for efficiency and reproducibility on Colab T4 GPUs:

* Phase 1 uses memory patterns directly (no labels needed beyond self-supervision via cosine + ranking loss).
* Phase 2 uses frozen embeddings to avoid token-level generation overhead, focusing purely on semantic vector alignment — a practical choice for rapid prototyping while still capturing high-level reasoning quality via cosine similarity.

In [5]:
# Phase 1 Dataset (memory patterns with optional noise)
class ContextDataset(Dataset):
    def __init__(self, embs, is_eval=False, noise_level=0.1):
        self.embs = embs
        self.is_eval = is_eval
        self.noise_level = noise_level

    def __len__(self):
        return len(self.embs)

    def __getitem__(self, idx):
        emb = self.embs[idx]
        if self.is_eval:
            emb = emb + self.noise_level * torch.randn_like(emb)
            emb = F.normalize(emb, dim=-1)
        return emb, idx


# Phase 2 Dataset (QA pairs + types)
class QADataset(Dataset):
    def __init__(self, qa_list, embed_model, device):
        self.query_embs = embed_model.encode([d["question"] for d in qa_list], convert_to_tensor=True, device=device)
        self.query_embs = F.normalize(self.query_embs, dim=-1)
        self.answer_embs = embed_model.encode([d["answer"] for d in qa_list], convert_to_tensor=True, device=device)
        self.answer_embs = F.normalize(self.answer_embs, dim=-1)
        self.types = [d["type"] for d in qa_list]

    def __len__(self):
        return len(self.query_embs)

    def __getitem__(self, i):
        return self.query_embs[i], self.answer_embs[i], self.types[i]

### 5. Training Phase 1 (Topological Routing + Memory)

Training is self-supervised: the model learns to map noisy/partial queries back to clean context embeddings. The combined loss (cosine similarity + ranking cross-entropy) encourages both accurate retrieval and correct ranking of top candidates. After 30 epochs, the routing layers are frozen, preserving the learned associative pathways while preventing catastrophic interference during Phase 2.

In [6]:
def train_phase1_simple():
    """Train Phase 1: Routing layers only (Hopfield is buffer)."""
    wandb.init(project="cortex-squad-phase1-final", name="run-final-dim384")

    model = MultimodalModel().to(device)
    model.violet.store(context_embs_train.to(device))

    loader = DataLoader(ContextDataset(context_embs_train), batch_size=32, shuffle=True)

    optimizer = AdamW(model.parameters(), lr=5e-4)

    for epoch in range(30):
        model.train()
        total_loss = 0
        for queries, target_idx in loader:
            queries = queries.to(device)
            target_idx = target_idx.to(device)
            optimizer.zero_grad()
            retrieved, _, topk_idx, logits = model(queries)
            target = context_embs_train[target_idx].to(device)
            loss_cos = 1 - F.cosine_similarity(retrieved, target).mean()
            loss_rank = 0.1 * F.cross_entropy(logits, target_idx)
            loss = loss_cos + loss_rank
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(loader)
        print(f"Phase 1 Epoch {epoch+1:2d} | Loss: {avg_loss:.4f}")
        wandb.log({"phase1_loss": avg_loss})

    torch.save(model.state_dict(), "phase1_final_dim384.pth")
    print("Phase 1 training completed and saved.")
    wandb.finish()

### 6. Training Phase 2 (Dynamic Reasoning with Sweep)

This phase is the core of the paper's novelty: by freezing Phase 1, we force the reasoning module to develop robustness to imperfect memory recall — a key bio-inspired property. The MoE structure must compensate for noisy or approximate retrieved vectors, leading to emergent semantic specialization (visible in PCA plots). Wandb sweeps systematically explore the hyperparameter space, allowing reliable comparison between MoE variants and the dense MLP ablation. Gate scores are accumulated per question type for visualization, confirming that specialization arises implicitly (no type labels in loss).

The PCA plots (logged every 10 epochs) are a highlight: they show progressive cluster separation from uniform distribution (epoch 0) to clear semantic groups (epoch 40), providing visual evidence of distributed, brain-like organization emerging in a compact model (<15M trainable params total).

In [7]:
def train_phase2():
    """Train Phase 2: Only reasoning layer (Phase 1 frozen)."""
    with wandb.init() as run:
        config = wandb.config
        run.name = f"lr_{config.lr}-bs_{config.batch_size}-experts_{config.num_experts}-layers_{config.expert_layers}-topk_{config.top_k}-ablation_{config.ablation}"

        # Load frozen Phase 1
        phase1_model = MultimodalModel().to(device)
        phase1_model.load_state_dict(torch.load("phase1_final_dim384.pth"))
        phase1_model.eval()
        for p in phase1_model.parameters():
            p.requires_grad = False
        phase1_model.violet.store(context_embs_train.to(device))

        model = Phase2Model(phase1_model,
                            num_experts=config.num_experts,
                            expert_layers=config.expert_layers,
                            top_k=config.top_k,
                            ablation=config.ablation).to(device)

        optimizer = AdamW(model.reasoning.parameters(), lr=config.lr)

        dataset = QADataset(qa_data, embed_model, device)
        loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

        exact_threshold = 0.9
        gate_accum = {"math": [], "history": [], "facts": []}  # List of (num_experts,) arrays

        for epoch in range(50):
            model.train()
            total_loss = 0
            total_sim = 0
            total_exact = 0

            for q_emb, a_emb, q_type in loader:
                q_emb, a_emb = q_emb.to(device), a_emb.to(device)
                optimizer.zero_grad()
                reasoned, gate_scores = model(q_emb)
                loss = 1 - F.cosine_similarity(reasoned, a_emb).mean()
                loss.backward()
                optimizer.step()

                sim = F.cosine_similarity(reasoned, a_emb)
                total_loss += loss.item() * q_emb.size(0)
                total_sim += sim.sum().item()
                total_exact += (sim > exact_threshold).sum().item()

                # Safe gate accumulation
                if gate_scores is not None:
                    gate_scores = gate_scores.detach().cpu()
                    for t in set(q_type):
                        mask = [typ == t for typ in q_type]
                        if any(mask):
                            indices = [i for i, m in enumerate(mask) if m]
                            mean_gate = gate_scores[indices].mean(0).numpy()  # (num_experts,)
                            gate_accum[t].append(mean_gate)

            avg_loss = total_loss / len(dataset)
            avg_sim = total_sim / len(dataset)
            exact_rate = total_exact / len(dataset)

            wandb.log({
                "loss": avg_loss,
                "close_answer_mean": avg_sim,
                "exact_match": exact_rate
            })
            print(f"Epoch {epoch+1:2d} | Loss: {avg_loss:.4f} | Close: {avg_sim:.4f} | Exact: {exact_rate:.4f}")

            # PCA visualization every 10 epochs (robust)
            if not config.ablation and epoch % 10 == 0 and len(gate_accum["math"]) > 0:
                valid_types = [t for t in gate_accum if len(gate_accum[t]) > 0]
                if len(valid_types) >= 2:
                    mean_gates = {}
                    for t in valid_types:
                        gates_array = np.stack(gate_accum[t])  # (n_epochs, num_experts)
                        mean_gates[t] = np.mean(gates_array, axis=0)
                    points = np.array([mean_gates[t] for t in valid_types])
                    pca = PCA(n_components=2)
                    pca_points = pca.fit_transform(points)

                    fig, ax = plt.subplots(figsize=(8, 6))
                    colors = {"math": "blue", "history": "red", "facts": "green"}
                    for i, t in enumerate(valid_types):
                        ax.scatter(pca_points[i,0], pca_points[i,1], c=colors[t], label=t, s=200)
                        ax.text(pca_points[i,0]+0.02, pca_points[i,1], t, fontsize=14)
                    ax.set_title(f"Expert Specialization - Epoch {epoch}")
                    ax.set_xlabel("PCA Component 1")
                    ax.set_ylabel("PCA Component 2")
                    ax.legend()
                    wandb.log({"expert_specialization_2d": wandb.Image(fig)})
                    plt.close(fig)

        print("Phase 2 training completed.")
        wandb.finish()

### 7. Run Experiments

Workflow summary:

1. Run Phase 1 once → generates a frozen routing + memory checkpoint.
2. Launch wandb sweep for Phase 2 → automatically tests dozens of MoE configurations + ablations.
3. Monitor in wandb dashboard: loss curves, close_answer_mean (semantic quality), exact_match, and especially the evolving PCA plots for expert specialization.

This modular, two-phase training + freezing strategy is computationally cheap (fits in free Colab), reproducible, and directly supports the paper's claims of extreme efficiency and cognitive robustness in limited-resource settings.

In [None]:
# === Run Phase 1 (run this once) ===
train_phase1_simple()

# === Sweep configuration for Phase 2 ===
sweep_config = {
    "method": "grid",
    "metric": {"name": "loss", "goal": "minimize"},
    "parameters": {
        "lr": {"values": [1e-3, 5e-4, 1e-4]},
        "batch_size": {"values": [16, 32, 64]},
        "num_experts": {"values": [8, 16, 32]},
        "expert_layers": {"values": [2, 4, 6]},
        "top_k": {"values": [2, 4, 8]},
        "ablation": {"values": [False, True]}  # Include ablation for comparison
    }
}

# Launch sweep (adjust count as needed)
sweep_id = wandb.sweep(sweep_config, project="cortex-fase2-community")
wandb.agent(sweep_id, function=train_phase2, count=30)  # Run 30 configurations