# Dataset & Training Setup

We combine **GTSDB (German Traffic Sign Detection Benchmark)** and **GTSRB (German Traffic Sign Recognition Benchmark)** to build a larger and more balanced training dataset.

## Original Data (GTSDB)
- ~900 images available:
  - **600** used for training
  - **300** for validation
- Validation set contains **only original images** (no synthetic augmentation).

## Augmented Data
- To avoid underfitting due to limited original samples, **~13k augmented images** were generated using GTSRB and GTSDB signs placed on empty backgrounds with various transformations.
- Total training data: **~14,958 images**  
  (≈13,158 augmented + 600 originals are upsampled ×2 treated as 1,800 samples)

## Ratio-based Training
- Training uses a **DynamicMixedDataset** that controls the ratio of original vs augmented images per epoch:
  - Early epochs → more augmented images (for diversity)  
  - Later epochs → more original images (for realism)

## Validation
- Validation strictly uses the **300 original images**.  
- This ensures evaluation is fair and not biased by synthetic data.

## Dataset Structure
images/
- train_org/ # 1800 (600 originals are upsampled ×2)
- train_aug/ # ~13k augmented
- val/ # 300 original (validation only)


## Why this setup?
1. Large and diverse training dataset (~15k images).  
2. Balanced use of augmented vs real images with a dynamic ratio schedule.  
3. Reliable validation on untouched original data only.  

---

## Download
The dataset is available on Kaggle:  
👉 [German Traffic Signs Detection (YOLO, Aug + Org)](https://www.kaggle.com/datasets/wahburrehman/german-traffic-signs-detection-yolo-aug-org)


In [None]:
!pip -q install ultralytics opencv-python

# verify GPU
!nvidia-smi

In [None]:
# =============== CUSTOM DATASET STRUCTURE ===============

In [None]:
%%writefile custom_dataset.py
import os
import random
import tempfile
import yaml
from pathlib import Path
from torch.utils.data import Dataset
from ultralytics.data.dataset import YOLODataset

class DynamicMixedDataset(Dataset):
    """
    A custom dataset that dynamically mixes original and augmented images each epoch
    based on a provided ratio schedule, all within a continuous training loop.
    """
    def __init__(
        self,
        original_img_dir,
        augmented_img_dir,
        data_yaml_path,
        img_size=640,
        augment=True,
        hyp=None,
        ratio_schedule_fn=None,
        total_target=9000,
    ):
        """
        Args:
            original_img_dir: Path to directory with original training images
            augmented_img_dir: Path to directory with augmented training images
            data_yaml_path: Path to the original data.yaml file
            ratio_schedule_fn: Function(epoch) -> (original_ratio, augmented_ratio)
            total_target: Total number of images to use per epoch
        """
        self.original_img_dir = Path(original_img_dir)
        self.augmented_img_dir = Path(augmented_img_dir)
        self.data_yaml_path = data_yaml_path
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.ratio_schedule_fn = ratio_schedule_fn
        self.total_target = total_target
        self.current_epoch = 0

        # Load base data config (for val path and classes)
        with open(data_yaml_path, "r") as f:
            self.base_data_config = yaml.safe_load(f)

        # Get all image paths
        self.original_images = self._get_image_paths(self.original_img_dir)
        self.augmented_images = self._get_image_paths(self.augmented_img_dir)

        print(f"Found {len(self.original_images)} original images")
        print(f"Found {len(self.augmented_images)} augmented images")

        # Initialize the dataset
        self.base_dataset = None
        self.current_epoch_config = None
        self.temp_train_file = None
        self.refresh_epoch_mix()

    def _get_image_paths(self, directory: Path):
        """Get all image paths from a directory recursively."""
        extensions = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
        return sorted(
            [p.as_posix() for p in directory.rglob("*") if p.suffix.lower() in extensions]
        )

    def refresh_epoch_mix(self):
        """Create a new mix of images for the current epoch."""
        # Get ratios for current epoch
        if self.ratio_schedule_fn:
            org_ratio, aug_ratio = self.ratio_schedule_fn(self.current_epoch)
        else:
            org_ratio, aug_ratio = 0.6, 0.4  # Default

        # Calculate how many of each to use
        org_target = int(org_ratio * self.total_target)
        aug_target = int(aug_ratio * self.total_target)

        # Helper: sample with replacement if needed
        def sample_with_replacement(pool, target):
            if target <= 0:
                return []
            if len(pool) >= target:
                return random.sample(pool, target)
            reps, rem = divmod(target, len(pool))
            selected = pool * reps + random.sample(pool, rem)
            random.shuffle(selected)
            return selected

        selected_org = sample_with_replacement(self.original_images, org_target)
        selected_aug = sample_with_replacement(self.augmented_images, aug_target)

        # Combine and shuffle
        epoch_images = selected_org + selected_aug
        random.shuffle(epoch_images)

        print(
            f"Epoch {self.current_epoch}: Using {len(selected_org)} original + "
            f"{len(selected_aug)} augmented images (Ratio: {org_ratio:.1f}/{aug_ratio:.1f})"
        )

        # Create/replace temporary train.txt file
        # Clean old temp (if any)
        if self.temp_train_file is not None:
            try:
                os.unlink(self.temp_train_file.name)
            except Exception:
                pass

        self.temp_train_file = tempfile.NamedTemporaryFile(
            mode="w", suffix=".txt", delete=False
        )
        self.temp_train_file.write("\n".join(epoch_images))
        self.temp_train_file.close()

        # Build a per-epoch data config that points 'train' to temp file
        epoch_config = {
            "path": self.base_data_config.get("path", ""),  # optional
            "train": self.temp_train_file.name,             # per-epoch list
            "val": self.base_data_config["val"],
            "names": self.base_data_config["names"],
            "nc": self.base_data_config["nc"],
        }
        self.current_epoch_config = epoch_config

        # after you build epoch_config dict
        self.temp_yaml_file = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False)
        yaml.safe_dump(epoch_config, self.temp_yaml_file)
        self.temp_yaml_file.close()

        self.current_epoch_yaml = self.temp_yaml_file.name  # path to the temp YAML file

        # Build the internal YOLO dataset
        self.base_dataset = YOLODataset(
            img_path=epoch_config["train"],  # required by BaseDataset
            data=epoch_config,
            task="detect",
            imgsz=self.img_size,
            augment=self.augment,
        )

    def set_epoch(self, epoch: int):
        """Call this at the start of each epoch to refresh the mix."""
        self.current_epoch = int(epoch)
        self.refresh_epoch_mix()

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, index):
        return self.base_dataset[index]

    def __del__(self):
        # Cleanup temporary file
        if getattr(self, "temp_train_file", None) is not None:
            try:
                os.unlink(self.temp_train_file.name)
            except Exception:
                pass


In [None]:
# === Model Training  ===
from ultralytics import YOLO
from ultralytics.data import build_yolo_dataset, build_dataloader
from custom_dataset import DynamicMixedDataset

# ratio schedule
def ratio_schedule(epoch: int):
    if epoch <= 40:
        return (0.4, 0.6)   # 40% original, 60% augmented
    elif epoch <= 80:
        return (0.6, 0.4)   # 60% original, 40% augmented
    elif epoch <= 90:
        return (0.8, 0.2)   # 80% original, 20% augmented
    else:
        return (1.0, 0.0)   # 100% original, 0% augmented

# paths (adjust for your environment)
base_path = "/kaggle/input/aug-org-traffic-sign-detection-yolo-format"
original_dir = f"{base_path}/images/train_org"
augmented_dir = f"{base_path}/images/train_aug"
data_yaml_path = f"{base_path}/data.yaml"

# build dynamic dataset (creates Epoch 0 mix)
train_dataset = DynamicMixedDataset(
    original_img_dir=original_dir,
    augmented_img_dir=augmented_dir,
    data_yaml_path=data_yaml_path,
    img_size=1024,
    augment=True,
    ratio_schedule_fn=ratio_schedule,
    total_target=9000,
)

# define a callback to refresh mix each epoch and rebuild the loader
def on_train_epoch_start(trainer):
    # 1) advance schedule and regenerate temp train.txt + YAML
    train_dataset.set_epoch(trainer.epoch)

    # 2) point trainer to new YAML path so internal checks use the fresh config
    trainer.args.data = train_dataset.current_epoch_yaml

    # 3) rebuild the train dataloader for this epoch
    ds = build_yolo_dataset(
        trainer.args,                                  
        img_path=train_dataset.current_epoch_config["train"],  # temp .txt
        batch=trainer.batch_size,
        data=trainer.data,
        mode="train",
    )
    trainer.train_loader = build_dataloader(
        ds,                      # dataset
        trainer.batch_size,      # batch (positional)
        trainer.args.workers,    # workers
        True,                    # shuffle
    )

# init model and register callback
model = YOLO("yolov8m.pt")
model.add_callback("on_train_epoch_start", on_train_epoch_start)

# train starting from the current mixed config (epoch 0 list)
results = model.train(
    data=train_dataset.current_epoch_yaml,
    epochs=100,
    imgsz=1024,
    batch=8,
    device=0,
    workers=2,
    cos_lr=True,
    patience=30,
    cache="ram",
    save_period=5, 
    verbose=True,
    name="detect/train",
    exist_ok=True,
)


In [None]:
import inspect
print(inspect.getsource(build_yolo_dataset))
