### Import and parameter

In [3]:
import shutil
import os
import tqdm as notebook_tqdm
import csv
from PIL import Image
from datasets import load_dataset
from ultralytics import YOLO
from pathlib import Path

### Download from hugging face

The images are put into /silver as they are already preprocessed: properly named and resized to 512 px max side.

Three classes are selected: pizza, spaghetti_bolognese, and spaghetti_carbonara.
Pizza can serve as a class that is more distinct. But both spaghetti dishes can be confused for each other.

30 images are downloaded for training, and 10 for validation. 

It is planned to use the Yolo classification model. Due to that, the folder structure and naming are adapted accordingly, and dataset.yaml is created.

In [None]:
CLASSES = {
    "pizza": 76,
    "spaghetti_bolognese": 90,
    "spaghetti_carbonara": 91,
}
N_SPLIT = {
    "train": 3,
    "validation": 1,
}
OUT_ROOT = Path("../data/silver")
MAKE_CSV = True
SPLIT_MAP = {"train": "train", "validation": "val"}

def map_labels():
    """Map Food-101 label ids <-> names and keep a stable class order."""
    label_names = load_dataset("ethz/food101", split="train").features["label"].names
    id_to_name = {i: n for i, n in enumerate(label_names)}
    target_pairs = [(id_to_name[food101_id], food101_id) for _, food101_id in CLASSES.items()]
    target_names = [n for n, _ in target_pairs]
    target_ids = {i for _, i in target_pairs}
    return id_to_name, target_names, target_ids

def ensure_dirs(root: Path, target_names):
    for split_food101, split_yolo in SPLIT_MAP.items():
        for cls in target_names:
            (root / split_yolo / cls).mkdir(parents=True, exist_ok=True)

def build_split(split_food101: str, OUT_ROOT: Path, target_names, target_ids, id_to_name):
    """
    Download images for the given Food-101 split and save in YOLO classification layout:
    OUT_ROOT/<train|val>/<class_name>/image.jpg
    Optionally writes a labels.csv at split root (filename,label).
    """
    split_yolo = SPLIT_MAP[split_food101]
    split_dir = OUT_ROOT / split_yolo
    split_dir.mkdir(parents=True, exist_ok=True)

    # counters by class name
    saved_counts = {cls: 0 for cls in target_names}
    target_per_class = N_SPLIT[split_food101]

    ds = load_dataset("ethz/food101", split=split_food101)

    rows = []  # optional CSV

    for idx, ex in enumerate(ds):
        lid = ex["label"]
        if lid not in target_ids:
            continue

        cls_name = id_to_name[lid]
        if saved_counts[cls_name] >= target_per_class:
            continue

        fname = f"{idx:06d}_{lid:02d}_{cls_name}.jpg"
        out_path = split_dir / cls_name / fname
        ex["image"].save(out_path)
        saved_counts[cls_name] += 1
        if MAKE_CSV:
            rows.append([f"{cls_name}/{fname}", cls_name])

        # stop early when all classes hit the quota
        if all(saved_counts[c] >= target_per_class for c in saved_counts):
            break

    # optional: CSV summary for this split
    if MAKE_CSV:
        with open(split_dir / "labels.csv", "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["filename", "label"])
            w.writerows(rows)

    print(f"[{split_food101}] saved counts:", saved_counts)

def write_dataset_yaml(root: Path, target_names):
    """
    Create dataset.yaml for YOLO classification training.
    """
    names_block = "\n".join([f"  - {n}" for n in target_names])
    yaml_text = f"""# YOLO classification dataset generated from Food-101
path: {root.resolve()}
train: train
val: val
names:
{names_block}
"""
    (root / "dataset.yaml").write_text(yaml_text)
    print(f"dataset.yaml written to: {root / 'dataset.yaml'}")
    print("Class order:", target_names)

def prepare_dataset():
    id_to_name, target_names, target_ids = map_labels()
    ensure_dirs(OUT_ROOT, target_names)
    for split in ["train", "validation"]:
        build_split(split, OUT_ROOT, target_names, target_ids, id_to_name)
    write_dataset_yaml(OUT_ROOT, target_names)
    return (OUT_ROOT / "dataset.yaml").resolve()

prepare_dataset()




[train] saved counts: {'pizza': 30, 'spaghetti_bolognese': 30, 'spaghetti_carbonara': 30}
[validation] saved counts: {'pizza': 10, 'spaghetti_bolognese': 10, 'spaghetti_carbonara': 10}
dataset.yaml written to: ../data/silver/dataset.yaml
Class order: ['pizza', 'spaghetti_bolognese', 'spaghetti_carbonara']


PosixPath('/workspaces/marktguru-home-assignment/data/silver/dataset.yaml')

### Model parameters

In [None]:
DATASET_YAML = Path("../data/silver/")  # point to your dataset.yaml
BASE_MODEL = "yolov8n-cls.pt"
EPOCHS = 20
IMG_SIZE = 224
BATCH = 16
DEVICE = "cpu"

### First run of the pre-trained model

In [None]:
model = YOLO(BASE_MODEL)
model.predict(
    source=Path(DATASET_YAML) / "val/spaghetti_carbonara",
    imgsz=IMG_SIZE,
    device=DEVICE,
    save=True,
)


image 1/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022250_91_spaghetti_carbonara.jpg: 224x224 carbonara 1.00, pretzel 0.00, plate 0.00, spaghetti_squash 0.00, wok 0.00, 17.0ms
image 2/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022251_91_spaghetti_carbonara.jpg: 224x224 carbonara 0.69, burrito 0.10, spaghetti_squash 0.05, guacamole 0.03, hotdog 0.02, 17.1ms
image 3/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022252_91_spaghetti_carbonara.jpg: 224x224 carbonara 0.99, spaghetti_squash 0.01, wok 0.00, broccoli 0.00, plate 0.00, 13.8ms
image 4/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022253_91_spaghetti_carbonara.jpg: 224x224 spaghetti_squash 0.40, carbonara 0.22, cucumber 0.15, zucchini 0.14, plate 0.03, 12.5ms
image 5/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbo

[ultralytics.engine.results.Results object with attributes:
 
 boxes: None
 keypoints: None
 masks: None
 names: {0: 'tench', 1: 'goldfish', 2: 'great_white_shark', 3: 'tiger_shark', 4: 'hammerhead', 5: 'electric_ray', 6: 'stingray', 7: 'cock', 8: 'hen', 9: 'ostrich', 10: 'brambling', 11: 'goldfinch', 12: 'house_finch', 13: 'junco', 14: 'indigo_bunting', 15: 'robin', 16: 'bulbul', 17: 'jay', 18: 'magpie', 19: 'chickadee', 20: 'water_ouzel', 21: 'kite', 22: 'bald_eagle', 23: 'vulture', 24: 'great_grey_owl', 25: 'European_fire_salamander', 26: 'common_newt', 27: 'eft', 28: 'spotted_salamander', 29: 'axolotl', 30: 'bullfrog', 31: 'tree_frog', 32: 'tailed_frog', 33: 'loggerhead', 34: 'leatherback_turtle', 35: 'mud_turtle', 36: 'terrapin', 37: 'box_turtle', 38: 'banded_gecko', 39: 'common_iguana', 40: 'American_chameleon', 41: 'whiptail', 42: 'agama', 43: 'frilled_lizard', 44: 'alligator_lizard', 45: 'Gila_monster', 46: 'green_lizard', 47: 'African_chameleon', 48: 'Komodo_dragon', 49: 'Afri

In [None]:


def train_eval_predict():
    model = YOLO(BASE_MODEL)

    # Train
    # results = model.train(
    #     data=str(DATASET_YAML),
    #     epochs=EPOCHS,
    #     imgsz=IMG_SIZE,
    #     batch=BATCH,
    #     device=DEVICE,
    #     verbose=False
    # )

    # Validate
    # val_results = model.val(data=str(DATASET_YAML), imgsz=IMG_SIZE, device=DEVICE)
    # model.val(data=str(DATASET_YAML), imgsz=IMG_SIZE, device=DEVICE)
    # print("\n=== Validation Results ===")
    # print("Top-1 Accuracy:", val_results.top1)
    # print("Top-5 Accuracy:", val_results.top5)
    # # optional summary
    # print("Summary dict:", val_results.summary())

    # Predict on the validation folder
    model.predict(
        source=Path(DATASET_YAML).parent / "silver/val/spaghetti_carbonara",
        imgsz=IMG_SIZE,
        device=DEVICE,
        save=True,
    )

# direct call
train_eval_predict()





image 1/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022250_91_spaghetti_carbonara.jpg: 224x224 carbonara 1.00, pretzel 0.00, plate 0.00, spaghetti_squash 0.00, wok 0.00, 10.3ms
image 2/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022251_91_spaghetti_carbonara.jpg: 224x224 carbonara 0.69, burrito 0.10, spaghetti_squash 0.05, guacamole 0.03, hotdog 0.02, 13.7ms
image 3/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022252_91_spaghetti_carbonara.jpg: 224x224 carbonara 0.99, spaghetti_squash 0.01, wok 0.00, broccoli 0.00, plate 0.00, 10.4ms
image 4/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbonara/022253_91_spaghetti_carbonara.jpg: 224x224 spaghetti_squash 0.40, carbonara 0.22, cucumber 0.15, zucchini 0.14, plate 0.03, 13.2ms
image 5/10 /workspaces/marktguru-home-assignment/notebooks/../data/silver/val/spaghetti_carbon