# CRUST-bench Multi-Stage Fine-Tuning Pipeline

This notebook covers dataset loading, stage 1 fine-tuning on stubs, pseudo-labeling for full transpilation, stage 2 fine-tuning, and evaluation.

## 1. Setup Environment
```bash
# Create and activate conda environment
conda env create -f environment.yml
conda activate crustbench
```

## 2: Dataset Loading and Preprocessing
Below we define `CRUSTLoader`, a helper class that:
- **Locates** all project directories under `datasets/CBench`.
- **Reads** `.c` and `.h` files, filters out tests/examples, and **removes comments**.
- **Merges** header files into their corresponding C sources to create self-contained code snippets.
- **Pairs** each cleaned C snippet with its Rust stub interface by matching filename stems.

This prepares raw data for instruction-style fine-tuning in chat format.

In [56]:
import json
import re
from pathlib import Path
from typing import List, Dict, Tuple

class CRUSTLoader:
    def __init__(self, cbench_dir: str, rbench_dir: str):
        self.cbench = Path(cbench_dir)
        self.rbench = Path(rbench_dir)
        # only directories under CBench are projects
        self.projects = [p.name for p in self.cbench.iterdir() if p.is_dir()]

    @staticmethod
    def remove_comments(code: str) -> str:
        # strip single-line comments
        code = re.sub(r"//.*?\n", "\n", code)
        # strip block comments
        code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
        # collapse multiple blank lines
        return re.sub(r"\n+", "\n", code)

    def get_c_files(self, project: str) -> List[Dict]:
        """
        Load all .c and .h files (excluding tests/examples/bin) for one project,
        return list of {'file_name': ..., 'content': ...}
        """
        records = []
        base = self.cbench / project
        for f in base.rglob("*"):
            if not f.is_file(): 
                continue
            if f.suffix not in {".c", ".h"}:
                continue
            if any(x in f.parts for x in ("test", "tests", "example", "examples", "bin", "unity")):
                continue
            raw = f.read_text(errors="ignore")
            txt = self.remove_comments(raw)
            records.append({"file_name": f.name, "content": txt})
        # sort for deterministic output
        return sorted(records, key=lambda d: d["file_name"])

    def process_c_and_h(self, c_records: List[Dict]) -> List[Dict]:
        """
        Merge headers into their corresponding .c where both exist.
        Returns a new list of records, with .h merged and renamed to .c.
        """
        c_map = {r["file_name"]: r["content"] for r in c_records}
        out = []
        visited = set()

        # iterate headers first, attach to .c
        for rec in sorted(c_records, key=lambda d: d["file_name"], reverse=True):
            name = rec["file_name"]
            if name.endswith(".h"):
                base = name[:-2] + ".c"
                if base in c_map:
                    # merge header+source
                    merged = (
                        "// from header\n/*\n"
                        + rec["content"]
                        + "\n*/\n"
                        + c_map[base]
                    )
                    out.append({"file_name": base, "content": merged})
                    visited.add(base)
                else:
                    out.append(rec)
            else:
                if name not in visited:
                    out.append(rec)

        # ensure sorted
        return sorted(out, key=lambda d: d["file_name"])

    def load_rust_interfaces(self, project: str) -> List[Dict]:
        """
        Load all stub interfaces (*.rs in src/interfaces) with their raw content.
        """
        recs = []
        intf_dir = self.rbench / project / "src" / "interfaces"
        if intf_dir.exists():
            for f in intf_dir.glob("*.rs"):
                recs.append({"file_name": f.name, "content": f.read_text()})
        return sorted(recs, key=lambda d: d["file_name"])

    def load_rust_bins(self, project: str) -> List[Dict]:
        """
        Load all test drivers (*.rs in src/bin) if you need them later.
        """
        recs = []
        bin_dir = self.rbench / project / "src" / "bin"
        if bin_dir.exists():
            for f in bin_dir.glob("*.rs"):
                recs.append({"file_name": f.name, "content": f.read_text()})
        return sorted(recs, key=lambda d: d["file_name"])

    def get_interface_pairs(self) -> List[Tuple[Path, Path]]:
        """
        For each project, returns list of (c_path, rust_interface_path)
        by matching stems.
        """
        pairs = []
        for proj in self.projects:
            c_recs = self.get_c_files(proj)
            c_recs = self.process_c_and_h(c_recs)
            rust_intfs = self.load_rust_interfaces(proj)

            # build a map from stem to C record
            cmap = {Path(r["file_name"]).stem: r for r in c_recs}
            for ri in rust_intfs:
                stem = Path(ri["file_name"]).stem
                if stem in cmap:
                    c_file = cmap[stem]["content"]
                    r_file = ri["content"]
                    pairs.append((c_file, r_file))
        return pairs

    def export_jsonl(self, out_train: str, out_valid: str, split: float = 0.9):
        """
        Write train/valid JSONL with instruction-style prompts & completions.
        """
        exs = []
        for ctext, rtext in self.get_interface_pairs():
            exs.append({"ctext": ctext, "rtext": rtext})

        cut = int(len(exs) * split)
        for path, batch in [(out_train, exs[:cut]), (out_valid, exs[cut:])]:
            p = Path(path)
            p.parent.mkdir(parents=True, exist_ok=True)
            with p.open("w") as f:
                for e in batch:
                    f.write(json.dumps(e) + "\n")


In [None]:
loader = CRUSTLoader("datasets/CBench", "datasets/RBench")
loader.export_jsonl("train.jsonl", "valid.jsonl")

## 4. Convert to instruction-style (chat-style) data

In [None]:
import json
from pathlib import Path

def to_chat_example(line):
    ex = json.loads(line)
    # take your old prompt/completion
    ctext, rtext = ex["ctext"], ex["rtext"]
    return {
      "messages": [
        {"role":"system","content":"You are a Rust expert tasked with translating C to Rust."},
        {"role":"user","content": "Translate the following C code into fully implemented, idiomatic, and compilable Rust.```c" +
                    ctext + "```"},
        {"role":"assistant","content": "```rust" + rtext + "```"}
      ],
    }

def convert(in_path: Path, out_path: Path):
    with in_path.open() as fin, out_path.open("w") as fout:
        for line in fin:
            fout.write(json.dumps(to_chat_example(line)) + "\n")

# Convert both train & valid
convert(Path("train.jsonl"), Path("train_stage1.jsonl"))
convert(Path("valid.jsonl"), Path("valid_stage1.jsonl"))

## 4.1 Stage 1 upload Data

In [7]:
from dotenv import load_dotenv
from openai import OpenAI
import os

load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# Upload dataset
train_file = client.files.create(
    file=open('train_stage1.jsonl', 'rb'),
    purpose='fine-tune'
)
valid_file = client.files.create(
    file=open('valid_stage1.jsonl', 'rb'),
    purpose='fine-tune'
)

print(f"Train file ID: {train_file.id}")
print(f"Valid file ID: {valid_file.id}")

Train file ID: file-Uc7nuJ3vq2a175zy9rSGNc
Valid file ID: file-QR2o1ETk6BFHM9GNERgZJi


## 4.2 Stage 1 Fine-Tuning (Stubs)
Upload files and start fine-tune on `gpt-3.5-turbo`.

In [8]:
# Launch fine-tune
job = client.fine_tuning.jobs.create(
    training_file=train_file.id,
    validation_file=valid_file.id,
    model='gpt-4.1-nano-2025-04-14'
)
print("Stage1 FT Job ID:", job.id)


Stage1 FT Job ID: ftjob-4LPBSZI2pisefYDjuonYaJSq


## 4.3: Monitoring Training Metrics
We poll the fine-tuning job's events to extract **`metrics`** events, which include:
- Training and validation loss per step
- Mean token-level accuracy metrics

Printing these in real time helps you verify convergence and detect issues early.

In [47]:
# 7. Monitor Fine-Tuning Job
# Replace JOB_ID with your actual fine-tune job ID
import time
import pandas as pd

def print_metrics(job_id):

    # Poll until at least some metrics are available
    all_metrics = []
    while True:
        evts = client.fine_tuning.jobs.list_events(
            fine_tuning_job_id=job_id, limit=100
        ).data
        # extract only the metrics events
        for e in evts:
            print(f"Event: {e.type} - {e.created_at} - {e.message}")
            if e.type == "metrics" and e.data not in all_metrics:
                all_metrics.append(e.data)
                print(f"New metrics event: {e.data}")
                print(f"Step {e.data['step']}: train_loss={e.data['train_loss']:.4f}, valid_loss={e.data.get('valid_loss', -1):.4f}")
        # exit once job is done
        status = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id).status
        if status in ("succeeded", "failed"):
            break
        time.sleep(30)

In [48]:

job_id = "ftjob-4LPBSZI2pisefYDjuonYaJSq"

print_metrics(job_id)

Event: message - 1750191642 - The job has successfully completed
Event: message - 1750191635 - Usage policy evaluations completed, model is now enabled for sampling
Event: moderation_checks - 1750191635 - Moderation checks for snapshot ft:gpt-4.1-nano-2025-04-14:personal::BjWp3qnB passed.
Event: message - 1750190629 - Evaluating model against our usage policies
Event: message - 1750190629 - New fine-tuned model created
Event: message - 1750190629 - Checkpoint created at step 386
Event: message - 1750190629 - Checkpoint created at step 193
Event: metrics - 1750190581 - Step 579/579: training loss=0.20, validation loss=0.62, full validation loss=0.42
New metrics event: {'step': 579, 'train_loss': 0.20390063524246216, 'valid_loss': 0.6245212878211069, 'total_steps': 579, 'full_valid_loss': 0.4175889157885576, 'train_mean_token_accuracy': 0.9356725215911865, 'valid_mean_token_accuracy': 0.8135593220338984, 'full_valid_mean_token_accuracy': 0.8911336873701223}
Step 579: train_loss=0.2039, v

## 5.1 Stage 2: Pseudo-Labeling with GPT-4
Here, we use the powerful `gpt-4` chat API to generate **full Rust implementations** from the C code plus the existing stub.  
- For each entry, we prompt with both the original C snippet and the Rust stub.
- The model returns a complete Rust translation, which we capture as a new training example.

## 5.2 Instruction-Style Fine-Tuning (Stage 2)
We now fine-tune a second GPT-3.5 chat model on the fully-transpiled examples (`full_transpile.jsonl`). This encourages the model to learn true end-to-end C→Rust transpilation, not just stub filling.

In [None]:
def transpile(ctext: str, stub: str) -> str:
    messages = [
        {"role": "system", "content": "You are a Rust expert."},
        {"role": "user", "content": (
            "Translate this C code into fully implemented, idiomatic Rust. "
            "Here’s the C code:\n```c\n"
            + ctext +
            "\n```\n"
            "And here’s the existing Rust stub:\n```rust\n"
            + stub +
            "\n```"
        )}
    ]
    resp = client.chat.completions.create(
        model="gpt-4.1-nano",
        messages=messages,
        max_tokens=2048,
        temperature=0
    )
    return resp.choices[0].message.content

input_path  = Path("train.jsonl")
output_path = Path("full_transpile.jsonl")
transpile_path = Path("train_transpile.jsonl")
entries = []

new_data = []

for line in input_path.read_text().splitlines():
    ex    = json.loads(line)
    ctext = ex["ctext"]
    stub  = ex["rtext"]
    full  = transpile(ctext, stub)
    
    new_data.append({
        "ctext": ctext,
        "rtext": stub,
        "full": full
    })

    # now build a chat-style example
    messages = [
      {"role":"system",   "content":"You are a Rust expert."},
      {"role":"user",     "content":(
         f"Translate this C code into fully implemented Rust.\n"
         f"```c\n{ctext}\n```\n"
      )},
      {"role":"assistant","content": full}
    ]

    entries.append({"messages": messages})
    

with transpile_path.open("w") as f:
    for ex in new_data:
        f.write(json.dumps(ex) + "\n")

with output_path.open("w") as f:
    for ex in entries:
        f.write(json.dumps(ex) + "\n")

print(f"Wrote {len(entries)} chat-style examples to {output_path}")


Wrote 193 chat-style examples to full_transpile.jsonl


## 5.3. Stage 2 Fine-Tune on Full Transpiled Data

In [49]:
# Upload full_transpile.jsonl
train_transpile_file = client.files.create(
    file=open('full_transpile.jsonl', 'rb'),
    purpose='fine-tune'
)

print(f"Full Transpile Train file ID: {train_transpile_file.id}")

Full Transpile Train file ID: file-EzgWV7tQ5gUEXFmpzmBEEb


In [50]:
# Launch second-stage fine-tune
job = client.fine_tuning.jobs.create(
    training_file=train_transpile_file.id,
    model='gpt-4.1-nano-2025-04-14'
)
print("Stage2 FT Job ID:", job.id)

Stage2 FT Job ID: ftjob-sHK5yvlKmPHDQGZ5KD9fANb4


In [57]:
job_id = job.id
print_metrics(job_id)

Event: message - 1750197201 - The job has successfully completed
Event: message - 1750197195 - Usage policy evaluations completed, model is now enabled for sampling
Event: moderation_checks - 1750197195 - Moderation checks for snapshot ft:gpt-4.1-nano-2025-04-14:personal::BjYGioVU passed.
Event: message - 1750196189 - Evaluating model against our usage policies
Event: message - 1750196189 - New fine-tuned model created
Event: message - 1750196188 - Checkpoint created at step 386
Event: message - 1750196188 - Checkpoint created at step 193
Event: metrics - 1750196180 - Step 579/579: training loss=0.07
New metrics event: {'step': 579, 'train_loss': 0.06870583444833755, 'total_steps': 579, 'train_mean_token_accuracy': 0.976097583770752}
Step 579: train_loss=0.0687, valid_loss=-1.0000
Event: metrics - 1750196180 - Step 578/579: training loss=0.23
New metrics event: {'step': 578, 'train_loss': 0.23106878995895386, 'total_steps': 579, 'train_mean_token_accuracy': 0.9442622661590576}
Step 578

## 6: Qualitative Evaluation
Finally, we evaluate a handful of examples for:
- **Quality**: computing a token-level **BLEU** score against the reference Rust implementation.


In [61]:
import subprocess, tempfile
from nltk.translate.bleu_score import sentence_bleu

FT_MODEL_STAGE2 = 'ft:gpt-4.1-nano-2025-04-14:personal::BjYGioVU'  # replace with actual ID

def bleu(ref: str, pred: str) -> float:
    return sentence_bleu([ref.split()], pred.split(), weights=(0.25,)*4)

# Evaluate first 5 samples
samples = [json.loads(l) for l in open('full_transpile.jsonl')][:5]
results = []
for s in samples:
    resp = client.chat.completions.create(
        model=FT_MODEL_STAGE2,
        messages=s['messages'][:2],
        max_tokens=1024,
        temperature=0
    )
    gen = resp.choices[0].message.content.strip()
    results.append({
        'bleu': bleu(s['messages'][2]['content'], gen)
    })

print(results)


[{'bleu': 0.4502649507952619}, {'bleu': 0.21063448590878106}, {'bleu': 0.2392737051184368}, {'bleu': 0.2760314019175371}, {'bleu': 0.2577859944295302}]
