# Bicameral Manifold Training (Colab Edition)

This notebook trains the **SmolLM3-3B** model to differentiate into two distinct manifolds:
1.  **Logic Manifold:** Fractal Dimension $D_H \approx 1.5$ (Correlated, Low Entropy)
2.  **Creative Manifold:** Fractal Dimension $D_H \approx 2.5$ (Expansive, High Entropy)

It uses the **Neural Fractal Estimator (NFE)** critic you trained previously to enforce these topological constraints.

### Pre-requisites
1.  You must have `nfe_critic.pt` (trained weights from the previous notebook).

## 1. Setup Environment

In [None]:
# Setup Environment & Dependencies
import os
import sys
import importlib.util
from pathlib import Path

# 1. Dynamic Path Setup
current_dir = Path.cwd()
project_root = None

for parent in [current_dir] + list(current_dir.parents):
    if (parent / "core").exists():
        project_root = parent
        break

if project_root:
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))
    print(f"Project root found and added to path: {project_root}")
else:
    if not os.path.exists("Bicameral_Manifold_GMoE") and not os.path.exists("core"):
        print("Cloning repository...")
        get_ipython().system('git clone https://github.com/angrysky56/Bicameral_Manifold_GMoE.git')
        get_ipython().run_line_magic('cd', 'Bicameral_Manifold_GMoE')
        sys.path.append(".")

# 2. Install/Fix Dependencies
if importlib.util.find_spec("tensorflow") is not None:
    print("TensorFlow found. Uninstalling to prevent conflicts with Transformers...")
    get_ipython().run_line_magic('pip', 'uninstall -y tensorflow')

if importlib.util.find_spec("torch") is None or importlib.util.find_spec("transformers") is None:
    get_ipython().run_line_magic('pip', 'install -q torch transformers accelerate safetensors numpy tqdm matplotlib')

# 3. Check Hardware Acceleration (GPU or TPU)
import torch
device_found = False

# Check TPU (PyTorch XLA)
if importlib.util.find_spec("torch_xla") is not None:
    try:
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        print(f"\nSUCCESS: TPU Detected: {device}")
        device_found = True
    except ImportError:
        pass

# Check CUDA
if not device_found and torch.cuda.is_available():
    print(f"\nSUCCESS: GPU Detected: {torch.cuda.get_device_name(0)}")
    device_found = True

if not device_found:
    print("\n" + "="*60)
    print("WARNING: NEITHER CUDA GPU NOR TPU DETECTED!")
    print("Training will be extremely slow on CPU.")
    print("If on Colab: Runtime -> Change runtime type -> T4 GPU or TPU v2")
    print("="*60 + "\n")


## 2. Upload Trained NFE Critic
Upload your `nfe_critic.pt` file here.

In [None]:
import shutil
import os
from pathlib import Path

# Create checkpoint dir
os.makedirs("models/checkpoints", exist_ok=True)

# Handle NFE Critic File
critic_filename = "nfe_critic.pt"
target_path = f"models/checkpoints/{critic_filename}"

try:
    from google.colab import files
    print("Google Colab detected. Please upload nfe_critic.pt")
    uploaded = files.upload()
    for filename in uploaded.keys():
        shutil.move(filename, target_path)
        print(f"Saved {filename} to {target_path}")
except ImportError:
    print("Local environment detected (no google.colab).")
    # Check local paths
    possible_paths = [
        Path(critic_filename),
        Path("..") / critic_filename,
        # Check if project_root is defined in previous cell context, otherwise guess
        Path("../nfe_critic.pt"), 
        Path("nfe_critic.pt")
    ]
    
    found = False
    for p in possible_paths:
        if p and p.exists():
            print(f"Found {critic_filename} at {p}")
            shutil.copy(p, target_path)
            print(f"Copied to {target_path}")
            found = True
            break
    
    if not found:
        print(f"WARNING: {critic_filename} not found. Ensure it is in the notebook directory or project root.")


## 3. Training Script
We will run the training logic directly here to allow monitoring.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.notebook import tqdm
from core.modeling_bicameral import convert_to_bicameral
import matplotlib.pyplot as plt
from IPython.display import clear_output
import importlib.util

# Configuration
MODEL_ID = "HuggingFaceTB/SmolLM3-3B"

# Smart Device Selection
DEVICE = "cpu"
USE_DTYPE = torch.float32

# 1. Try TPU
if importlib.util.find_spec("torch_xla") is not None:
    try:
        import torch_xla.core.xla_model as xm
        DEVICE = xm.xla_device()
        USE_DTYPE = torch.bfloat16 # TPUs work best with bfloat16
        print(f"Selected Device: TPU ({DEVICE})")
    except ImportError:
        pass

# 2. Try GPU if no TPU
if str(DEVICE) == "cpu" and torch.cuda.is_available():
    DEVICE = "cuda"
    USE_DTYPE = torch.float16 # GPUs work well with fp16
    print(f"Selected Device: GPU ({torch.cuda.get_device_name(0)})")

if str(DEVICE) == "cpu":
    print("Selected Device: CPU (Warning: Slow)")

# Load Model & Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Loading model with dtype={USE_DTYPE}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=USE_DTYPE)
model = convert_to_bicameral(model, use_acc=True, use_soft_routing=True)
model.to(DEVICE)

print("Model converted to Bicameral (ACC + Soft Routing)")

In [None]:
# Define Datasets (Synthetic Prompts)
logic_prompts = [
    "Solve this math problem: 24 * 7 + 12",
    "Calculate the integral of x squared",
    "If P implies Q and Q implies R, then",
    "Write a python function to sort a list",
    "Explain the theory of relativity step by step",
    "Derive the quadratic formula",
    "Logic puzzle: Three people enter a room...",
    "Analyze the time complexity of bubble sort",
    "What is the capital of France and its population?",
    "Debug this following C++ code snippet"
] * 50  # Repeat for dataset size

creative_prompts = [
    "Write a poem about a lonely robot",
    "Describe a color that doesn't exist",
    "Invent a new mythology for Mars",
    "Write a story starting with 'The clock struck 13'",
    "Imagine a world where sound is visible",
    "Compose a song lyric about rain",
    "Describe the taste of starlight",
    "Write a dialogue between a cat and a ghost",
    "Create a recipe for a magical potion",
    "Stream of consciousness about the ocean"
] * 50

In [None]:
def train_phase(phase_name, data, target_id, model, tokenizer, steps=500):
    print(f"\n=== Starting {phase_name} Phase (Target D={target_id}) ===")

    # Set Forced Mode for Training
    mode_str = "LOGIC" if phase_name == "Logic" else "CREATIVE"
    for layer in model.model.layers:
        layer.mlp.forced_mode = mode_str

    # Freeze/Thaw Parameters
    trainable_params = []
    for n, p in model.named_parameters():
        # Train relevant expert and the ACC parameters
        if (phase_name.lower() in n) or ("router" in n) or ("inhibitory" in n):
            p.requires_grad = True
            trainable_params.append(p)
        else:
            p.requires_grad = False

    optimizer = optim.AdamW(trainable_params, lr=5e-5)

    # Tracking
    loss_history = []
    id_history = []

    pbar = tqdm(range(steps))
    data_iter = iter(data * 100) # Infinite loop hack

    for i in pbar:
        text = next(data_iter)
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)

        outputs = model(**inputs, labels=inputs["input_ids"])
        lm_loss = outputs.loss

        # Topological Loss
        # Ensure we capture IDs from forward pass
        ids = [layer.mlp.cached_id for layer in model.model.layers if layer.mlp.cached_id is not None]
        if ids:
            avg_id = torch.stack(ids).mean()
            # Stronger penalty: Weighted MSE
            topo_loss = 2.0 * (avg_id - target_id)**2
        else:
            avg_id = torch.tensor(2.0)
            topo_loss = torch.tensor(0.0).to(DEVICE)

        total_loss = lm_loss + topo_loss

        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_history.append(total_loss.item())
        id_history.append(avg_id.item())

        pbar.set_postfix({'Loss': f"{total_loss.item():.3f}", 'ID': f"{avg_id.item():.2f}"})

        if i % 50 == 0:
            # Live Plotting
            clear_output(wait=True)
            plt.figure(figsize=(10, 4))
            plt.plot(id_history, label='Avg Intrinsic Dimension')
            plt.axhline(y=target_id, color='r', linestyle='--', label=f'Target D={target_id}')
            plt.title(f"{phase_name} Phase Training")
            plt.legend()
            plt.grid(True)
            plt.show()

    return id_history


# --- RUN TRAINING ---

# Phase 1: Logic (Target D=1.5)
logic_ids = train_phase("Logic", logic_prompts, 1.5, model, tokenizer, steps=300)

# Phase 2: Creative (Target D=2.5)
creative_ids = train_phase("Creative", creative_prompts, 2.5, model, tokenizer, steps=300)

## 4. Save and Download Weights

In [None]:
save_path = "smollm3_bicameral_diff"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

# Zip for download
get_ipython().system(f'zip -r {save_path}.zip {save_path}')

try:
    from google.colab import files
    files.download(f"{save_path}.zip")
except ImportError:
    print(f"Local environment detected. Files saved to {save_path}.zip")
