## Project Description: Vision-to-Code Baseline

This project implements a baseline model for images into CADQuery code using a vision-language architecture. The baseline combines a pre-trained ResNet-50 image encoder with a custom-trained GPT-2 style language model to generate code descriptions from images.

Key Features:

Dataset: Utilizes the CADCODER/GenCAD-Code dataset, which contains paired CAD images and code.

Image Encoder: Extracts image features using a ResNet-50 model, optionally frozen during training.

Code Decoder: Generates CADQuery code with a lightweight GPT-2 model, using a custom Byte-Pair Encoding (BPE) tokenizer trained on the dataset.

Training: Supports mixed-precision (AMP), efficient batching, and standard evaluation metrics including syntax validity and code overlap (IoU).

Reproducibility: Ensures deterministic training with random seed control and includes checkpoint saving.

The baseline serves as a strong starting point for research and development in image-to-code translation for CAD applications.


##The primary bottlenecks

I encountered in this project were related to both environment setup and compute resources:

Environment Configuration Issues: I ran into conflicting dependencies due to not thoroughly checking the pyproject.toml for version compatibility. Resolving these conflicts and properly configuring the environment took almost three hours, significantly cutting into the time available for model experimentation and tuning. This experience highlights how crucial it is to carefully manage dependencies in machine learning projects, especially when integrating multiple libraries.

GPU and Compute Limitations: My experiments were restricted to using Google Colab, which provides limited GPU resources. This limited the size and complexity of models I could use, as well as the number of experiments I could run within the available time. Training large image-to-code models or running extensive hyperparameter searches was not feasible in this environment.

Together, these bottlenecks—setup time and restricted compute—were the main constraints on what I could achieve in this project.

## Possible Enhancements with More Time

If I had more time and resources, I would focus on the following enhancements:

Experiment with Larger Decoder Models: I would test bigger and more expressive decoder architectures to potentially improve the accuracy of CadQuery code generation.

Train for Longer: More training epochs would allow the model to learn more effectively from the dataset.

Hyperparameter Tuning: I would try different learning rates, batch sizes, and optimizers to optimize performance.

Improve Code Organization: I would refactor the codebase to be more organized and modular, making it easier to understand, maintain, and extend in the future.

Reduce Redundancy: I would clean up the code to eliminate repeated or unnecessary components, ensuring more efficient and readable scripts.


Bottlenecks
The primary bottlenecks I encountered in this project were related to both environment setup and compute resources:

Environment Configuration Issues: I ran into conflicting dependencies due to not thoroughly checking the pyproject.toml for version compatibility. Resolving these conflicts and properly configuring the environment took almost three hours, significantly cutting into the time available for model experimentation and tuning. This experience highlights how crucial it is to carefully manage dependencies in machine learning projects, especially when integrating multiple libraries.

GPU and Compute Limitations: My experiments were restricted to using Google Colab, which provides limited GPU resources. This limited the size and complexity of models I could use, as well as the number of experiments I could run within the available time. Training large image-to-code models or running extensive hyperparameter searches was not feasible in this environment.

Together, these bottlenecks—setup time and restricted compute—were the main constraints on what I could achieve in this project.

In [None]:
!pip install "cadquery>=2.5.2" \
             "datasets>=3.6.0" \
             "ipykernel>=6.29.5" \
             "scipy>=1.15.3" \
             "trimesh>=4.6.11"


Collecting cadquery>=2.5.2
  Downloading cadquery-2.5.2-py3-none-any.whl.metadata (16 kB)
Collecting datasets>=3.6.0
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting ipykernel>=6.29.5
  Downloading ipykernel-6.29.5-py3-none-any.whl.metadata (6.3 kB)
Collecting trimesh>=4.6.11
  Downloading trimesh-4.6.12-py3-none-any.whl.metadata (18 kB)
Collecting cadquery-ocp<7.8,>=7.7.0 (from cadquery>=2.5.2)
  Downloading cadquery_ocp-7.7.2-cp311-cp311-manylinux_2_35_x86_64.whl.metadata (1.6 kB)
Collecting ezdxf (from cadquery>=2.5.2)
  Downloading ezdxf-1.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.8 kB)
Collecting multimethod<2.0,>=1.11 (from cadquery>=2.5.2)
  Downloading multimethod-1.12-py3-none-any.whl.metadata (9.6 kB)
Collecting nlopt<3.0,>=2.9.0 (from cadquery>=2.5.2)
  Downloading nlopt-2.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)
Collecting typish (from cadquery>=2.5.2)
  Downloading typish-1.9.

In [None]:
import argparse, importlib.util, runpy, tempfile, itertools, sys
from pathlib import Path
import os
import cadquery as cq
from cadquery import exporters
import numpy as np
import trimesh
from typing import Union
import textwrap

os.environ["CADQUERY_LOG_LEVEL"] = "ERROR"


# ---------- helpers ---------------------------------------------------------


def _load_solid_from_code(
    code: str, script_id: str = "unknown"
) -> Union[cq.Solid, cq.Compound]:
    """Execute Python code and return any CadQuery object found."""
    # Clean up indentation issues
    cleaned_code = textwrap.dedent(code).strip()

    # Provide necessary imports in the execution namespace
    ns = {"cq": cq, "cadquery": cq, "np": np, "numpy": np, "__builtins__": __builtins__}
    try:
        exec(cleaned_code, ns)
    except Exception as e:
        raise ValueError(f"Error executing script {script_id}: {e}")

    # Find any CadQuery objects in the namespace
    cadquery_objects = []
    for var_name, var_value in ns.items():
        if isinstance(var_value, (cq.Workplane, cq.Solid, cq.Compound)):
            cadquery_objects.append((var_name, var_value))

    if not cadquery_objects:
        raise ValueError(
            f"No CadQuery objects (Workplane, Solid, or Compound) found in script {script_id}"
        )

    if len(cadquery_objects) > 1:
        # If multiple objects, prefer common names
        preferred_names = ["solid", "result", "shape", "part", "object", "obj", "res"]
        for preferred in preferred_names:
            for var_name, var_value in cadquery_objects:
                if var_name == preferred:
                    cadquery_objects = [(var_name, var_value)]
                    break
            if len(cadquery_objects) == 1:
                break

        # If still multiple, just take the first one but warn
        if len(cadquery_objects) > 1:
            var_names = [name for name, _ in cadquery_objects]
            print(
                f"Warning: Multiple CadQuery objects found in {script_id}: {var_names}. Using '{cadquery_objects[0][0]}'"
            )

    var_name, solid_obj = cadquery_objects[0]
    # print(
    #     f"Found CadQuery object in variable '{var_name}' of type {type(solid_obj).__name__}"
    # )

    # Handle different CadQuery object types
    if isinstance(solid_obj, cq.Workplane):
        # Extract the solid from the workplane
        solid_obj = solid_obj.val()

    # Handle Compound objects (multiple solids combined)
    if hasattr(solid_obj, "Solids") and callable(getattr(solid_obj, "Solids")):
        solids = solid_obj.Solids()
        if len(solids) == 1:
            solid_obj = solids[0]
        elif len(solids) > 1:
            # If multiple solids, we need to combine them into one
            # Use the compound itself if it's valid for our purposes
            pass  # Keep the compound as is
        else:
            raise ValueError(f"No solids found in compound in script {script_id}")

    # Accept both Solid and Compound objects for our mesh operations
    if not isinstance(solid_obj, (cq.Solid, cq.Compound)):
        raise ValueError(
            f"CadQuery object '{var_name}' is not a Solid or Compound object in script {script_id}, got {type(solid_obj)}"
        )

    return solid_obj


def _load_solid(script_path: Path) -> cq.Solid:
    """Import a CadQuery script in isolation and return the 'solid' object."""
    ns = runpy.run_path(script_path)  # executes the file
    if "solid" not in ns or not isinstance(ns["solid"], cq.Solid):
        raise ValueError(f"'solid' not found in {script_path}")
    return ns["solid"]


def _root_gyration(solid: Union[cq.Solid, cq.Compound]) -> float:
    vol = solid.Volume()
    inertia = np.array(cq.Shape.matrixOfInertia(solid)).reshape(3, 3)
    return np.sqrt(np.trace(inertia) / (2.0 * vol))


def _normalized_mesh(
    solid: Union[cq.Solid, cq.Compound], pitch: float = 0.01
) -> trimesh.Trimesh:
    """Translate to centroid, isotropically scale by r_g, and return a mesh."""
    r_g = _root_gyration(solid)
    center_vector = solid.Center()
    centroid = np.array([center_vector.x, center_vector.y, center_vector.z])
    # Export to temporary STL then load with trimesh
    with tempfile.TemporaryDirectory() as tmp:
        stl_path = Path(tmp) / "part.stl"
        exporters.export(solid, str(stl_path))
        mesh = trimesh.load(str(stl_path), force="mesh")
    mesh.apply_translation(-centroid)
    mesh.apply_scale(1.0 / r_g)
    return mesh


def _principal_axes(mesh: trimesh.Trimesh) -> np.ndarray:
    """Return 3×3 orthonormal matrix whose columns are principal axes."""
    inertia = mesh.moment_inertia
    _, vecs = np.linalg.eigh(inertia)
    return vecs  # columns are eigenvectors


def _apply_rotation(mesh: trimesh.Trimesh, R: np.ndarray) -> trimesh.Trimesh:
    T = np.eye(4)
    T[:3, :3] = R
    mesh_rot = mesh.copy()
    mesh_rot.apply_transform(T)
    return mesh_rot


def _voxel_bool_unified(
    mesh1: trimesh.Trimesh, mesh2: trimesh.Trimesh, pitch: float = 0.05
) -> tuple[np.ndarray, np.ndarray]:
    """Create voxel grids for both meshes using unified bounds."""
    # Voxelize each mesh individually first
    voxel1 = mesh1.voxelized(pitch)
    voxel2 = mesh2.voxelized(pitch)

    # Get the bounds of each voxel grid
    bounds1 = voxel1.bounds
    bounds2 = voxel2.bounds

    # Compute unified bounds
    min_bounds = np.minimum(bounds1[0], bounds2[0])
    max_bounds = np.maximum(bounds1[1], bounds2[1])

    # Calculate grid dimensions
    grid_size = np.ceil((max_bounds - min_bounds) / pitch).astype(int)

    # Create empty unified voxel grids
    vox1 = np.zeros(grid_size, dtype=bool)
    vox2 = np.zeros(grid_size, dtype=bool)

    # Calculate offsets for placing each voxel grid in the unified space
    offset1 = np.round((bounds1[0] - min_bounds) / pitch).astype(int)
    offset2 = np.round((bounds2[0] - min_bounds) / pitch).astype(int)

    # Get shapes of individual voxel matrices
    shape1 = voxel1.matrix.shape
    shape2 = voxel2.matrix.shape

    # Calculate end positions
    end1 = offset1 + shape1
    end2 = offset2 + shape2

    # Place voxels in unified grids with bounds checking
    if np.all(offset1 >= 0) and np.all(end1 <= grid_size):
        vox1[offset1[0] : end1[0], offset1[1] : end1[1], offset1[2] : end1[2]] = (
            voxel1.matrix
        )

    if np.all(offset2 >= 0) and np.all(end2 <= grid_size):
        vox2[offset2[0] : end2[0], offset2[1] : end2[1], offset2[2] : end2[2]] = (
            voxel2.matrix
        )

    return vox1, vox2


def _voxel_bool(mesh: trimesh.Trimesh, pitch: float = 0.05) -> np.ndarray:
    vox = mesh.voxelized(pitch)
    return vox.matrix  # boolean 3-D numpy array


def iou_best(
    mesh_gt: trimesh.Trimesh, mesh_pred: trimesh.Trimesh, pitch: float = 0.05
) -> float:
    """IOU after best principal-axis alignment (4 valid sign flips)."""
    axes_gt = _principal_axes(mesh_gt)
    axes_pr = _principal_axes(mesh_pred)

    best = 0.0
    for signs in [(1, 1, 1), (1, 1, -1), (1, -1, 1), (-1, 1, 1)]:
        D = np.diag(signs)
        axes_pr_flipped = axes_pr @ D  # change axis directions
        R = axes_gt @ axes_pr_flipped.T  # rotation to align
        m_aligned = _apply_rotation(mesh_pred, R)

        # Use unified voxelization
        vox_gt, vox_pr = _voxel_bool_unified(mesh_gt, m_aligned, pitch)

        inter = np.logical_and(vox_gt, vox_pr).sum()
        union = np.logical_or(vox_gt, vox_pr).sum()

        if union > 0:
            iou = inter / union
            best = max(best, iou)

    return best


# ---------- main ------------------------------------------------------------


def evaluate_codes(gt_codes: dict, pred_codes: dict, pitch: float = 0.05):
    """Evaluate predictions against ground-truth using Python code directly.

    Args:
        gt_codes: Dict with IDs as keys and ground-truth Python code as values
        pred_codes: Dict with IDs as keys and prediction Python code as values
        pitch: Voxel pitch for IoU calculation
    """
    ids = sorted(gt_codes.keys())
    if not ids:
        sys.exit("no ground-truth scripts provided")

    vsr_success = 0
    ious = []

    for _id in ids:
        if _id not in pred_codes:
            print(f"missing prediction for {_id}, skipping")
            continue

        try:
            solid_gt = _load_solid_from_code(gt_codes[_id], f"gt_{_id}")
            solid_pr = _load_solid_from_code(pred_codes[_id], f"pred_{_id}")
            vsr_success += 1
        except Exception as exc:
            print(f"{_id}: syntax/runtime error -> {exc}")
            continue

        mesh_gt = _normalized_mesh(solid_gt)
        mesh_pr = _normalized_mesh(solid_pr)
        ious.append(iou_best(mesh_gt, mesh_pr, pitch))

    n_total = len(ids)
    vsr = vsr_success / n_total if n_total else 0.0
    iou_b = np.mean(ious) if ious else 0.0

    print(f"Valid Syntax Rate: {vsr:.3f}")
    print(f"Mean IOU_best   : {iou_b:.3f}")

    return {"vsr": vsr, "iou_best": iou_b}


def evaluate(gt_dir: Path, pred_dir: Path, pitch: float = 0.05):
    """Original file-based evaluation function."""
    ids = sorted(p.stem for p in gt_dir.glob("*.py"))
    if not ids:
        sys.exit("no ground-truth scripts found")

    vsr_success = 0
    ious = []

    for _id in ids:
        gt_path = gt_dir / f"{_id}.py"
        pr_path = pred_dir / f"{_id}.py"
        if not pr_path.exists():
            print(f"missing prediction for {_id}, skipping")
            continue

        try:
            solid_gt = _load_solid(gt_path)
            solid_pr = _load_solid(pr_path)
            vsr_success += 1
        except Exception as exc:
            print(f"{_id}: syntax/runtime error -> {exc}")
            continue

        mesh_gt = _normalized_mesh(solid_gt)
        mesh_pr = _normalized_mesh(solid_pr)
        ious.append(iou_best(mesh_gt, mesh_pr, pitch))

    n_total = len(ids)
    vsr = vsr_success / n_total if n_total else 0.0
    iou_b = np.mean(ious) if ious else 0.0

    print(f"Valid Syntax Rate: {vsr:.3f}")
    print(f"Mean IOU_best   : {iou_b:.3f}")


def get_iou_best(code1: str, code2: str):
    solid1 = _load_solid_from_code(code1)
    solid2 = _load_solid_from_code(code2)
    mesh1 = _normalized_mesh(solid1)
    mesh2 = _normalized_mesh(solid2)
    iou = iou_best(mesh1, mesh2)
    return iou


if __name__ == "__main__":
    code1 = """
        height = 60.0
        width = 80.0
        thickness = 10.0
        res = cq.Workplane("XY").box(height, width, thickness)
    """
    code2 = """
        height = 60.0
 width = 80.0
 thickness = 10.0
 diameter = 22.0
 padding = 12.0

 # make the base
 result = (
     cq.Workplane("XY")
     .box(height, width, thickness)
     .faces(">Z")
     .workplane()
     .hole(diameter)
     .faces(">Z")
     .workplane()
     .rect(height - padding, width - padding, forConstruction=True)
     .vertices()
     .cboreHole(2.4, 4.4, 2.1)
 )
    """
    solid1 = _load_solid_from_code(code1)
    solid2 = _load_solid_from_code(code2)
    mesh1 = _normalized_mesh(solid1)
    mesh2 = _normalized_mesh(solid2)
    iou = iou_best(mesh1, mesh2)
    print(f"IOU: {iou}")


IOU: 0.5834943417057687


In [None]:
import sys
import os
import cadquery as cq
import numpy as np
import textwrap
from typing import Union, Dict, List

os.environ["CADQUERY_LOG_LEVEL"] = "ERROR"


def _load_solid_from_code(
    code: str, script_id: str = "unknown"
) -> Union[cq.Solid, cq.Compound]:
    """Execute Python code and return any CadQuery object found."""
    # Clean up indentation issues
    cleaned_code = textwrap.dedent(code).strip()

    # Provide necessary imports in the execution namespace
    ns = {"cq": cq, "cadquery": cq, "np": np, "numpy": np, "__builtins__": __builtins__}
    try:
        exec(cleaned_code, ns)
    except Exception as e:
        raise ValueError(f"Error executing script {script_id}: {e}")

    # Find any CadQuery objects in the namespace
    cadquery_objects = []
    for var_name, var_value in ns.items():
        if isinstance(var_value, (cq.Workplane, cq.Solid, cq.Compound)):
            cadquery_objects.append((var_name, var_value))

    if not cadquery_objects:
        raise ValueError(
            f"No CadQuery objects (Workplane, Solid, or Compound) found in script {script_id}"
        )

    if len(cadquery_objects) > 1:
        # If multiple objects, prefer common names
        preferred_names = ["solid", "result", "shape", "part", "object", "obj", "res"]
        for preferred in preferred_names:
            for var_name, var_value in cadquery_objects:
                if var_name == preferred:
                    cadquery_objects = [(var_name, var_value)]
                    break
            if len(cadquery_objects) == 1:
                break

        # If still multiple, just take the first one but warn
        if len(cadquery_objects) > 1:
            var_names = [name for name, _ in cadquery_objects]
            print(
                f"Warning: Multiple CadQuery objects found in {script_id}: {var_names}. Using '{cadquery_objects[0][0]}'"
            )

    var_name, solid_obj = cadquery_objects[0]

    # Handle different CadQuery object types
    if isinstance(solid_obj, cq.Workplane):
        # Extract the solid from the workplane
        solid_obj = solid_obj.val()

    # Handle Compound objects (multiple solids combined)
    if hasattr(solid_obj, "Solids") and callable(getattr(solid_obj, "Solids")):
        solids = solid_obj.Solids()
        if len(solids) == 1:
            solid_obj = solids[0]
        elif len(solids) > 1:
            # If multiple solids, we need to combine them into one
            # Use the compound itself if it's valid for our purposes
            pass  # Keep the compound as is
        else:
            raise ValueError(f"No solids found in compound in script {script_id}")

    # Accept both Solid and Compound objects for our mesh operations
    if not isinstance(solid_obj, (cq.Solid, cq.Compound)):
        raise ValueError(
            f"CadQuery object '{var_name}' is not a Solid or Compound object in script {script_id}, got {type(solid_obj)}"
        )

    return solid_obj


def evaluate_syntax_rate(
    codes: Dict[str, str], verbose: bool = True
) -> Dict[str, Union[float, int, List[str]]]:
    """Evaluate valid syntax rate for a dictionary of CadQuery code strings.

    Args:
        codes: Dict with IDs as keys and Python code strings as values
        verbose: Whether to print detailed results

    Returns:
        Dict with 'vsr' (valid syntax rate), 'successful' (count), 'total' (count),
        'failed_ids' (list of IDs that failed)
    """
    if not codes:
        if verbose:
            print("No code provided")
        return {"vsr": 0.0, "successful": 0, "total": 0, "failed_ids": []}

    ids = sorted(codes.keys())
    successful_count = 0
    failed_ids = []

    for script_id in ids:
        code = codes[script_id]
        try:
            solid = _load_solid_from_code(code, script_id)
            successful_count += 1
            if verbose:
                print(f"✓ {script_id}: Successfully executed")
        except Exception as exc:
            failed_ids.append(script_id)
            if verbose:
                print(f"✗ {script_id}: {exc}")

    total_count = len(ids)
    vsr = successful_count / total_count if total_count > 0 else 0.0

    if verbose:
        print(f"\n--- SUMMARY ---")
        print(f"Successful: {successful_count}/{total_count}")
        print(f"Valid Syntax Rate: {vsr:.3f}")
        if failed_ids:
            print(f"Failed IDs: {failed_ids}")

    return {
        "vsr": vsr,
        "successful": successful_count,
        "total": total_count,
        "failed_ids": failed_ids,
    }


def evaluate_syntax_rate_simple(codes: Dict[str, str]) -> float:
    """Simple function that just returns the valid syntax rate as a float."""
    result = evaluate_syntax_rate(codes, verbose=False)
    return result["vsr"]


if __name__ == "__main__":
    # Test cases
    test_codes = {
        "simple_box": """
            height = 60.0
            width = 80.0
            thickness = 10.0
            result = cq.Workplane("XY").box(height, width, thickness)
        """,
        "box_with_hole": """
            height = 60.0
            width = 80.0
            thickness = 10.0
            diameter = 22.0
            padding = 12.0

            # make the base
            result = (
                cq.Workplane("XY")
                .box(height, width, thickness)
                .faces(">Z")
                .workplane()
                .hole(diameter)
                .faces(">Z")
                .workplane()
                .rect(height - padding, width - padding, forConstruction=True)
                .vertices()
                .cboreHole(2.4, 4.4, 2.1)
            )
        """,
        "syntax_error": """
            result = cq.Workplane("XY").box(10, 10, 10
            # Missing closing parenthesis
        """,
        "runtime_error": """
            result = cq.Workplane("XY").box(undefined_variable, 10, 10)
        """,
        "no_cadquery_object": """
            x = 5
            y = 10
            z = x + y
        """,
    }

    print("Testing Valid Syntax Rate evaluation:")
    print("=" * 50)

    result = evaluate_syntax_rate(test_codes)
    print(f"\nOverall VSR: {result['vsr']:.1%}")


Testing Valid Syntax Rate evaluation:
✓ box_with_hole: Successfully executed
✗ no_cadquery_object: No CadQuery objects (Workplane, Solid, or Compound) found in script no_cadquery_object
✗ runtime_error: Error executing script runtime_error: name 'undefined_variable' is not defined
✓ simple_box: Successfully executed
✗ syntax_error: Error executing script syntax_error: '(' was never closed (<string>, line 1)

--- SUMMARY ---
Successful: 2/5
Valid Syntax Rate: 0.400
Failed IDs: ['no_cadquery_object', 'runtime_error', 'syntax_error']

Overall VSR: 40.0%


# Train

In [None]:
import os, sys, random
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights
from datasets import load_dataset, disable_caching
from tokenizers import ByteLevelBPETokenizer
from transformers import GPT2Config, GPT2LMHeadModel, PreTrainedTokenizerFast
from tqdm.auto import tqdm
from huggingface_hub import login


if sys.version_info < (3, 9):
    raise RuntimeError("Python ≥3.9 required – upgrade runtime.")

login(token=os.getenv("HF_TOKEN", ""))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
use_amp = torch.cuda.is_available()


disable_caching()

@dataclass
class Config:
    dataset_name: str = "CADCODER/GenCAD-Code"
    cache_dir: str = "./hf_cache"
    img_size: int = 224
    max_len: int = 512
    batch_size: int = 2
    num_workers: int = 2
    learning_rate: float = 3e-4
    weight_decay: float = 1e-2
    epochs: int = 1
    decode_max_len: int = 256
    beam_size: int = 1
    freeze_vision: bool = True
    tokenizer_slice: str = "train"

cfg = Config()
Path(cfg.cache_dir).mkdir(parents=True, exist_ok=True)


#Tokeniser (BPE)

TOKENIZER_JSON = Path("tokenizer/tokenizer.json")
TOKENIZER_JSON.parent.mkdir(exist_ok=True)
if not TOKENIZER_JSON.exists():
    print("Tokenizer not found – training BPE …")
    ds_sample = load_dataset(cfg.dataset_name, split=cfg.tokenizer_slice,
                             cache_dir=cfg.cache_dir, token=True)
    codes = ds_sample["cadquery"]
    tok = ByteLevelBPETokenizer()
    tok.train_from_iterator(codes, vocab_size=32_768, min_frequency=2,
                            special_tokens=["<s>", "<pad>", "</s>", "<unk>"])
    tok.save(str(TOKENIZER_JSON))

hf_tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(TOKENIZER_JSON),
                                       bos_token="<s>", eos_token="</s>", pad_token="<pad>")


# Dataset & loaders
class CadQueryDataset(Dataset):
    """Wraps the HF dataset to return image tensors and token IDs."""

    def __init__(self, split: str | None = None):
        if split is not None:
            self.ds = load_dataset(cfg.dataset_name, split=split, cache_dir=cfg.cache_dir)
        self.tf = transforms.Compose([
            transforms.Resize((cfg.img_size, cfg.img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3),
        ])

    @classmethod
    def from_raw(cls, hf_ds):
        """Construct directly from an already‑loaded HF split (to avoid double I/O)."""
        obj = cls.__new__(cls)
        obj.ds = hf_ds
        obj.tf = transforms.Compose([
            transforms.Resize((cfg.img_size, cfg.img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3),
        ])
        return obj

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx: int):
        item = self.ds[idx]
        img = self.tf(item["image"])
        ids = hf_tokenizer.encode(item["cadquery"], truncation=True,
                                  max_length=cfg.max_len, add_special_tokens=False)
        return {"pixel_values": img, "labels": torch.tensor(ids, dtype=torch.long)}


def collate_fn(batch):
    imgs = torch.stack([b["pixel_values"] for b in batch])
    lens = [len(b["labels"]) for b in batch]
    max_len = max(lens)
    labels = torch.full((len(batch), max_len), hf_tokenizer.pad_token_id, dtype=torch.long)
    for i, b in enumerate(batch):
        labels[i, :lens[i]] = b["labels"]
    return {"pixel_values": imgs, "labels": labels}


print("Loading dataset splits …")
train_hf = load_dataset(cfg.dataset_name, split="train", cache_dir=cfg.cache_dir).shuffle(seed=42)
val_hf = load_dataset(cfg.dataset_name, split="validation", cache_dir=cfg.cache_dir)

train_loader = DataLoader(CadQueryDataset.from_raw(train_hf),
                          batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True,
                          collate_fn=collate_fn)

val_loader = DataLoader(CadQueryDataset.from_raw(val_hf),
                        batch_size=cfg.batch_size, shuffle=False,
                        num_workers=cfg.num_workers, pin_memory=True,
                        collate_fn=collate_fn)


#Model

class ResNetEncoder(nn.Module):
    def __init__(self, freeze=True):
        super().__init__()
        rn = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.features = nn.Sequential(*list(rn.children())[:-1])
        self.proj = nn.Linear(rn.fc.in_features, 1024)
        if freeze:
            for p in self.features.parameters():
                p.requires_grad = False

    def forward(self, x):
        return self.proj(torch.flatten(self.features(x), 1))


class Vision2Code(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = ResNetEncoder(cfg.freeze_vision)
        gpt_cfg = GPT2Config(vocab_size=len(hf_tokenizer), n_positions=cfg.max_len, n_ctx=cfg.max_len,
                             n_layer=8, n_head=8, n_embd=1024,
                             bos_token_id=hf_tokenizer.bos_token_id,
                             eos_token_id=hf_tokenizer.eos_token_id,
                             pad_token_id=hf_tokenizer.pad_token_id)
        self.decoder = GPT2LMHeadModel(gpt_cfg)
        self.vis_proj = nn.Linear(1024, 1024)

    def forward(self, pixel_values, labels=None):
        B = pixel_values.size(0)
        prefix = self.vis_proj(self.encoder(pixel_values)).unsqueeze(1)

        if labels is not None:
            bos = torch.full((B, 1), hf_tokenizer.bos_token_id, device=pixel_values.device)
            inp_tok = torch.cat([bos, labels], dim=1)[:, :cfg.max_len]
            tok_emb = self.decoder.transformer.wte(inp_tok)
            inputs_embeds = torch.cat([prefix, tok_emb], dim=1)[:, :cfg.max_len]


            ignore = torch.full((B, inputs_embeds.size(1) - labels.size(1)), -100,
                                dtype=torch.long, device=pixel_values.device)
            dec_labels = torch.cat([ignore, labels], dim=1)[:, :cfg.max_len]
            dec_labels[dec_labels == hf_tokenizer.pad_token_id] = -100
            return self.decoder(inputs_embeds=inputs_embeds, labels=dec_labels)


        bos_emb = self.decoder.transformer.wte(torch.full((B, 1), hf_tokenizer.bos_token_id,
                                                          device=pixel_values.device))
        inputs_embeds = torch.cat([prefix, bos_emb], dim=1)
        gen = self.decoder.generate(inputs_embeds=inputs_embeds,
                                     max_length=cfg.decode_max_len,
                                     num_beams=cfg.beam_size,
                                     pad_token_id=hf_tokenizer.pad_token_id,
                                     eos_token_id=hf_tokenizer.eos_token_id)
        return gen[:, 1:]


#Training


def set_seed(s=42):
    random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)

def train_epoch(model, loader, opt, scaler, epoch):
    model.train()
    pbar = tqdm(loader, desc=f"train {epoch}")
    for batch in pbar:
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            out = model(batch["pixel_values"].to(device), batch["labels"].to(device))
            loss = out.loss
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})


#Evaluation


def ids_to_str(ids: torch.Tensor) -> str:
    return hf_tokenizer.decode(ids.tolist(), clean_up_tokenization_spaces=True,
                               skip_special_tokens=True)

def evaluate_epoch(model, loader, max_batches: int = 200):
    model.eval()
    preds, refs = {}, {}
    with torch.no_grad():
        for b_idx, batch in enumerate(loader):
            if b_idx >= max_batches:
                break
            pix = batch["pixel_values"].to(device)
            gen_ids = model(pix)
            for i, (g, r) in enumerate(zip(gen_ids, batch["labels"])):
                key = f"{b_idx:04d}_{i:02d}"
                preds[key] = ids_to_str(g)

                refs[key] = ids_to_str(r[r != hf_tokenizer.pad_token_id])

    vsr = evaluate_syntax_rate_simple(preds) * 100.0
    ious = [get_iou_best(refs[k], preds[k]) for k in preds]
    mean_iou = sum(ious) / len(ious) if ious else 0.0
    return vsr, mean_iou


def main():
    set_seed()
    model = Vision2Code().to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate,
                            weight_decay=cfg.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    for ep in range(1, cfg.epochs + 1):
        train_epoch(model, train_loader, opt, scaler, ep)
        vsr, iou = evaluate_epoch(model, val_loader)
        print(f"✦ Epoch {ep}  VSR {vsr:5.1f}%   IoU {iou:.3f}")

    torch.save(model.state_dict(), "vision2cadquery.pt")
    print("✓ Training complete – saved to vision2cadquery.pt")


if __name__ == "__main__":
    main()


Using device: cuda
Loading dataset splits …


Generating train split:   0%|          | 0/147289 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7355 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8204 [00:00<?, ? examples/s]

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


train 1:   0%|          | 0/73645 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  wi

KeyboardInterrupt: 

# Test

In [None]:
import argparse, torch
from pathlib import Path
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from vision2cadquery_updated import Vision2Code, TOKENIZER_JSON, hf_tokenizer, cfg


def ids_to_str(ids):
    return hf_tokenizer.decode(ids.tolist(), skip_special_tokens=True)


def get_test_loader(batch_size=4):
    tf = transforms.Compose([
        transforms.Resize((cfg.img_size, cfg.img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3),
    ])

    def collate(batch):
        images = torch.stack([tf(b["image"]) for b in batch])
        labels = [hf_tokenizer.encode(b["cadquery"], truncation=True, max_length=cfg.max_len)
                  for b in batch]
        return {"pixel_values": images, "labels": labels}

    ds = load_dataset(cfg.dataset_name, split="test", cache_dir=cfg.cache_dir)
    return DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=collate)


@torch.no_grad()
def evaluate(weights_path: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Vision2Code().to(device)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()

    loader = get_test_loader()

    preds, refs = {}, {}
    for b_idx, batch in enumerate(loader):
        pix = batch["pixel_values"].to(device)
        gen_ids = model(pix)
        for i, (g, r) in enumerate(zip(gen_ids, batch["labels"])):
            key = f"{b_idx:04d}_{i:02d}"
            preds[key] = ids_to_str(g)
            refs[key] = ids_to_str(torch.tensor(r))

    vsr = evaluate_syntax_rate_simple(preds) * 100.0
    mean_iou = sum(get_iou_best(refs[k], preds[k]) for k in preds) / len(preds)
    print(f"Test split  VSR {vsr:5.1f}%   IoU {mean_iou:.3f}")


if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Evaluate Vision2CadQuery on the test split.")
    p.add_argument("weights", nargs="?", default="vision2cadquery.pt",
                   help="Path to the .pt checkpoint (default: vision2cadquery.pt)")
    args = p.parse_args()

    evaluate(args.weights)
