# Train Custom YOLO Models for Object Detection

This notebook is designed to help you train your own YOLO object detection models from scratch or fine-tune existing ones.  
You can use either:

- **Local datasets** (e.g., in YOLO format stored on your Google Drive or GitHub)
- **Datasets from Roboflow**, which can be easily imported via a download link

The workflow includes:
- Loading and organizing your dataset
- Writing a custom `.yaml` config file
- Launching training with the `ultralytics` YOLO implementation
- (Optional) Exporting and evaluating your trained model

This is ideal for training models on custom objects — whether you're working with animals, vehicles, tools, or underwater footage.

---

Make sure your dataset is in the correct YOLO structure:

```
dataset/
├── train/
│   ├── images/
│   └── labels/
├── valid/
│   ├── images/
│   └── labels/
├── test/   # optional
│   ├── images/
│   └── labels/
└── data.yaml
```

# Import libraries

In [1]:
import os
import random
import shutil
import math
import glob
from IPython.display import Image, display
import numpy as np
import time  # Import the time module
from pathlib import Path
import zipfile
from utils import *

# from google.colab import runtime
# from google.colab import drive

from IPython import display
display.clear_output()

import ultralytics
ultralytics.checks()

from ultralytics import YOLO
from IPython.display import display, Image

HOME = os.getcwd()
print(HOME)

Ultralytics 8.3.38 🚀 Python-3.12.7 torch-2.5.1.post4 CPU (Apple M2 Pro)
Setup complete ✅ (12 CPUs, 32.0 GB RAM, 880.9/926.4 GB disk)
/Users/ang/Seafile/TRex-tutorials-data/code


#### There are options: 1) to load data from local folder (.zip folder) 2) to get data directly from roboflow
Chose the one that applies to your case

# 1) Load data from local folder

In [2]:
# -------------------------------
# USER OPTION: Choose your data source
# -------------------------------
USE_LOCAL = False  # Set to True if loading from Google Drive

# -------------------------------
# Setup
# -------------------------------
import os

# Clean and prepare dataset directory
!rm -rf /content/datasets
!mkdir /content/datasets

if USE_LOCAL:
    # --- Option 1: Load from Google Drive ---
    dir_name = "/Users/ang/Google Drive//models/hexbugs/"
    name = "hexseg.v2i.yolov11"

    # Unzip the dataset
    !unzip {dir_name}{name}.zip -d /content/datasets

else:

    # --- Option 2: Clone from GitHub (for local Jupyter setup) ---
    name = "hexbugs-annotation-dataset"
    repo_url = "https://github.com/albiangela/TRex-tutorials-data.git"
    local_base = Path.cwd() / "datasets"
    dataset_path = local_base / name
    
    # Ensure the datasets folder exists
    local_base.mkdir(parents=True, exist_ok=True)
    
    # Clone the repository
    os.system(f"git clone {repo_url}")
    
    # Copy the dataset folder to the local datasets directory
    source_folder = Path("TRex-tutorials-data/YOLO-models") / name
    shutil.copytree(source_folder, dataset_path, dirs_exist_ok=True)
    
    # Unzip dataset (assumes only one ZIP file in the folder)
    zip_files = list(dataset_path.glob("*.zip"))
    if zip_files:
        with zipfile.ZipFile(zip_files[0], 'r') as zip_ref:
            zip_ref.extractall(dataset_path)
        zip_files[0].unlink()  # remove zip after extraction
    
    # Move unzipped contents up one level if nested in a subfolder
    nested_folder = dataset_path / name
    if nested_folder.exists() and nested_folder.is_dir():
        for item in nested_folder.iterdir():
            shutil.move(str(item), str(dataset_path))
        shutil.rmtree(nested_folder)
    
    # Clean up cloned repo
    shutil.rmtree("TRex-tutorials-data")

mkdir: /content: No such file or directory


Cloning into 'TRex-tutorials-data'...


In [3]:
# Define a simple class to hold dataset metadata
class Data:
    def __init__(self, location, name, version=1):
        self.location = Path(location).resolve()  # Convert to absolute Path object
        self.version = version                    # Optional versioning
        self.name = name                          # Name of the dataset

# Create an instance of the Data class with the path and name of the dataset
dataset = Data(Path("datasets") / name, name)

# Assert that the 'train' folder exists within the dataset location
# This will raise an error if the folder is missing
assert (dataset.location / "train").exists(), f"'train' folder not found in {dataset.location}"

# Return or print the dataset location
print("Dataset location:", dataset.location)

Dataset location: /Users/ang/Seafile/TRex-tutorials-data/code/datasets/hexbugs-annotation-dataset


In [None]:
from pathlib import Path


prepare_yolo_dataset(
    dataset_path=dataset.location,  # your original dataset object
    output_path=str(output_path),   # convert Path object to string
    split=(0.7, 0.2, 0.1),          # train, valid, test
    remove_test=False               # keep test if it exists
)

In [5]:
# Set output path to a new folder named 'rebalanced_dataset' in the current working directory
output_path = Path.cwd() / "rebalanced_dataset"

prepare_yolo_dataset(
    dataset_path=dataset.location,  # your original dataset object
    output_path=str(output_path),   # convert Path object to string
    split=(0.7, 0.2, 0.1),          # train, valid, test
    remove_test=False               # keep test if it exists
)

# ## If you want to change the labels, the function will look more like this
# prepare_yolo_dataset(
#     dataset_path=dataset.location,  # your original dataset object
#     output_path=str(output_path),   # convert Path object to string
#     split=(0.7, 0.2, 0.1),
#     remove_test=False,
#     allowed_ids={0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
#     collapse_map=collapse_map,
#     new_class_ids=new_class_ids,
#     drop_others=False
# )

Rebalanced dataset saved to: /Users/ang/Seafile/TRex-tutorials-data/code/rebalanced_dataset
Train: 28 | valid: 8 | test: 5
Dataset prepared successfully.


In [8]:
from collections import Counter

def count_labels(label_dir):
    class_counts = Counter()
    for fname in os.listdir(label_dir):
        if not fname.endswith('.txt'):
            continue
        with open(os.path.join(label_dir, fname)) as f:
            for line in f:
                parts = line.strip().split()
                if parts:
                    try:
                        class_id = int(parts[0])
                        class_counts[class_id] += 1
                    except ValueError:
                        continue
    return class_counts

# Check balance
print("Train class counts:", count_labels(Path.cwd() / 'rebalanced_dataset/train/labels'))
print("Valid class counts:", count_labels(Path.cwd() / 'rebalanced_dataset/valid/labels'))
print("Test class counts:", count_labels(Path.cwd() / 'rebalanced_dataset/test/labels'))

Train class counts: Counter({0: 140})
Valid class counts: Counter({0: 40})
Test class counts: Counter({0: 25})


In [10]:
# Define a simple class to hold dataset metadata
class Data:
    def __init__(self, location, name, version=1):
        self.location = Path(location).resolve()  # Ensure it's a full path
        self.version = version
        self.name = name

# Try to use a named rebalanced dataset with previous name as prefix, fallback if dataset.name is not defined
try:
    dataset = Data(Path.cwd() / "rebalanced_dataset", dataset.name + "_rebalanced")
except NameError:
    print("⚠️ 'dataset.name' not defined. Falling back to default name 'rebalanced'.")
    dataset = Data(Path.cwd() / "rebalanced_dataset", "rebalanced")

# Assert that the 'train' folder exists
assert (dataset.location / "train").exists(), f"'train' folder not found in: {dataset.location}"

# Show the resolved dataset path
dataset.location

PosixPath('/Users/ang/Seafile/TRex-tutorials-data/code/rebalanced_dataset')

In [11]:
# Define source and destination paths
source_yaml = Path(dataset_path) / "data.yaml"  # or "dataset.yaml" if that's the actual name
if not source_yaml.exists():
    source_yaml = Path(dataset_path) / "dataset.yaml"  # fallback if using a different name

destination_yaml = dataset.location / "data.yaml"

# Copy the YAML file
shutil.copy(source_yaml, destination_yaml)

print(f"✅ Copied dataset config from {source_yaml} to {destination_yaml}")

✅ Copied dataset config from /Users/ang/Seafile/TRex-tutorials-data/code/datasets/hexbugs-annotation-dataset/data.yaml to /Users/ang/Seafile/TRex-tutorials-data/code/rebalanced_dataset/data.yaml


# 4) Define Training paramenters

In [12]:
# Define your local output folder for saving models
REMOTE_URL = Path.cwd() / "models" / "rocks"
HOME = Path.cwd()

# Change to HOME directory (optional in local Jupyter, just informative here)
print(f"Working in directory: {HOME}")

# Create the directory if it doesn't exist
if not REMOTE_URL.exists():
    REMOTE_URL.mkdir(parents=True, exist_ok=True)
    print(f"✅ Directory '{REMOTE_URL}' created.")
else:
    print(f"📁 Directory '{REMOTE_URL}' already exists.")

Working in directory: /Users/ang/Seafile/TRex-tutorials-data/code
✅ Directory '/Users/ang/Seafile/TRex-tutorials-data/code/models/rocks' created.


In [12]:

# Change to home directory
%cd {HOME}

# ---- User-defined Settings ----
resolution = 1980                # Image resolution for training
epochs = 100                     # Number of training epochs
batch_size = 4                   # Batch size
base_model = "yolo11n-seg"       # Choose model variant. Options: "yolo11n-pose", "yolo11n-seg", "yolo11n" etc.
cropped = False                  # Whether images are cropped (affects naming only)
# sharkcam = True                  # Used for naming (optional toggle)


# ---- Auto-detect task type ----
if "-seg" in base_model:
    task = "segment"
elif "-pose" in base_model:
    task = "pose"
else:
    task = "detect"

# ---- 🔧 Training Settings ----
common_settings = {
    "translate": 0.25,       # Maximum image translation as data augmentation (in % of image size)
    "mixup": 0.001,          # MixUp blending factor for image mixing (usually low for object detection)
    "copy_paste": 0.15,      # Probability of using Copy-Paste augmentation (object pasting)
    "scale": 0.25,           # Random scaling of images for augmentation
    "mosaic": 1,             # Enable Mosaic augmentation (combines 4 images into 1)
    "close_mosaic": 10,      # Number of epochs before disabling mosaic for better fine-tuning
    "line_width": 1,         # Line width for label visualization
    "optimize": True,        # Apply training graph optimization
    "dynamic": True,         # Dynamic input resizing (True enables better memory usage)
    "format": "tflite",      # Export model format (e.g., 'tflite', 'onnx', 'torchscript')
    "nms": True,             # Apply Non-Maximum Suppression during inference
    "half": False,           # Use 16-bit (half) precision if supported
    "plots": True,           # Save training plots (loss, mAP, etc.)
    "cache": "disk",         # Caching mode: "disk" to speed up I/O
    "single_cls": False,     # If True, treat all objects as one class (for class-agnostic detection)
    "amp": True,             # Enable automatic mixed precision (reduces memory, speeds up training)
    "augment": False,        # If True, applies augmentation at inference time
    "workers": 16            # Number of dataloader workers (adjust depending on your CPU)
}

# Modify task-specific augmentations
if task == "detect":
    common_settings.update({
        "degrees": 180,       # Allow full rotation
        "flipud": 0.25,       # Vertical flip probability
        "fliplr": 0.25        # Horizontal flip probability
    })
else:
    common_settings.update({
        "degrees": 0,         # No rotation for pose/seg
        "flipud": 0.0,
        "fliplr": 0.0
    })

# Print CLI training parameters
parms = " ".join([f"{k}={v}" for k, v in common_settings.items()])
print("🔧 Training params:", parms)

# ---- 🗂️ Model Output Naming ----
from datetime import datetime
now = datetime.now()
date_string = now.strftime("%Y-%m-%d-%H") + "_" + dataset.name.replace(" ", "-") + "-" + str(dataset.version)

project = f"{resolution}-{base_model}"
if common_settings["mosaic"] > 0:
    project += "-mosaic"
if cropped:
    project += "-cropped"
# if sharkcam:
#     project += "-sharkcam"  # Add logic if needed

# ---- 🧠 Model Weights Source ----
model = base_model  # or path to a pretrained model
print(f"🧪 resolution={resolution} | project={project} | date_string={date_string}")
print(f"📦 model={model} | base_model={base_model} | task={task}")

# ---- 🔒 Safety Check ----
import os
assert model == base_model or os.path.exists(model + ".pt"), f"Model path not found: {model}.pt"

/Users/ang/Seafile/TRex-tutorials-data/code
🔧 Training params: translate=0.25 mixup=0.001 copy_paste=0.15 scale=0.25 mosaic=1 close_mosaic=10 line_width=1 optimize=True dynamic=True format=tflite nms=True half=False plots=True cache=disk single_cls=False amp=True augment=False workers=16 degrees=0 flipud=0.0 fliplr=0.0
🧪 resolution=1980 | project=1980-yolo11n-seg-mosaic | date_string=2025-06-10-00_hexbugs-annotation-dataset_rebalanced_rebalanced-1
📦 model=yolo11n-seg | base_model=yolo11n-seg | task=segment


### Start training

In [14]:
# Change to your working directory
%cd {HOME}

yolo_cmd = f"""
yolo task={task} \
     mode=train \
     resume=False \
     model={model}.pt \
     data={dataset.location}/data.yaml \
     device=cpu \
     name={date_string} \
     project={project} \
     epochs={epochs} \
     imgsz={resolution} \
     batch={batch_size} \
     patience=0 \
     visualize=True \
     {parms}
"""

# ▶️ Run the command
!{yolo_cmd}

/Users/ang/Seafile/TRex-tutorials-data/code
New https://pypi.org/project/ultralytics/8.3.152 available 😃 Update with 'pip install -U ultralytics'
Ultralytics 8.3.38 🚀 Python-3.12.7 torch-2.5.1.post4 CPU (Apple M2 Pro)
[34m[1mengine/trainer: [0mtask=segment, mode=train, model=yolo11n-seg.pt, data=/Users/ang/Seafile/TRex-tutorials-data/code/rebalanced_dataset/data.yaml, epochs=100, time=None, patience=0, batch=4, imgsz=1980, save=True, save_period=-1, cache=disk, device=cpu, workers=16, project=1980-yolo11n-seg-mosaic, name=2025-06-10-00_hexbugs-annotation-dataset_rebalanced_rebalanced-1, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=T

# 5) Locate trained model

In [None]:
# Change to the working directory
%cd {HOME}

# List all subdirectories in the project folder
all_subdirs = [project + '/' + d for d in os.listdir(project)]

# Keep only those that contain a trained model
all_subdirs = [d for d in all_subdirs if os.path.exists(d + "/weights/last.pt")]

# Get the most recently modified subdirectory
latest_subdir = max(all_subdirs, key=os.path.getmtime)

# Construct the full path to the latest run
full_path = HOME + "/" + latest_subdir

print(project)
print(latest_subdir)
print(full_path)

# Save training parameters to a parms.txt
!echo "{parms}" > {latest_subdir}/parms.txt

### Select and Save Best YOLO Model Based on mAP Metrics

In [15]:
# Check if the best model weights file exists
print(os.path.exists(latest_subdir + "/weights/best.pt"))

# Load training results CSV
import pandas as pd
csv = pd.read_csv(latest_subdir + "/results.csv")

# Strip whitespace from column names
for c in csv.columns:
    csv = csv.rename(columns={c: c.strip()})
    # print(c.strip())  # Optional: print cleaned column names

# Check if metrics for both M (mask) and B (box) exist, and compute a weighted average
if "metrics/mAP50-95(M)" in csv.columns:
    # Combined score: weighted average of mask and box metrics (90% mAP50-95 + 10% mAP50)
    combined = (csv["metrics/mAP50-95(M)"] * 0.9 + csv["metrics/mAP50(M)"] * 0.1) + \
               (csv["metrics/mAP50-95(B)"] * 0.9 + csv["metrics/mAP50(B)"] * 0.1)
    
    # Get index of best epoch based on combined score
    index = combined.argmax()
    
    # Extract best mAP values for masks
    best_map50_95 = csv["metrics/mAP50-95(M)"].values[index]
    best_map50 = csv["metrics/mAP50(M)"].values[index]
else:
    # Only box metrics available; compute weighted score accordingly
    combined = (csv["metrics/mAP50-95(B)"] * 0.9 + csv["metrics/mAP50(B)"] * 0.1)
    index = combined.argmax()
    
    # Extract best mAP values for boxes
    best_map50_95 = csv["metrics/mAP50-95(B)"].values[index]
    best_map50 = csv["metrics/mAP50(B)"].values[index]

# Define source path of best model
from_path = latest_subdir + "/weights/best.pt"

# Define destination path with project name, date, and mAP scores in filename
to_path = HOME + "/" + project + "-" + date_string + "-mAP5095_" + str(best_map50_95) + "-mAP50_" + str(best_map50) + ".pt"
to_path = "/content/" + project + "-" + date_string + "-mAP5095_" + str(best_map50_95) + "-mAP50_" + str(best_map50) + ".pt"

# Log the copy action with source and destination paths
print("copying from ", from_path, "to", to_path)

# Copy the best model weights to the destination path with informative filename
!cp {from_path} {to_path}

# Upload the copied model file to a remote location using rsync with progress display
!rsync --progress {to_path} {REMOTE_URL}/

# Create a ZIP archive of the full training results folder
!zip -r "{HOME}/{latest_subdir}.zip" "{full_path}"

# Upload the zipped training results to the remote server using rsync with progress shown
!rsync --progress "{HOME}/{latest_subdir}.zip" "{REMOTE_URL}/"

NameError: name 'latest_subdir' is not defined

### Training results plot

In [16]:
# Change working directory to HOME
%cd {HOME}

# Display the training results plot (e.g. loss and metrics curves)
Image(filename=f'{latest_subdir}/results.png', width=1200)

/Users/ang/Seafile/TRex-tutorials-data/code


NameError: name 'latest_subdir' is not defined

### Sample batch of validation predictions

In [None]:
# Change working directory to HOME
%cd {HOME}

# Display a sample batch of validation predictions (visual output of model)
Image(filename=f'{latest_subdir}/val_batch0_pred.jpg', width=600)

# 6) Validate Custom Model

This step runs **model validation** using the best trained checkpoint (`best.pt`) on the validation dataset defined in `data.yaml`. It evaluates the model's performance using standard YOLO metrics, such as:

- **mAP50**: mean Average Precision at IoU threshold 0.5
- **mAP50-95**: mean AP across IoU thresholds from 0.5 to 0.95
- **Precision & Recall** for each class

The validation results will be saved inside the specified project folder and include:

- A `results.png` file with training/validation curves
- A `confusion_matrix.png` for classification performance
- A `val_batch0_pred.jpg` showing predicted bounding boxes on a sample batch

You can use these visual and quantitative outputs to assess if the model generalizes well to unseen data.

In [None]:
# Change working directory to HOME
%cd {HOME}

# Run YOLO validation on the best model checkpoint using the specified dataset and image size
!yolo task={task} mode=val model={latest_subdir}/weights/best.pt data={dataset.location}/data.yaml project={project} imgsz={resolution} line_width=1

# 7) Run Inference on Validation Images

This step performs **inference (prediction)** using the best trained YOLO model (`best.pt`) on the validation image set. It is useful to **visually inspect how the model performs** on real images after training.

What this does:

- Removes any existing `predict` folder to avoid clutter or overwriting previous predictions
- Runs YOLO in `predict` mode using:
  - The best model checkpoint
  - Images from the validation set
  - A low confidence threshold (`conf=0.1`) to allow more predictions for visual inspection
  - The specified image size (`imgsz`)
- Saves predicted images (with boxes, masks, or keypoints depending on the task) in a new folder under the project directory: `runs/predict`

This is especially helpful for qualitatively checking the model's detection performance, spotting failure cases, or selecting images for visualization or presentations.

In [None]:
# Change working directory to HOME
%cd {HOME}

# Remove any previous YOLO prediction results to avoid overwriting conflicts
%rm -rf {latest_subdir}/../predict

# Run YOLO prediction on validation images using the best model checkpoint
!yolo task={task} mode=predict model={latest_subdir}/weights/best.pt project={project} name=predict conf=0.1 source={dataset.location}/valid/images save=true imgsz={resolution} line_width=1

### Zip and Save Prediction Results

This step creates a ZIP archive of the prediction results generated in the previous step. The archive is saved in your home directory and named using the training subdirectory name (to make it easy to track which model it came from).

This makes it simple to download, share, or upload the predictions for external use (e.g., for presentations, manual inspection, or further analysis).

In [None]:
# Create a ZIP archive of the YOLO prediction results
# The ZIP file will be named using the current training subdirectory name to keep it traceable
zipname = latest_subdir.replace('/', '_')
!zip -r "{HOME}/prediction_{zipname}.zip" {HOME}/{project}/predict -i "{HOME}/{project}/predict/*"

# Upload the zipped prediction results to the remote server using rsync with progress feedback
!rsync --progress "{HOME}/prediction_{zipname}.zip" "{REMOTE_URL}/"

# 10) Display Sample Predictions

This step randomly selects and displays 5 predicted images from the `predict` folder.

Each image includes the model's output (e.g., bounding boxes, masks, or keypoints) overlaid on the validation images.  
It provides a quick **visual inspection** of model performance across different examples.  

This qualitative check helps identify:
- How well the model localizes objects
- Possible false positives or negatives
- Class confusion or missed detections

In [None]:
# Randomly select 5 predicted images from the YOLO prediction output folder
files = np.random.choice(glob.glob(f'{HOME}/{project}/predict/*.jpg'), size=5)
print(files.shape)

# Display each selected image and print a newline for spacing
for image_path in files:
    display(Image(filename=image_path, height=600))
    print("\n")

In [None]:


# Wait for 30 seconds (e.g., to ensure all background tasks finish before disconnecting)
time.sleep(30)

# Gracefully disconnect the current Colab runtime session
runtime.unassign()

## 🏆 Congratulations

### Find more learning resources here

Roboflow has produced many resources that you may find interesting as you advance your knowledge of computer vision:

- [Roboflow Notebooks](https://github.com/roboflow/notebooks): A repository of over 20 notebooks that walk through how to train custom models with a range of model types, from YOLOv7 to SegFormer.
- [Roboflow YouTube](https://www.youtube.com/c/Roboflow): Our library of videos featuring deep dives into the latest in computer vision, detailed tutorials that accompany our notebooks, and more.
- [Roboflow Discuss](https://discuss.roboflow.com/): Have a question about how to do something on Roboflow? Ask your question on our discussion forum.
- [Roboflow Models](https://roboflow.com): Learn about state-of-the-art models and their performance. Find links and tutorials to guide your learning.

### Convert data formats

Roboflow provides free utilities to convert data between dozens of popular computer vision formats. Check out [Roboflow Formats](https://roboflow.com/formats) to find tutorials on how to convert data between formats in a few clicks.

### Connect computer vision to your project logic

[Roboflow Templates](https://roboflow.com/templates) is a public gallery of code snippets that you can use to connect computer vision to your project logic. Code snippets range from sending emails after inference to measuring object distance between detections.