# Elite Dangerous Core Mining - ML Detection Pipeline

Trains YOLOv8 object-detection models to identify core asteroids in Elite
Dangerous screenshots in real time.

## Labeling tool recommendation

You need to draw bounding boxes around every core asteroid visible in each
screenshot before training. Best free options:

**Recommended: [Roboflow Annotate](https://roboflow.com)** (web-based)
- Free for solo use, unlimited public projects, up to 10k images
- Drag-and-drop upload, clean bounding-box UI, good keyboard shortcuts
- Export as "YOLOv8" format (one .txt per image, same filename as the image)
- No train/val split needed on export - we handle that with cross-validation

**Local alternative: [LabelImg](https://github.com/HumanSignal/labelImg)**
- Fully offline, runs on Windows
- `pip install labelImg` then run `labelImg` from the command line
- Set output format to YOLO in the toolbar
- Shortcut `W` to draw a box, `D` for next image

## Dataset structure this notebook expects

Just dump all your images and label files into two flat folders per ring type.
No train/val subdirectory needed - the k-fold split is done automatically.

```
datasets/
  ice/
    images/    <- all your .png or .jpg screenshots, mixed together
    labels/    <- matching YOLO .txt files (same filename as the image)
  metallic/
    images/
    labels/
  rocky/
    images/
    labels/
  unified/     <- optional: copy of all images+labels from all ring types
    images/
    labels/
```

Each label .txt has one line per core:
`0 <cx> <cy> <w> <h>` - class 0 = core, then normalized center-x, center-y,
width, height (all 0.0-1.0). Both Roboflow and LabelImg produce this format.

## Strategy

- One model per ring type (ice, metallic, rocky) because each ring type has
  exactly one asteroid shape that can contain a core, and you mine one ring
  type per session so only one model needs to be loaded at a time.
- A "unified" model trained on all types combined is also trained for comparison.
- Three YOLOv8 sizes (nano, small, medium) are compared per ring type.
- K-fold cross-validation gives reliable metric estimates even with small
  datasets - especially useful early on when you have fewer than 200 images.


## 1. Install dependencies

Run once. Restart the kernel after.

In [None]:
import subprocess, sys

packages = [
    "ultralytics>=8.2.0",
    "opencv-python>=4.9.0",
    "pyyaml>=6.0.1",
    "matplotlib>=3.9.0",
    "scikit-learn>=1.5.0",
    "pandas>=2.2.0",
    "Pillow>=10.3.0",
    "onnx>=1.16.0",
    "onnxruntime-gpu>=1.18.0",
]

subprocess.run(
    [sys.executable, "-m", "pip", "install", "--quiet"] + packages,
    check=True,
)
print("All packages installed.")


## 2. Imports and GPU check

In [None]:
import os
import sys
import json
import math
import shutil
import csv
import yaml
import cv2
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from datetime import datetime
from collections import defaultdict

import torch
from ultralytics import YOLO
from sklearn.model_selection import KFold

print(f"Python:  {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu = torch.cuda.get_device_properties(0)
    vram_gb = gpu.total_memory / 1024 ** 3
    print(f"GPU: {gpu.name}")
    print(f"VRAM: {vram_gb:.1f} GB")
    if vram_gb < 6:
        print("WARNING: less than 6 GB VRAM - reduce BATCH_SIZE to 8 if training crashes")
else:
    print("WARNING: no GPU found. Training on CPU will be very slow.")
    print("Reinstall PyTorch with CUDA support:")
    print("  pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121")


## 3. Configuration

All tunable settings live here. Edit before running training.

In [None]:
# ---- MAIN CONFIG - edit this before running ----

# Ring types to train. Remove any type you don't have data for yet.
RING_TYPES = ["ice", "metallic", "rocky", "unified"]

# Model sizes to compare per ring type.
# n = fastest/smallest, s = good balance, m = best accuracy but slower inference.
# All three fit on RTX 3070 8GB at batch=16.
MODEL_SIZES = ["yolov8n", "yolov8s", "yolov8m"]

# Number of cross-validation folds.
# 5 is a solid default. Use 3 if you have fewer than ~30 images per ring type.
# Each fold trains one full model and uses 1/k of the data as validation.
K_FOLDS = 5

# Training hyperparameters
IMG_SIZE    = 640   # YOLO standard input size
EPOCHS      = 100   # max epochs per fold (early stopping usually kicks in earlier)
BATCH_SIZE  = 16    # reduce to 8 if you get CUDA out-of-memory errors
PATIENCE    = 20    # stop early if val mAP doesn't improve for this many epochs

# Confidence threshold for inference (0.0 - 1.0)
# 0.4 is a good starting point. Tune based on real-game performance.
CONF_THRESHOLD = 0.4

# We only detect one class: a core asteroid.
CLASS_NAMES = ["core"]

# Folder layout
DATASET_ROOT = Path("datasets")
RUNS_ROOT    = Path("runs")
EXPORTS_ROOT = Path("exports")
CV_TMP_ROOT  = Path("cv_tmp")   # temporary per-fold datasets, cleaned up after training

print("Config loaded.")
print(f"Ring types: {RING_TYPES}")
print(f"Model sizes: {MODEL_SIZES}")
print(f"K-folds: {K_FOLDS}")
print(f"Device: {'GPU - ' + torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")


## 4. Create dataset folder structure

Run once. Then drop all your screenshots and matching label files into the folders it prints.

In [None]:
def create_dataset_structure():
    """Create flat image and label folders for each ring type."""
    for ring_type in RING_TYPES:
        (DATASET_ROOT / ring_type / "images").mkdir(parents=True, exist_ok=True)
        (DATASET_ROOT / ring_type / "labels").mkdir(parents=True, exist_ok=True)

    RUNS_ROOT.mkdir(exist_ok=True)
    EXPORTS_ROOT.mkdir(exist_ok=True)
    CV_TMP_ROOT.mkdir(exist_ok=True)

    print("Folder structure created. Drop your files here:\n")
    for ring_type in RING_TYPES:
        print(f"  datasets/{ring_type}/images/  <- all screenshots (.png or .jpg, no subfolders)")
        print(f"  datasets/{ring_type}/labels/  <- matching YOLO .txt label files")
        print()
    print("No train/val split needed - k-fold handles that automatically.")


create_dataset_structure()


## 5. COCO JSON to YOLO format converter

Only needed if your labeling tool exported COCO JSON. Roboflow can export YOLO directly - use that and skip this.

In [None]:
def convert_coco_to_yolo(coco_json_path, output_labels_dir):
    """
    Convert a COCO-format annotations JSON to per-image YOLO .txt label files.

    COCO: one big JSON, bounding boxes as pixel-space (x, y, w, h) top-left origin.
    YOLO: one .txt per image, normalized center cx cy w h on a 0-1 scale.

    Args:
        coco_json_path: path to _annotations.coco.json
        output_labels_dir: folder where .txt files will be written
    """
    output_labels_dir = Path(output_labels_dir)
    output_labels_dir.mkdir(parents=True, exist_ok=True)

    with open(coco_json_path) as f:
        coco = json.load(f)

    cat_map = {}
    for cat in coco["categories"]:
        if cat["name"] in CLASS_NAMES:
            cat_map[cat["id"]] = CLASS_NAMES.index(cat["name"])

    if not cat_map:
        print(f"WARNING: no COCO categories match CLASS_NAMES {CLASS_NAMES}")
        print(f"Found in file: {[c['name'] for c in coco['categories']]}")
        return 0

    images = {img["id"]: img for img in coco["images"]}

    anns_by_image = defaultdict(list)
    for ann in coco["annotations"]:
        if ann["category_id"] in cat_map:
            anns_by_image[ann["image_id"]].append(ann)

    converted = 0
    for img_id, anns in anns_by_image.items():
        img_info = images[img_id]
        W = img_info["width"]
        H = img_info["height"]
        img_name = Path(img_info["file_name"]).stem

        lines = []
        for ann in anns:
            cls  = cat_map[ann["category_id"]]
            x, y, w, h = ann["bbox"]
            cx = (x + w / 2) / W
            cy = (y + h / 2) / H
            nw = w / W
            nh = h / H
            lines.append(f"{cls} {cx:.6f} {cy:.6f} {nw:.6f} {nh:.6f}")

        out_path = output_labels_dir / f"{img_name}.txt"
        out_path.write_text("\n".join(lines))
        converted += 1

    print(f"Converted {converted} images -> {output_labels_dir}")
    return converted


# Example:
# convert_coco_to_yolo(
#     coco_json_path="my_export/_annotations.coco.json",
#     output_labels_dir="datasets/ice/labels",
# )


## 6. Validate datasets

Run after adding screenshots and labels. Catches missing files, bad label format, and warns if K_FOLDS is too high for the number of images you have.

In [None]:
def validate_dataset(ring_type):
    """
    Check a ring type dataset for problems before training.
    Returns (stats dict, list of issue strings). Empty issues = all clear.
    """
    img_dir = DATASET_ROOT / ring_type / "images"
    lbl_dir = DATASET_ROOT / ring_type / "labels"

    images = list(img_dir.glob("*.png")) + list(img_dir.glob("*.jpg"))
    labels = list(lbl_dir.glob("*.txt"))

    img_stems = {p.stem for p in images}
    lbl_stems = {p.stem for p in labels}

    issues = []

    missing_labels = img_stems - lbl_stems
    orphan_labels  = lbl_stems - img_stems

    if missing_labels:
        issues.append(
            f"{len(missing_labels)} image(s) have no label file: "
            + str(sorted(missing_labels)[:5])
        )
    if orphan_labels:
        issues.append(
            f"{len(orphan_labels)} label file(s) have no matching image: "
            + str(sorted(orphan_labels)[:5])
        )

    total_cores = 0
    bad_lines = 0
    for lbl in labels:
        for line in lbl.read_text().strip().splitlines():
            line = line.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) != 5:
                bad_lines += 1
                continue
            try:
                cls = int(parts[0])
                cx, cy, w, h = float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])
                if not (0 <= cx <= 1 and 0 <= cy <= 1 and 0 < w <= 1 and 0 < h <= 1):
                    bad_lines += 1
                elif cls == 0:
                    total_cores += 1
            except ValueError:
                bad_lines += 1

    if bad_lines:
        issues.append(f"{bad_lines} malformed label line(s)")

    paired = len(img_stems & lbl_stems)
    min_needed = K_FOLDS * 2
    if paired < min_needed:
        issues.append(
            f"only {paired} labeled images but K_FOLDS={K_FOLDS} needs at least {min_needed}. "
            f"Add more data or reduce K_FOLDS in config."
        )

    stats = {
        "images":  len(images),
        "paired":  paired,
        "cores":   total_cores,
        "avg_cores_per_image": round(total_cores / paired, 2) if paired else 0,
    }

    print(f"\n{'=' * 44}")
    print(f"  {ring_type.upper()} dataset")
    print(f"{'=' * 44}")
    print(f"  Images:              {stats['images']}")
    print(f"  Paired (img+label):  {stats['paired']}")
    print(f"  Total core boxes:    {stats['cores']}")
    print(f"  Avg cores/image:     {stats['avg_cores_per_image']}")

    if issues:
        print("  ISSUES:")
        for iss in issues:
            print(f"    - {iss}")
    else:
        print("  All checks passed.")

    return stats, issues


for ring_type in RING_TYPES:
    validate_dataset(ring_type)


## 7. Preview labeled images

Sanity check - confirm your labels look correct before training.

In [None]:
def preview_labels(ring_type, n=4):
    """
    Show the first n labeled images with bounding boxes overlaid.
    Use this to confirm labels imported correctly.
    """
    img_dir = DATASET_ROOT / ring_type / "images"
    lbl_dir = DATASET_ROOT / ring_type / "labels"

    images = sorted(list(img_dir.glob("*.png")) + list(img_dir.glob("*.jpg")))[:n]

    if not images:
        print(f"No images found in {img_dir}")
        return

    cols = min(len(images), 4)
    rows = math.ceil(len(images) / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))

    if rows == 1 and cols == 1:
        axes = [axes]
    elif rows == 1 or cols == 1:
        axes = list(axes.flat)
    else:
        axes = [ax for row in axes for ax in row]

    for i, ax in enumerate(axes):
        if i >= len(images):
            ax.axis("off")
            continue

        img_path = images[i]
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        H, W = img.shape[:2]

        lbl_path = lbl_dir / (img_path.stem + ".txt")
        if lbl_path.exists():
            for line in lbl_path.read_text().strip().splitlines():
                parts = line.split()
                if len(parts) != 5:
                    continue
                _, cx, cy, w, h = (float(p) for p in parts)
                x1 = int((cx - w / 2) * W)
                y1 = int((cy - h / 2) * H)
                x2 = int((cx + w / 2) * W)
                y2 = int((cy + h / 2) * H)
                cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(img, "core", (x1, max(0, y1 - 5)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

        ax.imshow(img)
        ax.set_title(img_path.name, fontsize=8)
        ax.axis("off")

    plt.suptitle(f"{ring_type} - label preview", fontsize=12)
    plt.tight_layout()
    plt.show()


# Change ring_type to whichever dataset you want to inspect
preview_labels("ice", n=4)


## 8. K-fold cross-validation helpers

These functions build temporary per-fold datasets using hard links (fast, no
extra disk space), write a data.yaml for each fold, then clean up afterward.

**Why k-fold instead of a fixed train/val split?**

With a fixed split, one metric number depends heavily on which images happened
to land in the val set by luck. With k-fold you train k models, each validated
on a different 1/k slice of your data, and average the results. This gives a
much more reliable accuracy estimate - especially important with small datasets
early in the project. The best fold's weights are kept as the deployed model.


In [None]:
def get_paired_samples(ring_type):
    """
    Return a sorted list of (image_path, label_path) pairs for a ring type.
    Only includes images that have a matching label file.
    """
    img_dir = DATASET_ROOT / ring_type / "images"
    lbl_dir = DATASET_ROOT / ring_type / "labels"

    images = sorted(list(img_dir.glob("*.png")) + list(img_dir.glob("*.jpg")))
    pairs  = []
    for img_path in images:
        lbl_path = lbl_dir / (img_path.stem + ".txt")
        if lbl_path.exists():
            pairs.append((img_path, lbl_path))
    return pairs


def build_fold_dataset(pairs_train, pairs_val, fold_dir):
    """
    Build a temporary YOLO dataset for one fold using hard links.
    Hard links point to the same inode as the original - no extra disk space used.
    Falls back to a regular file copy if the filesystem doesn't support hard links.

    Args:
        pairs_train: list of (img_path, lbl_path) for training
        pairs_val:   list of (img_path, lbl_path) for validation
        fold_dir:    Path to write this fold's dataset into
    """
    def link_or_copy(src, dst):
        dst.parent.mkdir(parents=True, exist_ok=True)
        if dst.exists():
            dst.unlink()
        try:
            os.link(src, dst)
        except OSError:
            shutil.copy2(src, dst)

    for split, pairs in [("train", pairs_train), ("val", pairs_val)]:
        for img_path, lbl_path in pairs:
            link_or_copy(img_path, fold_dir / "images" / split / img_path.name)
            link_or_copy(lbl_path, fold_dir / "labels" / split / lbl_path.name)


def write_fold_yaml(fold_dir):
    """Write the data.yaml that points YOLO at this fold's train/val folders."""
    yaml_path = fold_dir / "data.yaml"
    config = {
        "path":  str(fold_dir.resolve()),
        "train": "images/train",
        "val":   "images/val",
        "nc":    len(CLASS_NAMES),
        "names": CLASS_NAMES,
    }
    with open(yaml_path, "w") as f:
        yaml.dump(config, f, default_flow_style=False)
    return yaml_path


def cleanup_fold(fold_dir):
    """Delete the temporary fold dataset folder after training."""
    if fold_dir.exists():
        shutil.rmtree(fold_dir)


In [None]:
def train_one_fold(yaml_path, run_name, model_size, n_train, n_val):
    """
    Train one YOLO model on a prepared fold dataset.

    Args:
        yaml_path:  Path to the fold's data.yaml
        run_name:   name used as the output folder name under RUNS_ROOT
        model_size: e.g. "yolov8s"
        n_train:    number of training images (for logging)
        n_val:      number of validation images

    Returns:
        (weights_path str, metrics dict)
    """
    print(f"  Training {run_name}  (train={n_train}, val={n_val})")

    model = YOLO(f"{model_size}.pt")

    results = model.train(
        data=str(yaml_path),
        epochs=EPOCHS,
        imgsz=IMG_SIZE,
        batch=BATCH_SIZE,
        patience=PATIENCE,
        device=0 if torch.cuda.is_available() else "cpu",
        project=str(RUNS_ROOT),
        name=run_name,

        # -- Augmentation tuned for Elite Dangerous screenshots --
        # Asteroids spin at all angles, so heavy rotation is always valid
        degrees=45,
        # Horizontal flip is always geometrically valid in space
        fliplr=0.5,
        # Vertical flip is valid too - no fixed orientation in zero-g
        flipud=0.3,
        # Cores appear at different distances = different apparent sizes
        scale=0.5,
        # Mosaic helps the model handle cluttered asteroid fields
        mosaic=1.0,
        close_mosaic=10,
        # Brightness/saturation shift for star-lit vs shadow-side lighting
        hsv_v=0.4,
        hsv_s=0.5,
        # Small hue shift - ring types differ but not wildly
        hsv_h=0.015,
        # Game renders are sharp - skip blur augmentations
        blur=False,
        median_blur=0.0,

        # Don't save intermediate checkpoints to save disk space
        save_period=0,
        verbose=False,
    )

    rd = results.results_dict
    weights = RUNS_ROOT / run_name / "weights" / "best.pt"
    metrics = {
        "mAP50":     float(rd.get("metrics/mAP50(B)",     0)),
        "mAP50_95":  float(rd.get("metrics/mAP50-95(B)",  0)),
        "precision": float(rd.get("metrics/precision(B)", 0)),
        "recall":    float(rd.get("metrics/recall(B)",    0)),
    }
    print(f"    -> mAP50={metrics['mAP50']:.3f}  P={metrics['precision']:.3f}  R={metrics['recall']:.3f}")
    return str(weights), metrics


## 9. Train models with k-fold cross-validation

For each combination of ring type x model size this will:
1. Split the full dataset into K folds
2. Train K models, each validated on a different held-out chunk
3. Average metrics across all folds for a stable accuracy estimate
4. Keep the weights from the best-scoring fold as the final model

Rough timing on RTX 3070: yolov8n ~5-10 min/fold, yolov8s ~10-20 min/fold,
yolov8m ~20-40 min/fold. Total = folds x model sizes x ring types x per-fold time.
Early stopping usually cuts it well short of the full EPOCHS count.


In [None]:
def train_kfold(ring_type, model_size):
    """
    Run K-fold cross-validation for one ring type + model size combination.

    Returns a result dict with averaged metrics and path to the best fold's
    weights, or None if the ring type doesn't have enough data.
    """
    pairs = get_paired_samples(ring_type)

    if len(pairs) < K_FOLDS * 2:
        print(f"Skipping {ring_type}/{model_size}: only {len(pairs)} paired samples "
              f"(need at least {K_FOLDS * 2} for {K_FOLDS} folds).")
        return None

    print(f"\n{'=' * 52}")
    print(f"  {ring_type.upper()} / {model_size}  ({len(pairs)} images, {K_FOLDS} folds)")
    print(f"{'=' * 52}")

    kf        = KFold(n_splits=K_FOLDS, shuffle=True, random_state=42)
    pairs_arr = np.array(pairs, dtype=object)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    fold_results = []

    for fold_idx, (train_idx, val_idx) in enumerate(kf.split(pairs_arr)):
        fold_num    = fold_idx + 1
        pairs_train = pairs_arr[train_idx].tolist()
        pairs_val   = pairs_arr[val_idx].tolist()

        fold_dir = CV_TMP_ROOT / f"{ring_type}_{model_size}_fold{fold_num}"
        run_name = f"{ring_type}_{model_size}_fold{fold_num}_{timestamp}"

        print(f"\nFold {fold_num}/{K_FOLDS}")
        build_fold_dataset(pairs_train, pairs_val, fold_dir)
        yaml_path = write_fold_yaml(fold_dir)

        weights, metrics = train_one_fold(
            yaml_path, run_name, model_size, len(pairs_train), len(pairs_val)
        )
        fold_results.append({"fold": fold_num, "weights": weights, **metrics})

        cleanup_fold(fold_dir)

    # average metrics across folds
    avg = {
        "mAP50":     float(np.mean([r["mAP50"]    for r in fold_results])),
        "mAP50_95":  float(np.mean([r["mAP50_95"] for r in fold_results])),
        "precision": float(np.mean([r["precision"] for r in fold_results])),
        "recall":    float(np.mean([r["recall"]    for r in fold_results])),
        "mAP50_std": float(np.std( [r["mAP50"]    for r in fold_results])),
    }

    best_fold = max(fold_results, key=lambda r: r["mAP50"])

    print(f"\nAverage across {K_FOLDS} folds:")
    print(f"  mAP50={avg['mAP50']:.3f} (+/-{avg['mAP50_std']:.3f})  "
          f"mAP50-95={avg['mAP50_95']:.3f}  "
          f"P={avg['precision']:.3f}  R={avg['recall']:.3f}")
    print(f"  Best fold: {best_fold['fold']} (mAP50={best_fold['mAP50']:.3f})")
    print(f"  Best weights: {best_fold['weights']}")

    return {
        "best_weights": best_fold["weights"],
        "best_fold":    best_fold["fold"],
        "fold_results": fold_results,
        **avg,
    }


In [None]:
# Run all ring type x model size combinations.
# Skips any ring type that doesn't have enough data yet.

all_results = {}

for ring_type in RING_TYPES:
    stats, issues = validate_dataset(ring_type)
    if issues:
        print(f"Skipping {ring_type} - fix dataset issues first:")
        for iss in issues:
            print(f"  {iss}")
        continue

    if stats["paired"] < K_FOLDS * 2:
        print(f"Skipping {ring_type} - not enough images yet ({stats['paired']}).")
        continue

    all_results[ring_type] = {}

    for model_size in MODEL_SIZES:
        result = train_kfold(ring_type, model_size)
        if result is not None:
            all_results[ring_type][model_size] = result

results_path = RUNS_ROOT / "all_results.json"
RUNS_ROOT.mkdir(exist_ok=True)
with open(results_path, "w") as f:
    json.dump(all_results, f, indent=2)

print("\nAll training runs complete. Results saved to:", results_path)


## 10. Compare model results

In [None]:
def show_comparison_table():
    """
    Print a formatted table of all trained model metrics averaged across folds.
    The +/- column shows standard deviation across folds - lower means more stable.
    """
    results_path = RUNS_ROOT / "all_results.json"
    if not results_path.exists():
        print("No results yet - run training first.")
        return {}

    with open(results_path) as f:
        results = json.load(f)

    header = (
        f"{'Ring':<12} {'Model':<12} {'mAP50':>8} {'  +/-':>6} "
        f"{'mAP50-95':>10} {'Precision':>10} {'Recall':>8}"
    )
    print(header)
    print("-" * len(header))

    best_per_ring = {}
    for ring_type, models in results.items():
        best_map = -1
        for model_size, m in models.items():
            marker = ""
            if m["mAP50"] > best_map:
                best_map = m["mAP50"]
                best_per_ring[ring_type] = model_size
                marker = "  <-- best"
            print(
                f"{ring_type:<12} {model_size:<12} "
                f"{m['mAP50']:>8.3f} {m.get('mAP50_std', 0):>6.3f} "
                f"{m['mAP50_95']:>10.3f} "
                f"{m['precision']:>10.3f} {m['recall']:>8.3f}"
                + marker
            )
        print()

    return best_per_ring


best_models = show_comparison_table()
print("Best model per ring type:", best_models)


## 11. Plot per-fold metrics

Shows how much variance there is between folds. High variance usually means you need more labeled data.

In [None]:
def plot_fold_metrics(ring_type, model_size):
    """
    Bar chart of mAP50, precision, and recall per fold.
    The dashed line marks the mean mAP50 across all folds.
    High variance between bars = model is sensitive to which images land in val,
    which usually means you need more training data.
    """
    results_path = RUNS_ROOT / "all_results.json"
    if not results_path.exists():
        print("No results found.")
        return

    with open(results_path) as f:
        all_r = json.load(f)

    model_data = all_r.get(ring_type, {}).get(model_size)
    if not model_data:
        print(f"No data for {ring_type}/{model_size}")
        return

    fold_results = model_data.get("fold_results", [])
    if not fold_results:
        print("No per-fold results stored.")
        return

    folds  = [f"Fold {r['fold']}" for r in fold_results]
    map50s = [r["mAP50"]     for r in fold_results]
    precs  = [r["precision"] for r in fold_results]
    recs   = [r["recall"]    for r in fold_results]

    x = np.arange(len(folds))
    w = 0.25

    fig, ax = plt.subplots(figsize=(10, 5))
    ax.bar(x - w, map50s, w, label="mAP50",    color="steelblue")
    ax.bar(x,     precs,  w, label="Precision", color="seagreen")
    ax.bar(x + w, recs,   w, label="Recall",    color="tomato")

    mean_map = np.mean(map50s)
    ax.axhline(mean_map, color="steelblue", linestyle="--", linewidth=1.2,
               label=f"mean mAP50={mean_map:.3f}")

    ax.set_xticks(x)
    ax.set_xticklabels(folds)
    ax.set_ylim(0, 1.05)
    ax.set_ylabel("Score")
    ax.set_title(f"{ring_type} / {model_size} - metrics per fold")
    ax.legend()
    ax.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.show()


plot_fold_metrics("ice", "yolov8s")


## 12. Plot training curves for a specific fold

In [None]:
def plot_training_curves(ring_type, model_size, fold=None):
    """
    Plot loss and mAP curves from a training run's results.csv.

    Args:
        ring_type:  e.g. "ice"
        model_size: e.g. "yolov8s"
        fold:       fold number to plot, or None to plot the best fold
    """
    results_path = RUNS_ROOT / "all_results.json"
    if not results_path.exists():
        print("No results found.")
        return

    with open(results_path) as f:
        all_r = json.load(f)

    model_data = all_r.get(ring_type, {}).get(model_size)
    if not model_data:
        print(f"No data for {ring_type}/{model_size}")
        return

    if fold is None:
        fold = model_data["best_fold"]
        print(f"Plotting best fold: {fold}")

    weights_path = next(
        (r["weights"] for r in model_data["fold_results"] if r["fold"] == fold),
        None,
    )
    if weights_path is None:
        print(f"Fold {fold} not found in results.")
        return

    run_dir  = Path(weights_path).parent.parent
    csv_path = run_dir / "results.csv"
    if not csv_path.exists():
        print(f"results.csv not found at {csv_path}")
        return

    epochs, box_loss, cls_loss, map50 = [], [], [], []
    with open(csv_path) as f:
        reader = csv.DictReader(f)
        for row in reader:
            row = {k.strip(): v.strip() for k, v in row.items()}
            epochs.append(int(row.get("epoch", 0)))
            box_loss.append(float(row.get("train/box_loss", 0) or row.get("train/dfl_loss", 0) or 0))
            cls_loss.append(float(row.get("train/cls_loss", 0) or 0))
            map50.append(float(row.get("metrics/mAP50(B)", 0) or 0))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(epochs, box_loss, label="box loss")
    ax1.plot(epochs, cls_loss, label="cls loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title(f"{ring_type} / {model_size} fold {fold} - training loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    ax2.plot(epochs, map50, color="green", label="mAP50")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("mAP50")
    ax2.set_title(f"{ring_type} / {model_size} fold {fold} - validation mAP50")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


plot_training_curves("ice", "yolov8s")


## 13. Test inference on a screenshot

In [None]:
def detect_cores(image_path, ring_type, model_size=None, conf=CONF_THRESHOLD):
    """
    Run the trained detector on a single screenshot and show the result.

    Args:
        image_path: path to a .png or .jpg screenshot
        ring_type:  which ring type model to load ("ice", "metallic", "rocky", "unified")
        model_size: e.g. "yolov8s", or None to auto-pick the best one
        conf:       confidence threshold
    """
    results_path = RUNS_ROOT / "all_results.json"
    if not results_path.exists():
        print("No trained models found. Run training first.")
        return

    with open(results_path) as f:
        all_r = json.load(f)

    ring_models = all_r.get(ring_type, {})
    if not ring_models:
        print(f"No trained model for ring type '{ring_type}'.")
        return

    if model_size is None:
        model_size = max(ring_models, key=lambda m: ring_models[m]["mAP50"])
        print(f"Auto-selected best model: {model_size} (mAP50={ring_models[model_size]['mAP50']:.3f})")

    weights = ring_models.get(model_size, {}).get("best_weights", "")
    if not weights or not Path(weights).exists():
        print(f"Weights file not found: {weights}")
        return

    model = YOLO(weights)
    img   = cv2.imread(str(image_path))
    if img is None:
        print(f"Could not load image: {image_path}")
        return

    preds     = model.predict(img, conf=conf, verbose=False)[0]
    annotated = preds.plot()

    plt.figure(figsize=(14, 8))
    plt.imshow(cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB))
    plt.title(
        f"{ring_type} cores | {model_size} | "
        f"{len(preds.boxes)} detection(s) | conf>={conf}"
    )
    plt.axis("off")
    plt.tight_layout()
    plt.show()

    print(f"Detections: {len(preds.boxes)}")
    for box in preds.boxes:
        xyxy = [round(v, 1) for v in box.xyxy[0].tolist()]
        print(f"  core | conf={float(box.conf):.3f} | bbox={xyxy}")

    return preds


# Usage:
# detect_cores("screenshot.png", ring_type="ice")
# detect_cores("screenshot.png", ring_type="metallic", model_size="yolov8s")


## 14. One-button retrain

Drop new labeled screenshots into the dataset folder and call `retrain()`.
It re-validates, runs the full k-fold loop using the best model size from the
last run, and updates the results table.


In [None]:
def retrain(ring_type="all", model_size="best"):
    """
    Retrain after adding new labeled screenshots.
    Drop new images + label files into datasets/<ring_type>/images/ and labels/,
    then call this.

    Args:
        ring_type:  "all" retrains every ring type, or pass one e.g. "ice"
        model_size: "best" reuses the size that scored highest last time,
                    or pass a specific size like "yolov8s"
    """
    targets = RING_TYPES if ring_type == "all" else [ring_type]

    results_path = RUNS_ROOT / "all_results.json"
    prev_results = {}
    if results_path.exists():
        with open(results_path) as f:
            prev_results = json.load(f)

    for rt in targets:
        stats, issues = validate_dataset(rt)
        if issues:
            print(f"Dataset issues for '{rt}' - fix before retraining:")
            for iss in issues:
                print(f"  {iss}")
            continue

        if stats["paired"] < K_FOLDS * 2:
            print(f"Skipping '{rt}' - not enough images ({stats['paired']}). "
                  f"Need at least {K_FOLDS * 2}.")
            continue

        if model_size == "best":
            ring_prev = prev_results.get(rt, {})
            size = max(ring_prev, key=lambda m: ring_prev[m]["mAP50"]) if ring_prev else "yolov8s"
            print(f"Using best model size from previous run: {size}")
        else:
            size = model_size

        result = train_kfold(rt, size)
        if result is not None:
            if rt not in prev_results:
                prev_results[rt] = {}
            prev_results[rt][size] = result

    with open(results_path, "w") as f:
        json.dump(prev_results, f, indent=2)

    print("\nRetrain complete!")
    show_comparison_table()


# Examples:
# retrain()                      # retrain all ring types, best model size each
# retrain("ice")                 # retrain only ice, auto-pick best model size
# retrain("ice", "yolov8m")      # force a specific model size


## 15. Export best models to ONNX

ONNX lets the companion app run inference without a full PyTorch install and enables DirectML/CUDA acceleration via ONNX Runtime on Windows.

In [None]:
def export_best_models():
    """
    Export the best fold's weights for each ring type to ONNX format.
    Output goes to exports/<ring_type>_best.onnx
    """
    results_path = RUNS_ROOT / "all_results.json"
    if not results_path.exists():
        print("No trained models to export.")
        return

    with open(results_path) as f:
        all_r = json.load(f)

    EXPORTS_ROOT.mkdir(exist_ok=True)

    for ring_type, models in all_r.items():
        if not models:
            continue

        best_size    = max(models, key=lambda m: models[m]["mAP50"])
        best_weights = models[best_size]["best_weights"]

        if not Path(best_weights).exists():
            print(f"Weights missing for {ring_type}/{best_size}: {best_weights}")
            continue

        print(f"Exporting {ring_type} ({best_size}, mAP50={models[best_size]['mAP50']:.3f})...")
        model    = YOLO(best_weights)
        exported = model.export(
            format="onnx",
            imgsz=IMG_SIZE,
            simplify=True,
            opset=17,
            dynamic=False,
        )

        dest    = EXPORTS_ROOT / f"{ring_type}_best.onnx"
        shutil.copy(exported, dest)
        size_mb = dest.stat().st_size / 1024 / 1024
        print(f"  -> {dest}  ({size_mb:.1f} MB)")

    print("\nExport complete. ONNX models are in:", EXPORTS_ROOT)


export_best_models()


## 16. Batch inference on a folder of screenshots

In [None]:
def batch_detect(screenshots_folder, ring_type, output_folder=None, conf=CONF_THRESHOLD):
    """
    Run detection on every image in a folder and save annotated copies.

    Args:
        screenshots_folder: folder with .png/.jpg screenshots
        ring_type:          which ring type model to use
        output_folder:      where to save annotated images (default: screenshots_folder/detected)
        conf:               confidence threshold
    """
    screenshots_folder = Path(screenshots_folder)
    if output_folder is None:
        output_folder = screenshots_folder / "detected"
    output_folder = Path(output_folder)
    output_folder.mkdir(exist_ok=True)

    results_path = RUNS_ROOT / "all_results.json"
    if not results_path.exists():
        print("No trained models. Run training first.")
        return

    with open(results_path) as f:
        all_r = json.load(f)

    ring_models = all_r.get(ring_type, {})
    if not ring_models:
        print(f"No model for ring type '{ring_type}'")
        return

    best_size    = max(ring_models, key=lambda m: ring_models[m]["mAP50"])
    best_weights = ring_models[best_size]["best_weights"]
    model        = YOLO(best_weights)

    images = sorted(
        list(screenshots_folder.glob("*.png"))
        + list(screenshots_folder.glob("*.jpg"))
    )
    print(f"Running {best_size} ({ring_type}) on {len(images)} images...")

    total_cores = 0
    for img_path in images:
        img   = cv2.imread(str(img_path))
        preds = model.predict(img, conf=conf, verbose=False)[0]
        cv2.imwrite(str(output_folder / img_path.name), preds.plot())
        total_cores += len(preds.boxes)

    print(f"Done. {total_cores} total core detections across {len(images)} screenshots.")
    print(f"Annotated images saved to: {output_folder}")


# Usage:
# batch_detect(
#     screenshots_folder=r"C:/Users/YourName/Pictures/Frontier Developments/Elite Dangerous",
#     ring_type="ice",
# )
