# Cortex Stage A1 Training - Google Colab

This notebook trains Cortex fast-weight sidecars on synthetic long-context tasks.

**Requirements:** T4 GPU (free tier) or better

**Runtime:** ~1-2 hours for full training


## Setup Environment


In [16]:
# Check GPU
!nvidia-smi

In [17]:
# Install dependencies
%pip install -q torch>=2.0.0 transformers>=4.30.0 accelerate sentencepiece pyyaml matplotlib numpy


In [18]:
# Auto-detect environment and setup paths
import sys
import os

# Smart detection: check which cortex-4 path actually exists
if os.path.exists('/Users/mazalcohen/cortex-4'):
    # Running locally
    IN_COLAB = False
    CORTEX_ROOT = '/Users/mazalcohen/cortex-4'
    CHECKPOINT_DIR = '/Users/mazalcohen/cortex-4/checkpoints'
    LOG_DIR = '/Users/mazalcohen/cortex-4/logs/a1'
elif os.path.exists('/content/cortex-4'):
    # Running in Google Colab cloud
    IN_COLAB = True
    CORTEX_ROOT = '/content/cortex-4'
    CHECKPOINT_DIR = '/content/checkpoints'
    LOG_DIR = '/content/logs/a1'
else:
    raise FileNotFoundError("cortex-4 folder not found at /Users/mazalcohen/cortex-4 or /content/cortex-4")

# Add cortex-4 to path
sys.path.insert(0, CORTEX_ROOT)

# Display config
print(f"Environment: {'Google Colab (cloud)' if IN_COLAB else 'Local Colab'}")
print(f"Cortex root: {CORTEX_ROOT}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"Logs: {LOG_DIR}")
print("âœ“ Setup complete!")


## Auto-Configure Based on GPU Memory


In [None]:
import torch

if torch.cuda.is_available():
    gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {gpu_mem_gb:.1f} GB\n")
    
    # Auto-config based on available memory
    if gpu_mem_gb > 14:
        GAPS = [512, 1024, 2048]
        BATCH_SIZE = 2
        SAMPLES = 256
        print("Using MEDIUM config (T4/16GB)")
    elif gpu_mem_gb > 10:
        GAPS = [256, 512, 1024]
        BATCH_SIZE = 1
        SAMPLES = 128
        print("Using SMALL config (12GB)")
    else:
        GAPS = [256, 512]
        BATCH_SIZE = 1
        SAMPLES = 64
        print("Using MINIMAL config (<10GB)")
    
    print(f"  Gaps: {GAPS}")
    print(f"  Batch Size: {BATCH_SIZE}")
    print(f"  Samples per gap: {SAMPLES}")
else:
    print("ERROR: No GPU detected! Enable GPU in Runtime > Change runtime type")


## Run Stage A1 Training


In [None]:
import os
os.chdir(CORTEX_ROOT)
print(f"Working directory: {os.getcwd()}\n")

# Create output directories
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

gaps_str = ' '.join(map(str, GAPS))

print(f"Starting Stage A1 Training:")
print(f"  Task: Key-Value Binding")
print(f"  Gaps: {GAPS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Samples per gap: {SAMPLES}")
print(f"  Model: Qwen/Qwen1.5-1.8B-Chat")
print(f"  Output: {CHECKPOINT_DIR}")
print(f"  Logs: {LOG_DIR}\n")

!python scripts/stage_a1_enable_fast.py \
    --model Qwen/Qwen1.5-1.8B-Chat \
    --task kv \
    --gaps {gaps_str} \
    --batch_size {BATCH_SIZE} \
    --epochs 2 \
    --save_dir {CHECKPOINT_DIR} \
    --log_dir {LOG_DIR} \
    --amp true \
    --samples_per_gap {SAMPLES} \
    --seed 42 \
    --fast_rank 16 \
    --lr_sidecar 2e-4 \
    --grad_clip 1.0


## Quick Results Analysis


In [None]:
import json
from pathlib import Path
from collections import defaultdict

probe_files = list(Path(LOG_DIR).rglob('probes.jsonl'))
if probe_files:
    probe_file = probe_files[0]
    print(f"Analyzing: {probe_file}\n")
    
    gap_accuracy = defaultdict(list)
    gap_fast_share = defaultdict(list)
    
    with open(probe_file) as f:
        for line in f:
            record = json.loads(line)
            gap = record['gap']
            gap_accuracy[gap].append(record['correct'])
            gap_fast_share[gap].append(record.get('fast_share_mean', 0))
    
    print("=" * 60)
    print("ACCURACY BY GAP LENGTH:")
    print("=" * 60)
    for gap in sorted(gap_accuracy.keys()):
        acc_list = gap_accuracy[gap]
        acc = sum(acc_list) / len(acc_list)
        fast_share = sum(gap_fast_share[gap]) / len(gap_fast_share[gap])
        print(f"Gap {gap:4d}: {acc*100:5.1f}% ({sum(acc_list):3d}/{len(acc_list):3d}) | Fast Share: {fast_share:.3f}")
    print("=" * 60)
else:
    print(f"No probe logs found in {LOG_DIR}")


## Download Results


In [None]:
# Archive results
import os
result_zip = os.path.join(os.path.dirname(CORTEX_ROOT), 'cortex_a1_results.zip')

print(f"Archiving results to: {result_zip}")
!zip -r {result_zip} {LOG_DIR} {CHECKPOINT_DIR}

if IN_COLAB:
    from google.colab import files
    files.download(result_zip)
    print("\nDownload started! Check your browser's download folder.")
else:
    print(f"\nResults saved to: {result_zip}")
    print("You can find your logs and checkpoints in:")
    print(f"  - Logs: {LOG_DIR}")
    print(f"  - Checkpoints: {CHECKPOINT_DIR}")
