# Redox All-in-One Colab: Rust + Iron Fine-tune, Predict, Evaluate

Run this notebook with **Runtime -> Run all**.

It will, in one run:
1. Train **Rust arm** adapter.
2. Generate Rust test predictions JSONL.
3. Train **Iron arm** adapter.
4. Generate Iron test predictions JSONL.
5. Repeat for each seed in `SEEDS`.
6. Evaluate each seed and write an aggregate report.

Expected Google Drive layout:
- `/content/drive/MyDrive/redox/`
- dataset files under `data/pilot/<version>/unsloth/`


In [None]:
from google.colab import drive
from pathlib import Path
from datetime import datetime
import os

drive.mount('/content/drive')

REPO_DIR = Path('/content/drive/MyDrive/redox').resolve()

# Prefer this dataset version first; notebook will auto-fallback if missing.
PREFERRED_DATA_VERSION = 'foundation_v2'

BASE_MODEL = 'unsloth/Qwen3-4B-Instruct-2507'
SEEDS = [3407, 2108]
ARMS = ['rust', 'iron']
MAX_SEQ_LENGTH = 2048
MAX_NEW_TOKENS = 320
NUM_TRAIN_EPOCHS = 1

stamp = datetime.now().strftime('%Y%m%d_%H%M%S')
RUN_TAG = f'all_in_one_qwen3_4b_multiseed_{stamp}'
TRAINING_ROOT = REPO_DIR / 'training' / 'all_in_one' / RUN_TAG
EVAL_ROOT = REPO_DIR / 'eval' / 'all_in_one' / RUN_TAG
TRAINING_ROOT.mkdir(parents=True, exist_ok=True)
EVAL_ROOT.mkdir(parents=True, exist_ok=True)

os.chdir(REPO_DIR)


def has_required_files(data_dir: Path) -> bool:
    required = []
    for arm in ARMS:
        required.extend([
            data_dir / f'{arm}_train.jsonl',
            data_dir / f'{arm}_val.jsonl',
            data_dir / f'{arm}_test.jsonl',
        ])
    return all(p.exists() for p in required)


candidate_versions = [
    PREFERRED_DATA_VERSION,
    'foundation_v2',
    'foundation_v1',
]
# Preserve order but remove duplicates
candidate_versions = list(dict.fromkeys(candidate_versions))

DATA_VERSION = None
DATA_DIR = None
for version in candidate_versions:
    candidate = REPO_DIR / 'data' / 'pilot' / version / 'unsloth'
    if has_required_files(candidate):
        DATA_VERSION = version
        DATA_DIR = candidate
        break

if DATA_DIR is None:
    searched = [str(REPO_DIR / 'data' / 'pilot' / v / 'unsloth') for v in candidate_versions]
    raise FileNotFoundError(
        'Could not find dataset with required files for both arms.\n'
        'Expected files: rust_{train,val,test}.jsonl and iron_{train,val,test}.jsonl\n'
        'Searched:\n' + '\n'.join(searched)
    )

print('Repo dir:', REPO_DIR)
print('Data version:', DATA_VERSION)
print('Data dir:', DATA_DIR)
print('Run tag :', RUN_TAG)
print('Seeds   :', SEEDS)
print('Training outputs:', TRAINING_ROOT)
print('Eval outputs:', EVAL_ROOT)


In [None]:
%%capture
import os, re
if 'COLAB_' not in ''.join(os.environ.keys()):
    !pip install unsloth
else:
    import torch
    v = re.match(r'[\d]{1,}\.[\d]{1,}', str(torch.__version__)).group(0)
    xformers = 'xformers==' + {'2.10':'0.0.34','2.9':'0.0.33.post1','2.8':'0.0.32.post2'}.get(v, '0.0.34')
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth_zoo bitsandbytes accelerate {xformers} peft trl triton unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2


In [None]:
import os
import shutil

if shutil.which('cargo') is None:
    !curl https://sh.rustup.rs -sSf | sh -s -- -y
    os.environ['PATH'] += ':/root/.cargo/bin'

!cargo --version
!rustc --version
!cargo build --release --bin redox

REDOX_CMD = str(REPO_DIR / 'target' / 'release' / 'redox')
print('Redox binary:', REDOX_CMD)


In [None]:
import gc
import inspect
import json
import re
from typing import Any

import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import (
    get_chat_template,
    standardize_data_formats,
    train_on_responses_only,
)
from tqdm.auto import tqdm


def load_jsonl(path: Path) -> list[dict[str, Any]]:
    rows = []
    for line in path.read_text(encoding='utf-8').splitlines():
        if line.strip():
            rows.append(json.loads(line))
    return rows


def get_user_prompt(row: dict[str, Any]) -> str:
    for msg in row.get('conversations', []):
        if msg.get('role') == 'user':
            return msg.get('content', '')
    return ''


def clean_generation(text: str) -> str:
    out = text.strip()
    if out.startswith('```'):
        out = re.sub(r'^```[A-Za-z0-9_-]*\n', '', out)
        out = re.sub(r'\n```\s*$', '', out)
    return out.strip()


def build_model_and_tokenizer(seed: int):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=BASE_MODEL,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=True,
        load_in_8bit=False,
        full_finetuning=False,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=32,
        target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
        lora_alpha=32,
        lora_dropout=0,
        bias='none',
        use_gradient_checkpointing='unsloth',
        random_state=seed,
        use_rslora=False,
        loftq_config=None,
    )

    tokenizer = get_chat_template(tokenizer, chat_template='qwen3-instruct')
    return model, tokenizer


def format_dataset_for_sft(dataset_split, tokenizer):
    ds = standardize_data_formats(dataset_split)

    def formatting_prompts_func(examples):
        convos = examples['conversations']
        texts = [
            tokenizer.apply_chat_template(c, tokenize=False, add_generation_prompt=False)
            for c in convos
        ]
        return {'text': texts}

    return ds.map(formatting_prompts_func, batched=True)


def generate_once(model, tokenizer, messages):
    rendered = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    model_inputs = tokenizer(rendered, return_tensors='pt').to(next(model.parameters()).device)

    with torch.inference_mode():
        outputs = model.generate(
            **model_inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=0.0,
            top_p=1.0,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    new_tokens = outputs[0][model_inputs['input_ids'].shape[1]:]
    return clean_generation(tokenizer.decode(new_tokens, skip_special_tokens=True))


def train_and_predict_arm(arm: str, seed: int) -> dict[str, str]:
    arm_train = DATA_DIR / f'{arm}_train.jsonl'
    arm_val = DATA_DIR / f'{arm}_val.jsonl'
    arm_test = DATA_DIR / f'{arm}_test.jsonl'

    ds = load_dataset(
        'json',
        data_files={
            'train': str(arm_train),
            'validation': str(arm_val),
            'test': str(arm_test),
        },
    )

    model, tokenizer = build_model_and_tokenizer(seed)

    train_dataset = format_dataset_for_sft(ds['train'], tokenizer)
    val_dataset = format_dataset_for_sft(ds['validation'], tokenizer)

    run_name = f'qwen3-4b-redox-{arm}-seed{seed}'
    run_dir = TRAINING_ROOT / run_name
    run_dir.mkdir(parents=True, exist_ok=True)

    trainer_kwargs = {
        'model': model,
        'train_dataset': train_dataset,
        'eval_dataset': val_dataset,
        'args': SFTConfig(
            dataset_text_field='text',
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            warmup_steps=5,
            num_train_epochs=NUM_TRAIN_EPOCHS,
            learning_rate=2e-4,
            logging_steps=10,
            eval_steps=50,
            eval_strategy='steps',
            save_steps=50,
            save_total_limit=2,
            optim='adamw_8bit',
            weight_decay=0.001,
            lr_scheduler_type='linear',
            seed=seed,
            report_to='none',
            output_dir=str(run_dir),
            bf16=is_bfloat16_supported(),
            fp16=not is_bfloat16_supported(),
        ),
    }

    trainer_sig = inspect.signature(SFTTrainer.__init__)
    if 'tokenizer' in trainer_sig.parameters:
        trainer_kwargs['tokenizer'] = tokenizer
    elif 'processing_class' in trainer_sig.parameters:
        trainer_kwargs['processing_class'] = tokenizer

    trainer = SFTTrainer(**trainer_kwargs)

    trainer = train_on_responses_only(
        trainer,
        instruction_part='<|im_start|>user\n',
        response_part='<|im_start|>assistant\n',
    )

    print(f'\n=== Training {arm.upper()} arm @ seed {seed} ===')
    train_result = trainer.train()
    print(train_result)

    adapter_dir = run_dir / 'final_adapter'
    trainer.model.save_pretrained(str(adapter_dir))
    tokenizer.save_pretrained(str(adapter_dir))

    FastLanguageModel.for_inference(model)
    model.eval()

    test_rows = load_jsonl(arm_test)
    pred_path = EVAL_ROOT / f'predictions_{arm}_seed{seed}.jsonl'

    with pred_path.open('w', encoding='utf-8') as f:
        for row in tqdm(test_rows, desc=f'Generating {arm} predictions seed={seed}'):
            prompt = get_user_prompt(row)
            prediction = generate_once(model, tokenizer, [{'role': 'user', 'content': prompt}])

            out_row = {
                'id': row.get('id'),
                'family': row.get('family'),
                'split': row.get('split', 'test'),
                'arm': arm,
                'prompt': prompt,
                'prediction': prediction,
            }
            f.write(json.dumps(out_row, ensure_ascii=True) + '\n')

    del trainer, model, tokenizer, train_dataset, val_dataset, ds
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return {
        'adapter_dir': str(adapter_dir),
        'predictions': str(pred_path),
    }


In [None]:
results_by_seed = {}
for seed in SEEDS:
    print(f'\n================ SEED {seed} ================')
    seed_results = {}
    for arm in ARMS:
        seed_results[arm] = train_and_predict_arm(arm, seed)
    results_by_seed[str(seed)] = seed_results

print('\nCompleted arm runs by seed:')
print(json.dumps(results_by_seed, indent=2))


In [None]:
import json
import subprocess

report_paths = []
for seed in SEEDS:
    rust_preds = Path(results_by_seed[str(seed)]['rust']['predictions'])
    iron_preds = Path(results_by_seed[str(seed)]['iron']['predictions'])
    report_path = EVAL_ROOT / f'report_seed{seed}.json'

    cmd = [
        'python3',
        'scripts/evaluate_predictions.py',
        '--rust', str(rust_preds),
        '--iron', str(iron_preds),
        '--redox-cmd', REDOX_CMD,
        '--out', str(report_path),
    ]
    print('Running:', ' '.join(cmd))
    proc = subprocess.run(cmd, text=True, capture_output=True, check=False)
    print(proc.stdout)
    if proc.returncode != 0:
        print(proc.stderr)
        raise RuntimeError(f'Evaluation failed for seed {seed}')

    report_paths.append(str(report_path))

aggregate_report = None
if len(report_paths) >= 2:
    aggregate_report = str(EVAL_ROOT / 'report_aggregate_multiseed.json')
    agg_cmd = ['python3', 'scripts/aggregate_eval_reports.py', *report_paths, '--out', aggregate_report]
    print('Running:', ' '.join(agg_cmd))
    agg = subprocess.run(agg_cmd, text=True, capture_output=True, check=False)
    print(agg.stdout)
    if agg.returncode != 0:
        print(agg.stderr)
        raise RuntimeError('Aggregate report generation failed')

summary_path = EVAL_ROOT / 'run_summary.json'
summary_path.write_text(
    json.dumps(
        {
            'repo_dir': str(REPO_DIR),
            'data_dir': str(DATA_DIR),
            'run_tag': RUN_TAG,
            'seeds': SEEDS,
            'training_root': str(TRAINING_ROOT),
            'eval_root': str(EVAL_ROOT),
            'results_by_seed': results_by_seed,
            'seed_reports': report_paths,
            'aggregate_report': aggregate_report,
        },
        indent=2,
    ) + '\n',
    encoding='utf-8',
)

print('Seed reports:')
for p in report_paths:
    print('-', p)
if aggregate_report:
    print('Aggregate report:', aggregate_report)
print('Saved run summary:', summary_path)


## Outputs

All outputs are stored under your Drive folder:

- Training artifacts: `training/all_in_one/<run_tag>/`
- Per-seed predictions and reports: `eval/all_in_one/<run_tag>/`
- Multi-seed aggregate report: `eval/all_in_one/<run_tag>/report_aggregate_multiseed.json`
- Run metadata: `eval/all_in_one/<run_tag>/run_summary.json`

You can compare the aggregate report directly against your previous baselines.
