# CRSM: End-to-End Cloud Training Pipeline

This notebook runs the complete 4-stage training pipeline for the Continuous Reasoning State Model (CRSM) on Google Colab.

**Stages:**
1.  **Data Preparation:** Download and tokenize FineWeb-Edu/GSM8K.
2.  **Stage 1 (System 1):** Train the Mamba backbone.
3.  **Stage 2 (Subconscious):** Distill the Latent Dynamics Model.
4.  **Stage 3 (Judgment):** Train the Value Head via offline expert iteration.
5.  **Stage 4 (Assembly):** Assemble the final artifact.

In [None]:
!nvidia-smi

## 1. Setup Environment

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Create experiment directories in Drive
!mkdir -p /content/drive/MyDrive/crsm_experiments/stage_{1,2,3,4}

In [None]:
# Clone Repository
!git clone https://github.com/Pomilon-Intelligence-Lab/crsm.git
%cd crsm

# Install Dependencies
!pip install -r requirements.txt
!pip install -e .

## 2. Data Preparation
We prepare the **FineWeb-Edu** dataset for the backbone and **GSM8K** for reasoning tasks.
The data is tokenized and saved as `uint16` binary files for memory-mapped streaming.

In [None]:
# Prepare FineWeb-Edu (Sample)
# Note: This might take a while. We use a small shard size to see results quickly.
!python scripts/data/prepare_dataset.py --dataset fineweb --subset sample-10BT --shard-size 20000000 --output-dir data/fineweb

In [None]:
# Prepare GSM8K (Reasoning)
!python scripts/data/prepare_dataset.py --dataset gsm8k --output-dir data/gsm8k

## 3. Stage 1: Backbone Training (System 1)
We train the Mamba backbone on the prepared FineWeb data.
We use the `baseline_27m` configuration but override the data directory and set epoch count.

In [None]:
!python scripts/training/stage_1_backbone.py \    --config configs/baseline_27m.yaml \    --data-dir data/fineweb \    --epochs 1 \    --no-wandb

In [None]:
# Backup to Drive
!cp experiments/stage_1/backbone_final.pt /content/drive/MyDrive/crsm_experiments/stage_1/

## 4. Stage 2: Dynamics Distillation (The Subconscious)
We freeze the backbone and train the latent dynamics model to predict state transitions.

In [None]:
!python scripts/training/stage_2_dynamics.py \    --config configs/baseline_27m.yaml \    --epochs 2 \    --samples 10000

In [None]:
# Backup to Drive
!cp experiments/stage_2/dynamics_final.pt /content/drive/MyDrive/crsm_experiments/stage_2/

## 5. Stage 3: Value Head Training (The Judgment)
We train the Value Head using offline MCTS rollouts to recognize high-quality reasoning paths.

In [None]:
!python scripts/training/stage_3_value_head.py \    --config configs/baseline_27m.yaml \    --epochs 1

In [None]:
# Backup to Drive
!cp experiments/stage_3/backbone_with_value.pt /content/drive/MyDrive/crsm_experiments/stage_3/

## 6. Stage 4: Assembly & Verification
We combine the trained components into the final CRSM artifact.

In [None]:
!python scripts/training/stage_4_assembly.py \    --config configs/baseline_27m.yaml

In [None]:
# Save Final Model to Drive
!cp experiments/stage_4/crsm_final.pt /content/drive/MyDrive/crsm_experiments/stage_4/

## 7. Inference Demo
Load the assembled model and run a test generation with the "Thinking" loop active.

In [None]:
import torch
import asyncio
from crsm.model import CRSMModel, CRSMConfig
from crsm.tokenizer import Tokenizer

# Load Model
checkpoint_path = "experiments/stage_4/crsm_final.pt"
ckpt = torch.load(checkpoint_path)
config = CRSMConfig.from_dict(ckpt['config']['model'])
config.autonomous_mode = True # Enable background thinking

model = CRSMModel(config).cuda()
model.load_state_dict(ckpt['model_state_dict'], strict=False)
model.load_dynamics("experiments/stage_2/dynamics_final.pt") # Ensure dynamics are loaded
model.eval()

tokenizer = Tokenizer("gpt2")

async def generate_demo(text):
    print(f"Prompt: {text}")
    input_ids = torch.tensor([tokenizer.encode(text)]).cuda()
    
    output_ids = await model.crsm.think_and_generate(
        input_ids, 
        max_length=50, 
        use_deliberation=True, 
        deliberation_lag=3
    )
    
    output_text = tokenizer.decode(output_ids.tolist())
    print(f\"\nGenerated: {output_text}\")

# Run
await generate_demo("The future of artificial intelligence depends on")