In [None]:
# Notebook repaired cell placeholder
# This file is managed by the repo. If you edit it here, ensure compatibility with the installed environment.


# CRSM Colab Runbook

This notebook scaffolds running distillation or training for the CRSM project on Google Colab. It mounts Google Drive, installs dependencies (preferring `mamba` when available), and provides interactive cells to choose whether to run distillation or training, configure hyperparameters, and save checkpoints to Drive.

Use the cells in order. The notebook supports using Gemini 2.5 Flash as the default teacher but allows switching to other providers or a local model.

In [None]:
# Cell 1: Detect environment and prefer mamba for package installation
import shutil
import sys
import os

HAS_MAMBA = shutil.which('mamba') is not None
print('mamba installed:', HAS_MAMBA)

# Provide helper install command variable
INSTALL_CMD = 'mamba install -y --skip-existing -c conda-forge' if HAS_MAMBA else 'pip install'
print('Using install command:', INSTALL_CMD)

# CUDA check
import torch
print('CUDA available:', torch.cuda.is_available())
print('torch version:', torch.__version__)


In [None]:
# Cell 2: Install core dependencies (uses INSTALL_CMD from previous cell)
print('Installing CRSM dependencies...')
cmd = f"{INSTALL_CMD} torch einops transformers datasets tokenizers wandb"
print('Running:', cmd)
import subprocess
res = subprocess.run(cmd, shell=True)
print('Install exit code:', res.returncode)
print('If you used pip, you may need to restart the runtime before importing heavy packages like torch or transformers.')


In [None]:
# Cell 3: Mount Google Drive (optional) and set WORKDIR
from google.colab import drive
drive.mount('/content/drive')
WORKDIR = '/content/drive/MyDrive/crsm_runs'
import os
os.makedirs(WORKDIR, exist_ok=True)
print('WORKDIR:', WORKDIR)


In [None]:
# Cell 4: Example - create HF tokenizer and small streaming dataset
from crsm.tokenizer import Tokenizer
from crsm.dataset import StreamingTextDataset

# Example: use a small HF tokenizer name if available, else use fallback
try:
    tok = Tokenizer(hf_name='sshleifer/tiny-gpt2')
    print('Loaded HF tokenizer, vocab_size=', tok.vocab_size)
except Exception as e:
    print('HF tokenizer not available, using fallback:', e)
    tok = Tokenizer()

# Create a streaming dataset from WORKDIR/texts (create that folder and upload files there)
ds = StreamingTextDataset(data_dir=WORKDIR + '/texts', seq_len=64, hf_tokenizer_name=None)
print('Streaming dataset created; first few examples:')
for i, (inp, tgt) in enumerate(ds):
    print('inp.shape=', inp.shape, 'tgt.shape=', tgt.shape)
    if i >= 1:
        break


In [None]:
# Cell 5: Example - run a short training job (small, for demo)
# This runs the CLI train command in-process. For longer runs, use the CLI in a shell.

from crsm.train import main as train_main

train_main(epochs=1, batch_size=8, vocab_size=1000, seq_len=32, lr=1e-3, data_dir=None, checkpoint_dir=WORKDIR + '/checkpoints', wandb_enabled=False)
print('Training demo finished.')


In [None]:
# Cell 6: Example - generate distillation traces from prompts file using distill_runner
# Create a small prompts file and run the distill runner (this will attempt to call the provider)
prompts_path = WORKDIR + '/prompts.txt'
with open(prompts_path, 'w') as f:
    f.write('Write a concise chain-of-thought explanation for why 2+2=4\n')

out_path = WORKDIR + '/distill_traces.jsonl'

# run distill runner (note: requires provider libs/API keys to be configured)
from crsm.distill_runner import generate_batch

try:
    generate_batch(["Explain why gravity causes objects to fall."], provider_kind='local', out_path=out_path, prompt_erasure=False, workers=1)
    print('Wrote traces to', out_path)
except Exception as e:
    print('Distillation generation failed (expected if provider libs not installed or API keys missing):', e)
