# RoadEye Wildlife Detection - YOLO11 Training with NAM Attention

This notebook trains a YOLO11 model with NAM (Normalisation-based Attention Module) for detecting Tasmanian wildlife roadkill.

**Features:**
- **YOLO11** - Latest Ultralytics model with 22% fewer parameters than YOLOv8
- **NAM Attention** - Normalization-based attention for better feature extraction
- **FiftyOne Integration** - Visualise and analyse your dataset
- Transfer learning with frozen backbone to preserve pre-trained animal features
- Exports trained model for deployment

**Target Species:**
- Tasmanian Devil (endangered)
- Feral Cat
- Tasmanian Pademelon
- Bennett's Wallaby
- Bare-nosed Wombat
- Brushtail Possum
- Fallow Deer
- Southern Brown Bandicoot
- Currawong
- Bronzewing

## 1. Setup and Installation

In [None]:
# Check GPU availability
!nvidia-smi

# Install dependencies
# Pin requests-ratelimiter to avoid BucketFullException import error in pyinaturalist
!pip install -q ultralytics>=8.3.0 "requests-ratelimiter<0.8" pyinaturalist tqdm pyyaml

import os

# Try to mount Google Drive (optional - skip if it fails)
DRIVE_OUTPUT = None
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_OUTPUT = '/content/drive/MyDrive/RoadEye'
    os.makedirs(DRIVE_OUTPUT, exist_ok=True)
    os.makedirs(f'{DRIVE_OUTPUT}/models', exist_ok=True)
    print(f"Drive mounted. Output directory: {DRIVE_OUTPUT}")
except Exception as e:
    print(f"Drive mount skipped ({e})")
    print("Models will be saved locally at /content/roadeye_output/")
    DRIVE_OUTPUT = '/content/roadeye_output'
    os.makedirs(DRIVE_OUTPUT, exist_ok=True)
    os.makedirs(f'{DRIVE_OUTPUT}/models', exist_ok=True)

# Check ultralytics version
import ultralytics
print(f"Ultralytics version: {ultralytics.__version__}")

## 2. Configuration

In [None]:
# Training Configuration
CONFIG = {
    # Project settings
    "project_name": "roadeye-wildlife-yolo11",
    
    # Base model options (YOLO11 recommended):
    # - "yolo11n.pt" (nano, fastest, 2.6M params)
    # - "yolo11s.pt" (small, 9.4M params)
    # - "yolo11m.pt" (medium, recommended, 20.1M params)
    # - "yolo11l.pt" (large, 25.3M params)
    # - "yolo11x.pt" (extra large, most accurate, 56.9M params)
    "base_model": "yolo11m.pt",
    
    # NAM Attention settings
    "use_nam_attention": True,
    
    # Training parameters
    "epochs_phase1": 50,      # Frozen backbone training
    "epochs_phase2": 50,      # Fine-tuning (optional)
    "batch_size": 16,
    "image_size": 640,
    "patience": 20,           # Early stopping patience
    "freeze_layers": 10,      # Layers to freeze (backbone)
    
    # Species to train (MEWC class ID -> species info)
    # Taxon IDs verified against iNaturalist API
    "species": {
        0: {"code": "DEVIL", "scientific": "Sarcophilus harrisii", "taxon_id": 40232},
        1: {"code": "FCAT", "scientific": "Felis catus", "taxon_id": 118552},
        2: {"code": "PADEM", "scientific": "Thylogale billardierii", "taxon_id": 42970},
        3: {"code": "WALBY", "scientific": "Notamacropus rufogriseus", "taxon_id": 1453431},
        4: {"code": "WOMBAT", "scientific": "Vombatus ursinus", "taxon_id": 43009},
        5: {"code": "BPOSM", "scientific": "Trichosurus vulpecula", "taxon_id": 42808},
        6: {"code": "FDEER", "scientific": "Dama dama", "taxon_id": 42161},
        7: {"code": "BANDI", "scientific": "Isoodon obesulus", "taxon_id": 43294},
        8: {"code": "CURRA", "scientific": "Strepera graculina", "taxon_id": 8423},
        9: {"code": "BRONZ", "scientific": "Phaps chalcoptera", "taxon_id": 3335},
    },
    
    # Data collection settings
    "max_images_per_species": 500,  # Increased for better training
    "include_dead_only": False,
    "train_val_split": 0.8,
}

# Display config
print("Training Configuration")
print("=" * 40)
print(f"Base model: {CONFIG['base_model']} (YOLO11)")
print(f"NAM Attention: {CONFIG['use_nam_attention']}")
print(f"Epochs (Phase 1 - frozen): {CONFIG['epochs_phase1']}")
print(f"Epochs (Phase 2 - fine-tune): {CONFIG['epochs_phase2']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Image size: {CONFIG['image_size']}")
print(f"Species count: {len(CONFIG['species'])}")
print(f"Max images per species: {CONFIG['max_images_per_species']}")

## 3. NAM Attention Module

The Normalisation-based Attention Module (NAM) uses batch normalisation scaling factors to measure channel importance. This is more efficient than CBAM while maintaining competitive performance.

Reference: [NAM: Normalization-based Attention Module](https://arxiv.org/abs/2111.12419)

In [None]:
import torch
import torch.nn as nn


class NAMChannelAttention(nn.Module):
    """NAM Channel Attention Module.
    
    Uses batch normalisation scaling factor to represent channel importance.
    More efficient than SE/CBAM channel attention.
    """
    
    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        self.channels = channels
        self.bn = nn.BatchNorm2d(channels, affine=True)
        self.gamma = nn.Parameter(torch.zeros(1, channels, 1, 1))
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        bn_weight = self.bn.weight.data.abs()
        bn_weight = bn_weight / (bn_weight.sum() + 1e-8)
        out = self.bn(x)
        weight = self.gamma * bn_weight.view(1, -1, 1, 1)
        attention = self.sigmoid(weight)
        return out * attention


class NAMSpatialAttention(nn.Module):
    """NAM Spatial Attention Module."""
    
    def __init__(self, channels: int):
        super().__init__()
        self.bn = nn.BatchNorm2d(channels, affine=True)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.bn(x)
        attention = self.sigmoid(out)
        return x * attention


class NAMBlock(nn.Module):
    """Combined NAM Attention Block (Channel + Spatial) with residual connection.
    
    Drop-in module for YOLO architectures.
    """
    
    def __init__(self, channels: int, use_spatial: bool = True):
        super().__init__()
        self.channel_attention = NAMChannelAttention(channels)
        self.use_spatial = use_spatial
        if use_spatial:
            self.spatial_attention = NAMSpatialAttention(channels)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.channel_attention(x)
        if self.use_spatial:
            out = self.spatial_attention(out)
        return out + x  # Residual connection


# Register NAMBlock with Ultralytics so the custom YAML can reference it
if CONFIG["use_nam_attention"]:
    try:
        import ultralytics.nn.modules as modules
        from ultralytics.nn import tasks
        
        if not hasattr(modules, "NAMBlock"):
            modules.NAMBlock = NAMBlock
        if not hasattr(tasks, "NAMBlock"):
            tasks.NAMBlock = NAMBlock
        
        print("NAMBlock registered with Ultralytics")
        print("  Custom YOLO11 architecture will use NAM attention blocks")
        print("  Architecture YAML: configs/yolo11m-nam.yaml")
    except ImportError:
        print("Ultralytics not yet imported - NAM will be registered before training")
else:
    print("NAM Attention disabled in config")

print("\nNAM Attention modules defined:")
print("  - NAMChannelAttention: BN scaling for channel importance")
print("  - NAMSpatialAttention: Pixel-wise attention via BN")
print("  - NAMBlock: Combined with residual connection")

## 4. Load Training Data

Downloads images from iNaturalist for each species. This takes ~20-30 minutes on first run.

If you already have `roadeye_training_images.zip` on Google Drive, set `LOAD_FROM_DRIVE = True` below to skip the download.

In [None]:
import zipfile
import shutil
import time
import requests
from pathlib import Path
from typing import List, Dict, Optional
from dataclasses import dataclass, field
from tqdm.notebook import tqdm
from collections import Counter

# === Choose data loading method ===
LOAD_FROM_DRIVE = False  # Set to True if you uploaded roadeye_training_images.zip to Drive

# Where the zip lives on Google Drive (only used if LOAD_FROM_DRIVE=True)
DRIVE_ZIP = f"{DRIVE_OUTPUT}/roadeye_training_images.zip"
# Local working directory on Colab
RAW_DIR = Path("/content/data/training/raw")

@dataclass
class ObservationRecord:
    """Record of a downloaded observation."""
    observation_id: str
    species_code: str
    class_id: int
    latitude: Optional[float] = None
    longitude: Optional[float] = None
    observed_on: Optional[str] = None
    image_url: str = ""
    local_path: Optional[str] = None
    is_dead: bool = False
    source: str = "unknown"

if LOAD_FROM_DRIVE:
    # ---- Load from Google Drive zip ----
    print("Loading training data from Google Drive")
    print("=" * 50)

    if not Path(DRIVE_ZIP).exists():
        raise FileNotFoundError(
            f"Zip not found at {DRIVE_ZIP}\n"
            f"Upload roadeye_training_images.zip to Google Drive under MyDrive/RoadEye/"
        )

    print(f"Unzipping {DRIVE_ZIP} to /content/ ...")
    with zipfile.ZipFile(DRIVE_ZIP, "r") as zf:
        zf.extractall("/content/")
    print("Unzipped.")

    configs_dir = Path("/content/configs")
    if configs_dir.exists():
        print(f"  Found configs/ directory with NAM YAML")

    all_observations = []
    class_names = [CONFIG["species"][i]["code"] for i in range(len(CONFIG["species"]))]
    class_to_idx = {name: idx for idx, name in enumerate(class_names)}
    extra_classes = ["FOX", "HARE", "HEDGE", "OTHER"]
    for i, ec in enumerate(extra_classes):
        if ec not in class_to_idx:
            class_to_idx[ec] = len(class_names) + i

    for source_dir in sorted(RAW_DIR.iterdir()):
        if not source_dir.is_dir():
            continue
        source_name = source_dir.name
        for species_dir in sorted(source_dir.iterdir()):
            if not species_dir.is_dir():
                continue
            species_code = species_dir.name
            cid = class_to_idx.get(species_code, -1)
            if cid == -1:
                continue
            for img_path in sorted(species_dir.iterdir()):
                if img_path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".webp"):
                    continue
                all_observations.append(ObservationRecord(
                    observation_id=img_path.stem,
                    species_code=species_code,
                    class_id=cid,
                    local_path=str(img_path),
                    source=source_name,
                ))

    print(f"\nLoaded {len(all_observations)} images from Drive")

else:
    # ---- Download fresh from iNaturalist ----
    print("Downloading training data from iNaturalist")
    print("=" * 50)
    print("This takes ~20-30 minutes. Set LOAD_FROM_DRIVE=True to skip next time.\n")

    from pyinaturalist import get_observations

    class INaturalistCollector:
        def __init__(self, output_dir: str = "/content/data/training/raw/inaturalist"):
            self.output_dir = Path(output_dir)
            self.output_dir.mkdir(parents=True, exist_ok=True)

        def collect_species(self, class_id, species_code, taxon_id, max_images=500, dead_only=False):
            params = {
                "taxon_id": taxon_id, "quality_grade": "research",
                "photos": True, "per_page": min(200, max_images),
            }
            if dead_only:
                params["term_id"] = 17
                params["term_value_id"] = 19

            print(f"  Fetching {species_code} from iNaturalist...")
            all_results = []
            page = 1
            while len(all_results) < max_images:
                try:
                    params["page"] = page
                    response = get_observations(**params)
                    results = response.get("results", [])
                    if not results:
                        break
                    all_results.extend(results)
                    page += 1
                    time.sleep(1)
                except Exception as e:
                    print(f"  API error: {e}")
                    break

            species_dir = self.output_dir / species_code
            species_dir.mkdir(parents=True, exist_ok=True)

            observations = []
            for obs in tqdm(all_results[:max_images], desc=f"  {species_code}", leave=False):
                if not obs.get("photos"):
                    continue
                location = obs.get("location")
                lat, lng = None, None
                if location:
                    try:
                        lat, lng = map(float, location.split(","))
                    except:
                        pass
                photo = obs["photos"][0]
                image_url = photo.get("url", "").replace("square", "medium")
                is_dead = any(
                    a.get("controlled_attribute", {}).get("id") == 17
                    and a.get("controlled_value", {}).get("id") == 19
                    for a in obs.get("annotations", [])
                )
                filename = f"{obs['id']}.jpg"
                filepath = species_dir / filename
                if not filepath.exists():
                    try:
                        r = requests.get(image_url, timeout=30)
                        if r.status_code == 200:
                            filepath.write_bytes(r.content)
                        time.sleep(1.0)
                    except:
                        continue
                if filepath.exists():
                    observations.append(ObservationRecord(
                        observation_id=str(obs["id"]), species_code=species_code,
                        class_id=class_id, latitude=lat, longitude=lng,
                        observed_on=obs.get("observed_on"), image_url=image_url,
                        local_path=str(filepath), is_dead=is_dead, source="inaturalist",
                    ))
            print(f"  Downloaded {len(observations)} images for {species_code}")
            return observations

    collector = INaturalistCollector()
    all_observations = []
    for class_id, species_info in CONFIG["species"].items():
        print(f"\n[{class_id + 1}/{len(CONFIG['species'])}] {species_info['code']}")
        observations = collector.collect_species(
            class_id=class_id, species_code=species_info["code"],
            taxon_id=species_info["taxon_id"],
            max_images=CONFIG["max_images_per_species"],
            dead_only=CONFIG["include_dead_only"],
        )
        all_observations.extend(observations)

# Summary
print(f"\n{'=' * 50}")
print(f"Total images: {len(all_observations)}")
species_counts = Counter(o.species_code for o in all_observations)
source_counts = Counter(o.source for o in all_observations)
print(f"\nBy source:")
for src, count in sorted(source_counts.items()):
    print(f"  {src}: {count}")
print(f"\nBy species:")
for species, count in sorted(species_counts.items()):
    print(f"  {species}: {count}")

## 5. Create YOLO Dataset

In [None]:
import random
import shutil
import yaml
from pathlib import Path

def create_yolo_dataset(
    observations: List[ObservationRecord],
    output_dir: str = "/content/dataset_yolo",
    train_ratio: float = 0.8,
) -> str:
    """Create YOLO-format dataset from observations.
    
    Returns path to data.yaml file.
    """
    output_path = Path(output_dir)
    
    # Create directory structure
    for split in ["train", "val"]:
        (output_path / split / "images").mkdir(parents=True, exist_ok=True)
        (output_path / split / "labels").mkdir(parents=True, exist_ok=True)
        
    # Get class names in order
    class_names = [CONFIG["species"][i]["code"] for i in range(len(CONFIG["species"]))]
    
    # Filter valid observations
    valid_obs = [o for o in observations if o.local_path and Path(o.local_path).exists()]
    print(f"Valid observations with images: {len(valid_obs)}")
    
    # Shuffle and split
    random.shuffle(valid_obs)
    split_idx = int(len(valid_obs) * train_ratio)
    train_obs = valid_obs[:split_idx]
    val_obs = valid_obs[split_idx:]
    
    print(f"Train set: {len(train_obs)}")
    print(f"Val set: {len(val_obs)}")
    
    def process_split(obs_list, split_name):
        for obs in tqdm(obs_list, desc=f"Processing {split_name}"):
            src_path = Path(obs.local_path)
            
            # Copy image
            dst_img = output_path / split_name / "images" / src_path.name
            shutil.copy(src_path, dst_img)
            
            # Create label (full image bounding box as placeholder)
            # Format: class_id x_center y_center width height (normalised)
            label_content = f"{obs.class_id} 0.5 0.5 1.0 1.0\n"
            
            label_path = output_path / split_name / "labels" / f"{src_path.stem}.txt"
            label_path.write_text(label_content)
            
    process_split(train_obs, "train")
    process_split(val_obs, "val")
    
    # Create data.yaml
    data_yaml = {
        "path": str(output_path.absolute()),
        "train": "train/images",
        "val": "val/images",
        "nc": len(class_names),
        "names": class_names,
    }
    
    yaml_path = output_path / "data.yaml"
    with open(yaml_path, "w") as f:
        yaml.dump(data_yaml, f, default_flow_style=False)
        
    print(f"\nDataset created at {output_path}")
    print(f"data.yaml: {yaml_path}")
    
    return str(yaml_path)

# Create dataset
DATA_YAML = create_yolo_dataset(
    observations=all_observations,
    train_ratio=CONFIG["train_val_split"],
)

# Display data.yaml contents
print("\ndata.yaml contents:")
print("-" * 30)
with open(DATA_YAML) as f:
    print(f.read())

## 5b. Source Separation and Exploration

Separate images by data source (iNaturalist vs Zenodo) to compare quality and decide which sources to include in training.

In [None]:
from pathlib import Path
from collections import Counter, defaultdict

# === Source Separation ===
# Scan raw directories to identify images by source
# RAW_DIR is set in the data loading cell above

if not RAW_DIR.exists():
    print(f"Raw directory not found at {RAW_DIR}")
    print("Run the data loading cell first.")
else:
    # Count images per source and species
    source_counts = defaultdict(lambda: defaultdict(int))
    source_images = defaultdict(list)  # source -> list of (species, path)

    for source_dir in sorted(RAW_DIR.iterdir()):
        if not source_dir.is_dir():
            continue
        source_name = source_dir.name
        for species_dir in sorted(source_dir.iterdir()):
            if not species_dir.is_dir():
                continue
            species_code = species_dir.name
            for img in species_dir.iterdir():
                if img.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp"):
                    source_counts[source_name][species_code] += 1
                    source_images[source_name].append((species_code, img))

    # Print comparison table
    print("=" * 70)
    print("DATA SOURCE COMPARISON")
    print("=" * 70)

    all_species = sorted(set(
        sp for counts in source_counts.values() for sp in counts
    ))
    sources = sorted(source_counts.keys())

    header = f"{'Species':<10}" + "".join(f"{s:<16}" for s in sources) + f"{'Total':<10}"
    print(header)
    print("-" * len(header))

    for sp in all_species:
        row = f"{sp:<10}"
        total = 0
        for src in sources:
            count = source_counts[src].get(sp, 0)
            row += f"{count:<16}"
            total += count
        row += f"{total:<10}"
        print(row)

    print("-" * len(header))
    totals_row = f"{'TOTAL':<10}"
    grand = 0
    for src in sources:
        t = sum(source_counts[src].values())
        totals_row += f"{t:<16}"
        grand += t
    totals_row += f"{grand:<10}"
    print(totals_row)
    print("=" * 70)

    print(f"\nKey observations:")
    print(f"  iNaturalist: {sum(source_counts.get('inaturalist', {}).values())} images, "
          f"research-grade, mostly live animals")
    print(f"  Zenodo:      {sum(source_counts.get('zenodo', {}).values())} images, "
          f"verified roadkill, European + some Australian")
    print(f"  Zenodo 'OTHER' category: {source_counts.get('zenodo', {}).get('OTHER', 0)} "
          f"images (unmapped European species)")

In [None]:
# Build separate YOLO datasets per source
import shutil, yaml, random

DATASET_BASE = RAW_DIR.parent  # /content/data/training

def build_source_dataset(source_name, image_list, class_names, output_base, train_ratio=0.8):
    """Build a YOLO dataset from a list of (species_code, image_path) tuples."""
    out_dir = Path(output_base) / f"dataset_{source_name}"
    for split in ("train", "val"):
        (out_dir / split / "images").mkdir(parents=True, exist_ok=True)
        (out_dir / split / "labels").mkdir(parents=True, exist_ok=True)
    
    valid = [(sp, p) for sp, p in image_list if sp in class_names]
    random.shuffle(valid)
    split_idx = int(len(valid) * train_ratio)
    splits = {"train": valid[:split_idx], "val": valid[split_idx:]}
    
    class_to_idx_local = {name: idx for idx, name in enumerate(class_names)}
    
    for split_name, items in splits.items():
        for species_code, img_path in items:
            dst = out_dir / split_name / "images" / img_path.name
            shutil.copy(img_path, dst)
            class_id = class_to_idx_local[species_code]
            label_path = out_dir / split_name / "labels" / f"{img_path.stem}.txt"
            label_path.write_text(f"{class_id} 0.5 0.5 1.0 1.0\n")
    
    data_yaml = {
        "path": str(out_dir.absolute()),
        "train": "train/images",
        "val": "val/images",
        "nc": len(class_names),
        "names": class_names,
    }
    yaml_path = out_dir / "data.yaml"
    with open(yaml_path, "w") as f:
        yaml.dump(data_yaml, f, default_flow_style=False)
    
    print(f"  {source_name}: {len(valid)} images -> {yaml_path}")
    return str(yaml_path)

# Define class lists per source
INAT_CLASSES = [CONFIG["species"][i]["code"] for i in range(len(CONFIG["species"]))]
ZENODO_CLASSES = INAT_CLASSES + ["FOX", "HARE", "HEDGE", "OTHER"]
COMBINED_CLASSES = sorted(set(INAT_CLASSES + ZENODO_CLASSES))

print("Building separate datasets per source...")
print("=" * 50)

SOURCE_YAMLS = {}
for source_name, images in source_images.items():
    if source_name == "inaturalist":
        cls_list = INAT_CLASSES
    elif source_name == "zenodo":
        cls_list = [sp for sp in ZENODO_CLASSES if any(s == sp for s, _ in images)]
    else:
        cls_list = COMBINED_CLASSES
    
    yaml_path = build_source_dataset(source_name, images, cls_list, DATASET_BASE)
    SOURCE_YAMLS[source_name] = yaml_path

# Build combined
all_images = []
for imgs in source_images.values():
    all_images.extend(imgs)
combined_cls = sorted(set(sp for sp, _ in all_images))
SOURCE_YAMLS["combined"] = build_source_dataset("combined", all_images, combined_cls, DATASET_BASE)

print(f"\nDataset paths:")
for name, path in SOURCE_YAMLS.items():
    print(f"  {name}: {path}")

# === SELECT WHICH SOURCE TO USE FOR TRAINING ===
#   "inaturalist" - Australian wildlife only (~1,600 images)
#   "zenodo"      - European roadkill (~330 images)
#   "combined"    - All sources merged
TRAINING_SOURCE = "inaturalist"  # <-- Change this to select source
DATA_YAML = SOURCE_YAMLS[TRAINING_SOURCE]
print(f"\nSelected for training: {TRAINING_SOURCE} -> {DATA_YAML}")

In [ ]:
import fiftyone as fo
import fiftyone.zoo as foz
from pathlib import Path

# Load dataset into FiftyOne
print("Loading dataset into FiftyOne...")

# Get class names
class_names = [CONFIG["species"][i]["code"] for i in range(len(CONFIG["species"]))]

# Create FiftyOne dataset from YOLO format
dataset_path = Path(DATA_YAML).parent

# Load training split
dataset = fo.Dataset.from_dir(
    dataset_type=fo.types.YOLOv5Dataset,
    dataset_dir=str(dataset_path),
    split="train",
    name="roadeye-wildlife-train",
)

# Add validation split
val_dataset = fo.Dataset.from_dir(
    dataset_type=fo.types.YOLOv5Dataset,
    dataset_dir=str(dataset_path),
    split="val",
)

# Merge into main dataset with tags
for sample in dataset:
    sample.tags.append("train")
    sample.save()

for sample in val_dataset:
    sample.tags.append("val")

dataset.merge_samples(val_dataset)

# Add source tags based on filename origin
# Build reverse lookup: image filename -> source
filename_to_source = {}
for source_name, images in source_images.items():
    for _, img_path in images:
        filename_to_source[img_path.name] = source_name

for sample in dataset:
    img_name = Path(sample.filepath).name
    source = filename_to_source.get(img_name, "unknown")
    sample.tags.append(source)
    sample["source"] = source
    sample.save()

print(f"\nFiftyOne dataset loaded:")
print(f"  Total samples: {len(dataset)}")
print(f"  Train samples: {len(dataset.match_tags('train'))}")
print(f"  Val samples: {len(dataset.match_tags('val'))}")

# Source breakdown
for src in sorted(set(s["source"] for s in dataset if s.get("source"))):
    count = len(dataset.match(fo.ViewField("source") == src))
    print(f"  Source '{src}': {count}")

print(dataset)

In [None]:
# Launch FiftyOne App for visual exploration
# Note: In Colab, this opens in a new tab

session = fo.launch_app(dataset)

print("\nFiftyOne App Features:")
print("  - Filter by species using the sidebar")
print("  - Click samples to see bounding boxes")
print("  - Use the Brain to find duplicates/outliers")
print("  - Tag samples for review")

In [None]:
# Dataset statistics and analysis
print("Dataset Statistics")
print("=" * 40)

# Count by species
print("\nSamples per species:")
for species_code in class_names:
    view = dataset.filter_labels(
        "ground_truth",
        fo.ViewField("label") == species_code
    )
    print(f"  {species_code}: {len(view)}")

# Check for potential issues
print("\nData Quality Checks:")

# Check for samples without labels
no_labels = dataset.match(fo.ViewField("ground_truth.detections").length() == 0)
print(f"  Samples without labels: {len(no_labels)}")

# Check image dimensions
import fiftyone.core.media as fom
dataset.compute_metadata()

# Get size distribution
widths = [s.metadata.width for s in dataset if s.metadata]
heights = [s.metadata.height for s in dataset if s.metadata]
if widths:
    print(f"  Image width range: {min(widths)} - {max(widths)}")
    print(f"  Image height range: {min(heights)} - {max(heights)}")

## 6b. Auto-Labelling with Foundation Models

Replace placeholder bounding boxes (`0.5 0.5 1.0 1.0`) with real detections using Grounding DINO or YOLO-World. Run this section on GPU.

In [None]:
# Install auto-labelling dependencies
!pip install -q transformers

# === Auto-Labelling Configuration ===
AUTOLABEL_CONFIG = {
    # Method: "grounding_dino" or "yolo_world"
    "method": "grounding_dino",
    
    # Grounding DINO settings
    "dino_model": "IDEA-Research/grounding-dino-base",
    "box_threshold": 0.20,
    "text_threshold": 0.15,
    
    # YOLO-World settings (alternative)
    "yolo_world_model": "yolov8x-worldv2.pt",
    "yolo_world_conf": 0.20,
}

# Species prompts - multiple text prompts per species to improve recall
# Foundation models know "cat" better than "feral cat", "pigeon" better than "bronzewing"
SPECIES_PROMPTS = {
    "DEVIL": ["tasmanian devil"],
    "FCAT": ["cat", "feral cat"],
    "PADEM": ["pademelon", "small wallaby"],
    "WALBY": ["wallaby", "kangaroo"],
    "WOMBAT": ["wombat"],
    "BPOSM": ["possum", "brushtail possum"],
    "FDEER": ["deer", "fallow deer"],
    "BANDI": ["bandicoot", "small mammal"],
    "BRONZ": ["pigeon", "bird", "common bronzewing"],
    "CURRA": ["currawong", "black bird"],
    "FOX": ["fox", "red fox"],
    "HARE": ["hare", "rabbit"],
    "HEDGE": ["hedgehog"],
    "OTHER": ["animal"],
}

print(f"Auto-labelling method: {AUTOLABEL_CONFIG['method']}")
print(f"Box threshold: {AUTOLABEL_CONFIG['box_threshold']}")
print(f"Species with prompts: {len(SPECIES_PROMPTS)}")

In [None]:
import torch
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import json

# Read data.yaml to get class names for the selected dataset
with open(DATA_YAML) as f:
    data_config = yaml.safe_load(f)
autolabel_classes = data_config["names"]
class_to_idx = {name: idx for idx, name in enumerate(autolabel_classes)}

# Build prompt-to-class mapping
prompt_to_code = {}
all_prompts = []
for species_code in autolabel_classes:
    prompts = SPECIES_PROMPTS.get(species_code, [species_code.lower()])
    for p in prompts:
        prompt_to_code[p] = species_code
        all_prompts.append(p)

# Build species map from folder structure for cross-validation
species_map = {}  # filename -> expected species
for source_name, images in source_images.items():
    for species_code, img_path in images:
        species_map[img_path.name] = species_code

# Output directory for auto-generated labels
LABELS_OUTPUT = Path(DATA_YAML).parent.parent / "auto_labels"
LABELS_OUTPUT.mkdir(parents=True, exist_ok=True)

# Collect all images from the selected dataset
dataset_dir = Path(DATA_YAML).parent
image_paths = list((dataset_dir / "train" / "images").glob("*.jpg"))
image_paths += list((dataset_dir / "val" / "images").glob("*.jpg"))
print(f"Images to label: {len(image_paths)}")

if AUTOLABEL_CONFIG["method"] == "grounding_dino":
    from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
    
    print(f"Loading Grounding DINO: {AUTOLABEL_CONFIG['dino_model']}")
    processor = AutoProcessor.from_pretrained(AUTOLABEL_CONFIG["dino_model"])
    dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
        AUTOLABEL_CONFIG["dino_model"]
    ).to("cuda")
    
    text_prompt = ". ".join(all_prompts) + "."
    print(f"Text prompt: {text_prompt[:100]}...")
    
    report = {"labelled": 0, "empty": 0, "disagreements": [], "per_class": {}}
    all_confs = []
    
    for img_path in tqdm(image_paths, desc="Auto-labelling with DINO"):
        image = Image.open(img_path)
        w, h = image.size
        
        inputs = processor(images=image, text=text_prompt, return_tensors="pt").to("cuda")
        with torch.no_grad():
            outputs = dino_model(**inputs)
        
        results = processor.post_process_grounded_object_detection(
            outputs, inputs["input_ids"],
            box_threshold=AUTOLABEL_CONFIG["box_threshold"],
            text_threshold=AUTOLABEL_CONFIG["text_threshold"],
            target_sizes=[(h, w)],
        )
        
        lines = []
        detected_species = []
        for box, score, label_text in zip(
            results[0]["boxes"], results[0]["scores"], results[0]["labels"]
        ):
            label_lower = label_text.lower().strip()
            matched_code = None
            for phrase, code in prompt_to_code.items():
                if phrase in label_lower or label_lower in phrase:
                    matched_code = code
                    break
            if matched_code is None or matched_code not in class_to_idx:
                continue
            
            x1, y1, x2, y2 = box.tolist()
            cx = ((x1 + x2) / 2) / w
            cy = ((y1 + y2) / 2) / h
            bw = (x2 - x1) / w
            bh = (y2 - y1) / h
            cid = class_to_idx[matched_code]
            lines.append(f"{cid} {cx:.6f} {cy:.6f} {bw:.6f} {bh:.6f}")
            all_confs.append(float(score))
            detected_species.append(matched_code)
            report["per_class"][matched_code] = report["per_class"].get(matched_code, 0) + 1
        
        label_path = LABELS_OUTPUT / f"{img_path.stem}.txt"
        label_path.write_text("\n".join(lines))
        
        if lines:
            report["labelled"] += 1
        else:
            report["empty"] += 1
        
        # Cross-validate against known species
        expected = species_map.get(img_path.name)
        if expected and detected_species and expected not in detected_species:
            report["disagreements"].append({
                "image": img_path.name, "expected": expected, "detected": detected_species
            })
    
    report["avg_confidence"] = sum(all_confs) / len(all_confs) if all_confs else 0
    del dino_model  # Free GPU memory
    torch.cuda.empty_cache()

elif AUTOLABEL_CONFIG["method"] == "yolo_world":
    from ultralytics import YOLO as YOLOWorld
    
    print(f"Loading YOLO-World: {AUTOLABEL_CONFIG['yolo_world_model']}")
    yw_model = YOLOWorld(AUTOLABEL_CONFIG["yolo_world_model"])
    
    # Use first prompt per species as class text
    class_texts = [SPECIES_PROMPTS.get(c, [c.lower()])[0] for c in autolabel_classes]
    yw_model.set_classes(class_texts)
    
    results = yw_model.predict(
        source=[str(p) for p in image_paths],
        conf=AUTOLABEL_CONFIG["yolo_world_conf"],
        device="cuda",
        verbose=False,
    )
    
    report = {"labelled": 0, "empty": 0, "disagreements": [], "per_class": {}}
    all_confs = []
    
    for img_path, result in zip(image_paths, results):
        boxes = result.boxes
        lines = []
        detected_species = []
        for i in range(len(boxes)):
            box = boxes.xywhn[i].cpu().numpy()
            cls_id = int(boxes.cls[i])
            conf = float(boxes.conf[i])
            if cls_id < len(autolabel_classes):
                cx, cy, bw, bh = box
                lines.append(f"{cls_id} {cx:.6f} {cy:.6f} {bw:.6f} {bh:.6f}")
                all_confs.append(conf)
                sp = autolabel_classes[cls_id]
                detected_species.append(sp)
                report["per_class"][sp] = report["per_class"].get(sp, 0) + 1
        
        label_path = LABELS_OUTPUT / f"{img_path.stem}.txt"
        label_path.write_text("\n".join(lines))
        
        if lines:
            report["labelled"] += 1
        else:
            report["empty"] += 1
        
        expected = species_map.get(img_path.name)
        if expected and detected_species and expected not in detected_species:
            report["disagreements"].append({
                "image": img_path.name, "expected": expected, "detected": detected_species
            })
    
    report["avg_confidence"] = sum(all_confs) / len(all_confs) if all_confs else 0
    del yw_model
    torch.cuda.empty_cache()

# Save report
report_path = LABELS_OUTPUT / "labelling_report.json"
with open(report_path, "w") as f:
    json.dump(report, f, indent=2)

print(f"\nAuto-labelling Results")
print(f"=" * 40)
print(f"  Labelled: {report['labelled']}")
print(f"  Empty (no detections): {report['empty']}")
print(f"  Avg confidence: {report['avg_confidence']:.3f}")
print(f"  Disagreements with folder labels: {len(report['disagreements'])}")
print(f"\nDetections per class:")
for cls, count in sorted(report["per_class"].items()):
    print(f"  {cls}: {count}")
print(f"\nLabels saved to: {LABELS_OUTPUT}")
print(f"Report saved to: {report_path}")

In [None]:
# Copy auto-labels into the dataset, replacing placeholders
import shutil

dataset_dir = Path(DATA_YAML).parent
replaced = 0

for split in ("train", "val"):
    labels_dir = dataset_dir / split / "labels"
    for label_file in labels_dir.glob("*.txt"):
        auto_label = LABELS_OUTPUT / label_file.name
        if auto_label.exists():
            content = auto_label.read_text().strip()
            if content:  # Only replace if auto-label has detections
                shutil.copy(auto_label, label_file)
                replaced += 1

print(f"Replaced {replaced} placeholder labels with auto-generated bounding boxes")
print(f"Remaining placeholders: {len(image_paths) - replaced}")

# Verify: check a few labels are no longer placeholders
sample_labels = list((dataset_dir / "train" / "labels").glob("*.txt"))[:5]
print("\nSample labels (should show real coordinates, not 0.5 0.5 1.0 1.0):")
for lbl in sample_labels:
    print(f"  {lbl.name}: {lbl.read_text().strip()[:80]}")

## 7. Phase 1: Frozen Backbone Training with YOLO11

Train with the backbone frozen to preserve pre-trained features.

In [None]:
from ultralytics import YOLO

print("Phase 1: Frozen Backbone Training (YOLO11)")
print("=" * 50)

# Load model â€” either custom NAM architecture or stock YOLO11
if CONFIG["use_nam_attention"]:
    # Re-register NAMBlock in case this cell runs in a fresh runtime
    import ultralytics.nn.modules as _modules
    from ultralytics.nn import tasks as _tasks
    if not hasattr(_modules, "NAMBlock"):
        _modules.NAMBlock = NAMBlock
    if not hasattr(_tasks, "NAMBlock"):
        _tasks.NAMBlock = NAMBlock

    # Write the NAM YAML config if it doesn't exist yet
    NAM_YAML = "/content/configs/yolo11m-nam.yaml"
    Path("/content/configs").mkdir(parents=True, exist_ok=True)
    if not Path(NAM_YAML).exists():
        Path(NAM_YAML).write_text("""\
# YOLO11m with NAM (Normalisation-based Attention Module)
nc: 9
scales:
  m: [0.50, 1.00, 512]

backbone:
  - [-1, 1, Conv, [64, 3, 2]]           # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]          # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]   # 2
  - [-1, 1, Conv, [256, 3, 2]]          # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]   # 4
  - [-1, 1, NAMBlock, [512]]            # 5 - NAM after P3 features
  - [-1, 1, Conv, [512, 3, 2]]          # 6-P4/16
  - [-1, 2, C3k2, [512, True]]          # 7
  - [-1, 1, NAMBlock, [512]]            # 8 - NAM after P4 features
  - [-1, 1, Conv, [1024, 3, 2]]         # 9-P5/32
  - [-1, 2, C3k2, [1024, True]]         # 10
  - [-1, 1, SPPF, [1024, 5]]            # 11
  - [-1, 2, C2PSA, [1024]]              # 12

head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]  # 13
  - [[-1, 8], 1, Concat, [1]]           # 14 - cat backbone P4 (after NAM)
  - [-1, 2, C3k2, [512, False]]         # 15
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]  # 16
  - [[-1, 5], 1, Concat, [1]]           # 17 - cat backbone P3 (after NAM)
  - [-1, 2, C3k2, [256, False]]         # 18 (P3/8-small)
  - [-1, 1, Conv, [256, 3, 2]]          # 19
  - [[-1, 15], 1, Concat, [1]]          # 20 - cat head P4
  - [-1, 2, C3k2, [512, False]]         # 21 (P4/16-medium)
  - [-1, 1, Conv, [512, 3, 2]]          # 22
  - [[-1, 12], 1, Concat, [1]]          # 23 - cat head P5
  - [-1, 2, C3k2, [1024, True]]         # 24 (P5/32-large)
  - [[18, 21, 24], 1, Detect, [nc]]     # 25 - Detect(P3, P4, P5)
""")
        print(f"  Created NAM YAML at {NAM_YAML}")

    print(f"Architecture: {NAM_YAML} (YOLO11m + NAM attention)")
    print(f"Pretrained weights: {CONFIG['base_model']}")
    model = YOLO(NAM_YAML)
    model.load(CONFIG['base_model'])

    # NAM inserts 2 extra layers in the backbone (layers 5, 8),
    # so freeze 12 to cover the full backbone instead of 10
    freeze_layers = 12
    print(f"Freeze layers: {freeze_layers} (adjusted for NAM backbone)")
else:
    print(f"Architecture: stock {CONFIG['base_model']}")
    model = YOLO(CONFIG['base_model'])
    freeze_layers = CONFIG["freeze_layers"]
    print(f"Freeze layers: {freeze_layers}")

print(f"Epochs: {CONFIG['epochs_phase1']}")
print(f"Learning rate: 0.001")
print()

# Train with frozen backbone
results_phase1 = model.train(
    data=DATA_YAML,
    epochs=CONFIG["epochs_phase1"],
    batch=CONFIG["batch_size"],
    imgsz=CONFIG["image_size"],
    patience=CONFIG["patience"],
    project=CONFIG["project_name"],
    name="phase1_frozen_yolo11",
    exist_ok=True,

    freeze=freeze_layers,
    lr0=0.001,
    lrf=0.01,

    hsv_h=0.015,
    hsv_s=0.7,
    hsv_v=0.4,
    degrees=0.0,
    translate=0.1,
    scale=0.5,
    fliplr=0.5,
    mosaic=1.0,

    device=0,
    workers=4,
)

print("\nPhase 1 training complete!")
PHASE1_MODEL = f"{CONFIG['project_name']}/phase1_frozen_yolo11/weights/best.pt"
print(f"Best model: {PHASE1_MODEL}")

## 8. Phase 2: Fine-tuning (Optional)

Unfreeze all layers and fine-tune with very low learning rate.

In [None]:
RUN_PHASE2 = True  # Set to False to skip fine-tuning

if RUN_PHASE2:
    print("Phase 2: Fine-tuning (All Layers)")
    print("=" * 50)
    print(f"Starting from: {PHASE1_MODEL}")
    print(f"Epochs: {CONFIG['epochs_phase2']}")
    print(f"Learning rate: 0.0001 (very low to preserve features)")
    print()

    # Re-register NAMBlock if using NAM (needed if runtime restarted)
    if CONFIG["use_nam_attention"]:
        import ultralytics.nn.modules as _modules
        from ultralytics.nn import tasks as _tasks
        if not hasattr(_modules, "NAMBlock"):
            _modules.NAMBlock = NAMBlock
        if not hasattr(_tasks, "NAMBlock"):
            _tasks.NAMBlock = NAMBlock
        print("NAMBlock registered (Phase 2)")

    # Load Phase 1 model (already has NAM architecture if enabled)
    model_phase2 = YOLO(PHASE1_MODEL)

    # Fine-tune with all layers unfrozen
    results_phase2 = model_phase2.train(
        data=DATA_YAML,
        epochs=CONFIG["epochs_phase2"],
        batch=CONFIG["batch_size"],
        imgsz=CONFIG["image_size"],
        patience=CONFIG["patience"],
        project=CONFIG["project_name"],
        name="phase2_finetune_yolo11",
        exist_ok=True,

        # No freezing - all layers trainable
        freeze=0,

        # Very low learning rate to preserve knowledge
        lr0=0.0001,
        lrf=0.001,

        # Same augmentation
        hsv_h=0.015,
        hsv_s=0.7,
        hsv_v=0.4,
        degrees=0.0,
        translate=0.1,
        scale=0.5,
        fliplr=0.5,
        mosaic=1.0,

        device=0,
        workers=4,
    )

    FINAL_MODEL = f"{CONFIG['project_name']}/phase2_finetune_yolo11/weights/best.pt"
    print("\nPhase 2 fine-tuning complete!")
else:
    FINAL_MODEL = PHASE1_MODEL
    print("Skipping Phase 2 - using Phase 1 model")

print(f"Final model: {FINAL_MODEL}")

## 9. Model Evaluation with FiftyOne

Use FiftyOne to visualise model predictions and identify failure modes.

In [None]:
# Evaluate final model
print("Evaluating Final Model")
print("=" * 50)

model_eval = YOLO(FINAL_MODEL)
metrics = model_eval.val(data=DATA_YAML)

print(f"\nResults:")
print(f"  mAP50: {metrics.box.map50:.4f}")
print(f"  mAP50-95: {metrics.box.map:.4f}")
print(f"  Precision: {metrics.box.mp:.4f}")
print(f"  Recall: {metrics.box.mr:.4f}")

# Per-class metrics
print(f"\nPer-class mAP50:")
class_names = [CONFIG["species"][i]["code"] for i in range(len(CONFIG["species"]))]
for i, name in enumerate(class_names):
    if i < len(metrics.box.ap50):
        print(f"  {name}: {metrics.box.ap50[i]:.4f}")

In [None]:
# Add model predictions to FiftyOne dataset for analysis
print("Adding model predictions to FiftyOne...")

# Get validation samples
val_view = dataset.match_tags("val")

# Run inference on validation set
for sample in tqdm(val_view, desc="Running inference"):
    # Run YOLO inference
    results = model_eval(sample.filepath, verbose=False)
    
    # Convert predictions to FiftyOne format
    detections = []
    for result in results:
        boxes = result.boxes
        for i in range(len(boxes)):
            # Get box in YOLO format (x_center, y_center, width, height) normalised
            box = boxes.xywhn[i].cpu().numpy()
            x, y, w, h = box
            
            # Convert to FiftyOne format (x, y, width, height) - top-left corner
            fo_box = [x - w/2, y - h/2, w, h]
            
            cls_id = int(boxes.cls[i])
            conf = float(boxes.conf[i])
            label = class_names[cls_id] if cls_id < len(class_names) else "unknown"
            
            detections.append(
                fo.Detection(
                    label=label,
                    bounding_box=fo_box,
                    confidence=conf,
                )
            )
    
    # Add predictions to sample
    sample["predictions"] = fo.Detections(detections=detections)
    sample.save()

print("\nPredictions added. Launching FiftyOne for model evaluation...")

# Evaluate predictions against ground truth
results = val_view.evaluate_detections(
    "predictions",
    gt_field="ground_truth",
    eval_key="eval",
    compute_mAP=True,
)

print(f"\nFiftyOne Evaluation Results:")
print(results.report())

In [None]:
# Launch FiftyOne to explore predictions and failures
session = fo.launch_app(val_view)

print("\nFiftyOne Model Evaluation Features:")
print("  - Compare predictions vs ground truth")
print("  - Filter by TP/FP/FN using eval field")
print("  - Sort by confidence to find threshold")
print("  - Identify failure modes by species")

## 10. Export and Save to Google Drive

In [None]:
import shutil
from datetime import datetime

print("Exporting Model")
print("=" * 50)

# Generate timestamp for versioning
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Create output directory
export_dir = f"{DRIVE_OUTPUT}/models/{timestamp}"
os.makedirs(export_dir, exist_ok=True)

# Copy PyTorch model
pt_dest = f"{export_dir}/roadeye_yolo11_{timestamp}.pt"
shutil.copy(FINAL_MODEL, pt_dest)
print(f"PyTorch model saved: {pt_dest}")

# Export to ONNX (for deployment)
model_export = YOLO(FINAL_MODEL)
onnx_path = model_export.export(format="onnx")
onnx_dest = f"{export_dir}/roadeye_yolo11_{timestamp}.onnx"
shutil.copy(onnx_path, onnx_dest)
print(f"ONNX model saved: {onnx_dest}")

# Save training config
config_path = f"{export_dir}/training_config.yaml"
with open(config_path, "w") as f:
    yaml.dump(CONFIG, f, default_flow_style=False)
print(f"Config saved: {config_path}")

# Save metrics
metrics_data = {
    "map50": float(metrics.box.map50),
    "map50_95": float(metrics.box.map),
    "precision": float(metrics.box.mp),
    "recall": float(metrics.box.mr),
    "model": "YOLO11",
    "timestamp": timestamp,
}
metrics_path = f"{export_dir}/metrics.yaml"
with open(metrics_path, "w") as f:
    yaml.dump(metrics_data, f, default_flow_style=False)
print(f"Metrics saved: {metrics_path}")

print(f"\nAll exports saved to: {export_dir}")

## Summary

Training complete with **YOLO11** and **NAM attention** module.

**Key improvements over YOLOv8:**
- 22% fewer parameters
- Better small object detection (important for roadside wildlife)
- NAM attention enhances feature extraction

**Next steps:**
1. Download the `.pt` file from Google Drive
2. Use FiftyOne to review failure cases and improve labels
3. Consider manual annotation with CVAT/LabelStudio for better bounding boxes

**Labelling Tools (for manual annotation):**
- **CVAT** - Free, web-based, industry standard: https://cvat.ai
- **LabelStudio** - Free, self-hosted option: https://labelstud.io
- **LabelImg** - Simple desktop app for Mac: `pip install labelImg`