# 🧠 ExperimentalDLLM — Colab Training

Resume pre-training of the Episodic-Centric PCModel from step 610,000.

**Prerequisites:**
1. Upload `ExperimentalDLLM_colab.tar.gz` to Google Drive root
2. Use Colab Pro with GPU runtime (A100 preferred, T4 acceptable)

---

## 1. GPU Check

In [None]:
import torch
import os

if not torch.cuda.is_available():
    raise RuntimeError(
        "❌ No GPU detected!\n"
        "Go to Runtime → Change runtime type → GPU (A100 or T4)"
    )

gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9

print(f"✅ GPU: {gpu_name}")
print(f"   Memory: {gpu_mem:.1f} GB")
print(f"   CUDA: {torch.version.cuda}")
print(f"   PyTorch: {torch.__version__}")

# Recommend batch size based on GPU
if gpu_mem > 35:
    RECOMMENDED_BATCH = 32
    print(f"\n🚀 A100 detected — can use batch_size={RECOMMENDED_BATCH}")
elif gpu_mem > 14:
    RECOMMENDED_BATCH = 16
    print(f"\n⚡ Good GPU — recommended batch_size={RECOMMENDED_BATCH}")
else:
    RECOMMENDED_BATCH = 8
    print(f"\n⚠️  Limited GPU — using batch_size={RECOMMENDED_BATCH}")

## 2. Mount Google Drive & Extract Code

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Extract the uploaded tarball
DRIVE_TARBALL = '/content/drive/MyDrive/ExperimentalDLLM_colab.tar.gz'
WORK_DIR = '/content/ExperimentalDLLM'

if not os.path.exists(DRIVE_TARBALL):
    raise FileNotFoundError(
        f"❌ Tarball not found at {DRIVE_TARBALL}\n"
        "Please upload ExperimentalDLLM_colab.tar.gz to Google Drive root."
    )

print(f">> Extracting {DRIVE_TARBALL}...")
!tar -xzf {DRIVE_TARBALL} -C /content/

# Rename if needed
extracted = '/content/ExperimentalDLLM_colab'
if os.path.exists(extracted) and not os.path.exists(WORK_DIR):
    os.rename(extracted, WORK_DIR)
elif os.path.exists(extracted):
    !cp -r {extracted}/* {WORK_DIR}/

os.chdir(WORK_DIR)
print(f"\n✅ Working directory: {os.getcwd()}")
print(f"   Contents: {os.listdir('.')}")

## 3. Install Dependencies

In [None]:
!pip install -q transformers datasets tqdm matplotlib
print("✅ Dependencies installed")

## 4. Verify Checkpoint

Make sure the step 610,000 checkpoint is available.

In [None]:
CKPT_DIR = os.path.join(WORK_DIR, 'checkpoints')
CKPT_FILE = os.path.join(CKPT_DIR, 'checkpoint_small_pretrain_step610000.pt')

# Set up Drive checkpoint directory for saving
DRIVE_CKPT_DIR = '/content/drive/MyDrive/ExperimentalDLLM_checkpoints'
os.makedirs(DRIVE_CKPT_DIR, exist_ok=True)

if os.path.exists(CKPT_FILE):
    size_gb = os.path.getsize(CKPT_FILE) / 1e9
    print(f"✅ Checkpoint found: {CKPT_FILE} ({size_gb:.1f} GB)")
else:
    # Check Drive for previously saved checkpoints
    drive_ckpts = sorted([
        f for f in os.listdir(DRIVE_CKPT_DIR)
        if f.startswith('checkpoint_small_pretrain_step') and f.endswith('.pt')
        and 'latest' not in f
    ]) if os.path.exists(DRIVE_CKPT_DIR) else []
    
    if drive_ckpts:
        latest = drive_ckpts[-1]
        print(f">> Copying latest checkpoint from Drive: {latest}")
        os.makedirs(CKPT_DIR, exist_ok=True)
        !cp "{DRIVE_CKPT_DIR}/{latest}" "{CKPT_DIR}/"
        !cp "{DRIVE_CKPT_DIR}/{latest}" "{CKPT_DIR}/checkpoint_small_pretrain_latest.pt"
        CKPT_FILE = os.path.join(CKPT_DIR, latest)
        print(f"✅ Checkpoint restored: {latest}")
    else:
        print("⚠️  No checkpoint found. Training will start from scratch.")
        CKPT_FILE = None

print(f"\n📁 Drive checkpoint directory: {DRIVE_CKPT_DIR}")
print(f"   New checkpoints will be saved here automatically.")

## 5. Configure Training

Adjust batch size for GPU memory. Data streams directly from HuggingFace — no caching needed.

In [None]:
# ============================================================
# CONFIGURATION — Adjust these as needed
# ============================================================

BATCH_SIZE = RECOMMENDED_BATCH  # Auto-set by GPU (cell 1)
TOTAL_STEPS = 5000000           # ~80B tokens at batch=32
CHECKPOINT_INTERVAL = 5000      # Save every N steps
LOG_INTERVAL = 100              # Print loss every N steps

# ============================================================

# Patch the config to use the optimal batch size
from src.config import ConfigSmall
original_bs = ConfigSmall.batch_size
ConfigSmall.batch_size = BATCH_SIZE
print(f">> Batch size: {original_bs} → {BATCH_SIZE}")
print(f">> Total steps: {TOTAL_STEPS:,}")
print(f">> Checkpoint interval: {CHECKPOINT_INTERVAL:,}")
print(f">> Data: Streaming from HuggingFace (no cache needed)")

## 6. 🚀 Start Training

Uses `--stream` to load data directly from HuggingFace FineWeb-Edu.
Checkpoints auto-sync to Google Drive every 5 minutes.

In [None]:
import subprocess
import shutil
import signal
import time

# Build the training command
cmd = [
    'python', 'train_episodic.py',
    '--config', 'small',
    '--mode', 'pretrain',
    '--steps', str(TOTAL_STEPS),
    '--device', 'cuda',
    '--stream',
    '--save_dir', CKPT_DIR,
    '--checkpoint_interval', str(CHECKPOINT_INTERVAL),
    '--log_interval', str(LOG_INTERVAL),
    '--resume', CKPT_DIR,
]

print(f">> Command: {' '.join(cmd)}")
print(f">> Checkpoints will be copied to Google Drive every {CHECKPOINT_INTERVAL} steps")
print("="*60)

# Run training with real-time output
process = subprocess.Popen(
    cmd,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    universal_newlines=True,
    bufsize=1
)

last_drive_sync = time.time()
DRIVE_SYNC_INTERVAL = 300  # Sync to Drive every 5 minutes

try:
    for line in process.stdout:
        print(line, end='')
        
        # Auto-sync checkpoints to Google Drive periodically
        if '💾 Checkpoint saved' in line and time.time() - last_drive_sync > DRIVE_SYNC_INTERVAL:
            latest = os.path.join(CKPT_DIR, 'checkpoint_small_pretrain_latest.pt')
            if os.path.exists(latest):
                try:
                    step_str = line.split('step')[1].split('\u2192')[0].strip()
                    step_file = f'checkpoint_small_pretrain_step{step_str}.pt'
                    src = os.path.join(CKPT_DIR, step_file)
                    if os.path.exists(src):
                        shutil.copy2(src, DRIVE_CKPT_DIR)
                        shutil.copy2(src, os.path.join(DRIVE_CKPT_DIR, 'checkpoint_small_pretrain_latest.pt'))
                        print(f"   ☁️  Synced to Google Drive: {step_file}")
                except:
                    shutil.copy2(latest, DRIVE_CKPT_DIR)
                    print(f"   ☁️  Synced latest checkpoint to Google Drive")
                last_drive_sync = time.time()

except KeyboardInterrupt:
    print("\n\n⚠️  Training interrupted! Saving final checkpoint...")
    process.send_signal(signal.SIGINT)
    process.wait(timeout=30)

finally:
    latest = os.path.join(CKPT_DIR, 'checkpoint_small_pretrain_latest.pt')
    if os.path.exists(latest):
        shutil.copy2(latest, os.path.join(DRIVE_CKPT_DIR, 'checkpoint_small_pretrain_latest.pt'))
        print(f"\n☁️  Final checkpoint synced to Google Drive")
    
    retcode = process.wait()
    print(f"\nTraining process exited with code: {retcode}")

## 7. Benchmark (Optional)

Run the WikiText-2 perplexity benchmark on the latest checkpoint.

In [None]:
# Find the latest checkpoint
ckpts = sorted([
    f for f in os.listdir(CKPT_DIR)
    if f.startswith('checkpoint_small_pretrain_step') and f.endswith('.pt')
    and 'latest' not in f
])

if ckpts:
    latest_ckpt = os.path.join(CKPT_DIR, ckpts[-1])
    print(f">> Benchmarking: {ckpts[-1]}")
    !python benchmark_all.py --synapse_ckpt "{latest_ckpt}" --config small
else:
    print("❌ No checkpoints found to benchmark.")

---

## 🔄 Session Recovery

If your Colab session times out:

1. **Start a new runtime** (Runtime → Change runtime type → GPU)
2. **Run cells 1–4** — detects the latest checkpoint from Google Drive
3. **Run cells 5–6** — training auto-resumes from where it left off

Checkpoints are safe in `My Drive/ExperimentalDLLM_checkpoints/`