<a href="https://colab.research.google.com/github/Zenbagi/LOMMO/blob/main/Train_with_rf_detr_on_detection_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training RF-DETR Object Detection on a Custom Dataset

---
[![hf space](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/SkalskiP/RF-DETR)
[![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-rf-detr-on-detection-dataset.ipynb)
[![roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)](https://blog.roboflow.com/rf-detr)
[![code](https://badges.aleen42.com/src/github.svg)](https://github.com/roboflow/rf-detr)


## Introduction

### Choose the right `batch_size`

Different GPUs have different amounts of VRAM (video memory), which limits how much data they can handle at once during training. To make training work well on any machine, you can adjust two settings: `batch_size` and `grad_accum_steps`. These control how many samples are processed at a time. The key is to keep their product equal to 16 — that’s our recommended total batch size. For example, on powerful GPUs like the A100, set `batch_size=16` and `grad_accum_steps=1`. On smaller GPUs like the T4, use `batch_size=4` and `grad_accum_steps=4`. We use a method called gradient accumulation, which lets the model simulate training with a larger batch size by gradually collecting updates before adjusting the weights.

### Train with multiple GPUs

You can fine-tune RF-DETR on multiple GPUs using PyTorch’s Distributed Data Parallel (DDP). Create a `main.py` script that initializes your model and calls `.train()` as usual than run it in terminal.

```bash
python -m torch.distributed.launch \
    --nproc_per_node=8 \
    --use_env \
    main.py
```

Replace `8` in the `--nproc_per_node argument` with the number of GPUs you want to use. This approach creates one training process per GPU and splits the workload automatically. Note that your effective batch size is multiplied by the number of GPUs, so you may need to adjust your `batch_size` and `grad_accum_steps` to maintain the same overall batch size.

## Environment setup

### Configure API keys
You need to provide your W&B API key. Follow these steps:
*   In Colab, go to the left pane and click on Secrets (🔑).

*   Store Weights & Biasses Access Key under the name W&B_API_KEY.



<img src="http://wandb.me/logo-im-png" width="400" alt="Weights & Biases" />
<!--- @wandbcode{intro-colab} -->

Use [W&B](https://wandb.ai/site?utm_source=intro_colab&utm_medium=code&utm_campaign=intro) logging for machine learning experiment tracking and model checkpointing. See the full W&B Documentation [here](https://docs.wandb.ai/)

In [None]:
import os
from google.colab import userdata

os.environ["W&B_API_KEY"] = userdata.get("W&B_API_KEY")
#os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
#os.environ["ROBOFLOW_API_KEY"] = userdata.get("ROBOFLOW_API_KEY")

In [None]:
# @title Mount Google Drive

from google.colab import drive

drive.mount('/content/drive',force_remount=True)


# Setting the main paths of GDrive
base_path =  os.path.join('/content/drive', 'MyDrive','Workstation')
ds_path =  os.path.join('/content/drive', 'MyDrive','Workstation','Datasets')
prime_path = os.path.join('/content/drive', 'MyDrive','Workstation','Training_results')

local_base = "/content/dataset"


os.makedirs(base_path, exist_ok=True)
os.makedirs(ds_path, exist_ok=True)
os.makedirs(prime_path, exist_ok=True)


Mounted at /content/drive




```
Workstation/
  ├──Inference_results/
  ├──Datasets/
  │    └── dataset_name/
  │         ├── all_images/
  │         └── _annotations.json
  │   
  └──Training_results/
        └──project/
            ├── project_info.json ← (name,descr,categs used,etc)
            ├── training_notebook.ipynb
            ├── dataset/ <-- snapshots used
            │       ├── split_2025-05-13_14-20-31.json <-- Splits used
            │       ...
            │
            └── run_name/
                  ├── logs/
                  ├── run_info.json (hyperp,configs, used_split,etc)
                  └── model_checkpoints/
                      ├── last.pth
                      ├── epoch_10.pth
                      ├── epoch_20.pth
                      ...

```



### Check GPU availability

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `T4 GPU`, and then click `Save`.

In [None]:
# @title Details of the GPU in function
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


## Internal utility functions

In [None]:
# @title Acquire GPU name
import subprocess

def get_gpu_info():
    try:
        # Get the GPU information
        gpu_info = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=gpu_name,pstate,memory.total", "--format=csv,noheader"],
            encoding="utf-8"
        ).strip()

        # Split the output into components
        gpu_name, pstate, memory_total = gpu_info.split(", ")

        # Create a dictionary with the GPU information
        return {
            "gpu_name": gpu_name,
            "pstate": pstate,
            "memory_total": memory_total
        }
    except Exception as e:
        return {
            "gpu_name": "Unavailable",
            "pstate": "N/A",
            "memory_total": "N/A"
        }

In [None]:
# @title Dataset importing system
"""
- A dedicated, robust GDrive-import mechanism for
  dataset split snapshots pulled by project.

- A more general-purpose import from various sources
  like Roboflow, ZIPs, or other local formats .
"""
import os
import shutil
import json
from pathlib import Path
from datetime import datetime
from glob import glob
import random


# Merge multiple Roboflow annotations into one
def merge_coco_annotations(annotation_paths, output_path):
    merged = {
        "info": {"description": "Merged annotations"},
        "licenses": [],
        "categories": [],
        "images": [],
        "annotations": []
    }

    image_id_offset = 0
    annotation_id_offset = 0
    image_id_map = {}
    license_set = {}
    categories_set = {}

    for path in annotation_paths:
        try:
           with open(path) as f:
                data = json.load(f)
        except Exception as e:
                print(f"Error reading {path}: {e}")
                continue

        if "info" in data:
            merged["info"].update(data["info"])

        for lic in data.get("licenses", []):
            if lic["id"] not in license_set:
                merged["licenses"].append(lic)
                license_set[lic["id"]] = lic

        if not merged["categories"]:
            merged["categories"] = data["categories"]
            categories_set = {cat["id"]: cat for cat in data["categories"]}
        else:
            for cat in data["categories"]:
                if cat["id"] not in categories_set:
                    raise ValueError("Different category IDs found across annotation files.")

        for img in data["images"]:
            old_id = img["id"]
            img["id"] = image_id_offset
            image_id_map[(path, old_id)] = image_id_offset
            merged["images"].append(img)
            image_id_offset += 1

        for ann in data["annotations"]:
            ann["id"] = annotation_id_offset
            ann["image_id"] = image_id_map[(path, ann["image_id"])]
            merged["annotations"].append(ann)
            annotation_id_offset += 1


    try:
        with open(output_path, "w") as f:
            json.dump(merged, f, indent=2)
        print(f"Merged annotations saved to {output_path}")
    except Exception as e:
        print(f"Error writing to {output_path}: {e}")


def create_snapshot_split(image_dir, split_ratios=(0.8, 0.1, 0.1), description="Initial split", output_dir="./dataset"):
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    image_files.sort()
    random.shuffle(image_files)

    total = len(image_files)
    train_end = int(total * split_ratios[0])
    valid_end = train_end + int(total * split_ratios[1])

    splits = {
        "train": image_files[:train_end],
        "valid": image_files[train_end:valid_end],
        "test": image_files[valid_end:]
    }

    snapshot = {
        "description": description,
        "timestamp": datetime.now().isoformat(),
        "split_percentages": {
            "train": split_ratios[0],
            "valid": split_ratios[1],
            "test": split_ratios[2]
        },
        "splits": splits
    }

    os.makedirs(output_dir, exist_ok=True)
    ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    out_file = os.path.join(output_dir, f"split_{ts}.json")

    with open(out_file, "w") as f:
        json.dump(snapshot, f, indent=2)

    print(f"Snapshot saved to {out_file}")
    return out_file



def import_roboflow_flattened(rf_ds_path, source, output_path, snapshot_output_dir=None):

    # Find all annotation files and split image folders (train/valid/test)
    annotation_paths = glob(os.path.join(rf_ds_path, "**", "*_annotations.coco.json"), recursive=True)
    split_dirs = {
        name: os.path.join(rf_ds_path, name)
        for name in ["train", "valid", "test"]
        if os.path.isdir(os.path.join(rf_ds_path, name))
    }

    # Save snapshot of original Roboflow split folders
    if snapshot_output_dir:
        combined_images = {}
        for split_name, folder in split_dirs.items():
            image_files = [
                f for f in os.listdir(folder)
                if f.lower().endswith((".jpg", ".jpeg", ".png"))
            ]
            combined_images[split_name] = image_files

        snapshot = {
            "description": "Snapshot from Roboflow original splits",
            "timestamp": datetime.now().isoformat(),
            "splits": combined_images,
            "source": source,
        }

        os.makedirs(snapshot_output_dir, exist_ok=True)
        ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        snapshot_path = os.path.join(snapshot_output_dir, f"rf_split.json")
        with open(snapshot_path, "w") as f:
            json.dump(snapshot, f, indent=2)
        print(f"Saved snapshot of original Roboflow splits to: {snapshot_path}")

    # Creates dataset folder with given output_path
    flat_image_dir = os.path.join(output_path)
    os.makedirs(flat_image_dir, exist_ok=True)
    print("Merging Roboflow splits:\n...")
    for root, _, files in os.walk(rf_ds_path):
        for f in files:
            if f.lower().endswith(('.jpg', '.jpeg', '.png')):
                src = os.path.join(root, f)
                dst = os.path.join(flat_image_dir, f)
                if not os.path.exists(dst):
                    shutil.copy(src, dst)

    print(f"Flattened all images to: {flat_image_dir}")

    # Merge all annotations into one
    merged_ann_path = os.path.join(output_path, "_annotations.coco.json")
    merge_coco_annotations(annotation_paths, merged_ann_path)

    print(f"Roboflow import complete: {output_path}")


def filter_annotations(src, dst, file_list=None, max_files=None, category_names=None):
    """
    Filters COCO annotations based on image file list and category names .
    """
    annotation_file = "_annotations.coco.json"
    annotation_path = os.path.join(src, annotation_file)

    if not os.path.exists(annotation_path):
        print(f"Annotation file not found at: {annotation_path}")
        return

    with open(annotation_path, "r") as f:
        data = json.load(f)

    # Build selected file set
    if file_list:
        selected_file_set = set(os.path.basename(f) for f in file_list[:max_files] if not f.endswith(".json"))
    else:
        selected_file_set = set(img["file_name"] for img in data["images"])

    # Map category names to IDs
    category_name_to_id = {cat["name"]: cat["id"] for cat in data["categories"]}
    if category_names:
        selected_cat_ids = {category_name_to_id[name] for name in category_names if name in category_name_to_id}
    else:
        selected_cat_ids = None

    # Filter images
    filtered_images = [img for img in data["images"] if img["file_name"] in selected_file_set]
    image_ids = {img["id"] for img in filtered_images}

    # Filter annotations
    filtered_annotations = [
        ann for ann in data["annotations"]
        if ann["image_id"] in image_ids and (selected_cat_ids is None or ann["category_id"] in selected_cat_ids)
    ]

    # Keep only used categories
    if selected_cat_ids is not None:
        used_cat_ids = {ann["category_id"] for ann in filtered_annotations}
        filtered_categories = [cat for cat in data["categories"] if cat["id"] in used_cat_ids]
    else:
        filtered_categories = data["categories"]

    # Write new annotations
    new_data = {
        "info": data.get("info", {}),
        "licenses": data.get("licenses", []),
        "categories": filtered_categories,
        "images": filtered_images,
        "annotations": filtered_annotations,
    }

    os.makedirs(dst, exist_ok=True)
    dst_ann_path = os.path.join(dst, annotation_file)
    with open(dst_ann_path, "w") as f:
        json.dump(new_data, f)

    print(f"Filtered annotations saved to {dst_ann_path}")


# Copy to local memory from singular split/dataset
def fast_copy_dataset(src, dst, max_files=None, file_list=None):
    """
    Copy specific image files (if file_list is provided), or up to max_files randomly from src to dst.
    Also filters _annotations.coco.json to only include annotations for copied images.
    """
    os.makedirs(dst, exist_ok=True)

    all_files = sorted(os.listdir(src))
    annotation_file = "_annotations.coco.json"
    annotation_path = os.path.join(src, annotation_file)

    # Select image files
    if file_list is not None:
        selected_files = file_list[:max_files] if max_files else file_list
    elif max_files:
        image_files = [f for f in all_files if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        random.shuffle(image_files)
        selected_files = image_files[:max_files]
    else:
        # Use rsync if available
        try:
            subprocess.run(['rsync', '-a', '--size-only', src + '/', dst], check=True)
            print(f"Copied dataset to: {dst}")
            return
        except (subprocess.CalledProcessError, FileNotFoundError):
            print("rsync failed. Falling back to manual copy.")
            selected_files = [f for f in all_files if f != annotation_file]

    # Add annotation file if it exists
    if os.path.exists(annotation_path):
        selected_files.append(annotation_file)

    # Copy selected files
    for f in selected_files:
        src_file = os.path.join(src, f)
        dst_file = os.path.join(dst, f)
        if os.path.exists(src_file):
            shutil.copy(src_file, dst_file)
        else:
            print(f"Warning: File not found: {src_file}")

    print(f"Copied {len(selected_files)} files to: {dst}")



def import_gdrive_project_dataset(project_name, dataset_name, split="train", snapshot_name=None,used_category_names=None, max_files=None):
    """
    Import dataset based on snapshot-split from GDrive.
    """
    assert split in {"train", "valid", "test"}, "split must be 'train', 'valid', or 'test'"

    base_dir = os.path.join(prime_path, project_name)
    #image_dir = os.path.join(base_dir, "images")
    snapshot_dir = os.path.join(base_dir, "dataset") # the dataset with the snasphots created in the Project



    src_path = os.path.join(ds_path, dataset_name)

    if not os.path.exists(src_path):
        raise FileNotFoundError(f"Dataset directory not found: {src_path}")

    # Determine the snapshot to use
    if snapshot_name=="rf_split.json":
        snapshot_path = os.path.join(ds_path, dataset_name, snapshot_name)
    elif snapshot_name:
        snapshot_path = os.path.join(snapshot_dir, snapshot_name)
    else:
        # Use latest snapshot by timestamp
        all_snapshots = sorted([f for f in os.listdir(snapshot_dir) if f.endswith(".json")])
        if not all_snapshots:
            raise FileNotFoundError(f"No snapshot splits found in: {snapshot_dir}")
        snapshot_path = os.path.join(snapshot_dir, all_snapshots[-1])

    with open(snapshot_path) as f:
        snapshot = json.load(f)

    split_images = snapshot["splits"].get(split)
    if not split_images:
        raise ValueError(f"Split '{split}' not found in snapshot: {snapshot_path}")

    dst_path = os.path.join(local_base, split)
    os.makedirs(dst_path, exist_ok=True)

    print(f" Importing <{split}> data from GDrive: {src_path} → {dst_path} \n ...")

    fast_copy_dataset(src_path, dst_path, max_files=max_files, file_list=split_images)

    print(f" Filtering annotations of <{split}> for max image number = {max_files} \n",
    f"categories selected: {used_category_names} \n...")

    filter_annotations(src_path, dst_path, file_list=split_images, max_files=max_files, category_names=used_category_names)

    return dst_path



def import_dataset(source="local", path_or_url=None, project_name=None, dataset_name=None, snapshot_name=None, split="train", used_category_names=None, max_files=None):
    """
    Generic import handler supporting multiple sources.
    source: "local", "local_zip", "roboflow", "gdrive"
    """
    dest_path = os.path.join(prime_path, project_name, "dataset")
    dataset_main = os.path.join(ds_path, dataset_name)

    if source == "local":
        if not os.path.exists(path_or_url):
            raise FileNotFoundError(f"Local path not found: {path_or_url}")
        shutil.copytree(path_or_url, dataset_main, dirs_exist_ok=True)

    elif source == "roboflow":
        os.makedirs(dest_path, exist_ok=True)
        zip_path = os.path.join(dest_path, "dataset.zip")
        os.system(f'wget "{path_or_url}" -O {zip_path}')
        rf_ds_path = "/content/dataset"
        import_roboflow_flattened(rf_ds_path, source=path_or_url, output_path=dataset_main, snapshot_output_dir=dataset_main)


    elif source == "local_zip":
        os.system(f'unzip -q {path_or_url} -d /content/dataset')
        shutil.copytree("/content/dataset", dataset_main, dirs_exist_ok=True)

    elif source == "gdrive":
        if project_name is None:
            raise ValueError("project_name must be provided when using source='gdrive'")
        elif not dataset_name:
            raise ValueError("dataset_name must be provided when using source='gdrive'")
        return  import_gdrive_project_dataset(project_name, dataset_name, split=split, snapshot_name=snapshot_name, max_files=max_files)

    else:
        raise ValueError(f"Unsupported dataset source: {source}")


In [None]:
# @title Eliminate redundant folders in case of unsuccesful runs
from glob import glob

def get_latest_experiment(base_path, skip=0):
    all_runs = [f for f in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, f))]
    all_runs.sort(reverse=True)

    if len(all_runs) > skip:
        return all_runs[skip]
    return None

def delete_run_if_empty_or_unused(run_path):
    checkpoints_path = os.path.join(run_path, "model_checkpoints")
    logs_path = os.path.join(run_path, "logs")

    should_delete = False

    # Case 1: Whole run folder is empty
    if not os.listdir(run_path):
        should_delete = True

    # Case 2: Checkpoints folder exists but is empty
    elif os.path.exists(checkpoints_path) and not os.listdir(checkpoints_path):
        should_delete = True

    if should_delete:
        shutil.rmtree(run_path)
        print(f"Deleted unused run folder: {run_path}")


In [None]:
# @title Inquiring Training Directory for existing models
import re
from datetime import datetime


# Step 1: list all projects (like Football_players)
all_projects = [proj for proj in os.listdir(prime_path) if os.path.isdir(os.path.join(prime_path, proj))]

# Step 2. Build nested structure: project -> runs -> checkpoints
project_runs = {}

for project in all_projects:
    project_dir = os.path.join(prime_path, project)
    all_runs = [run for run in os.listdir(project_dir) if os.path.isdir(os.path.join(project_dir, run))
                and run.startswith(f"{project}_")]

    runs_dict = {}
    for run in all_runs:
        run_path = os.path.join(project_dir, run)
        ckpt_dir = os.path.join(run_path, "model_checkpoints")
        if os.path.exists(ckpt_dir):
            checkpoints = [f for f in os.listdir(ckpt_dir) if f.endswith(".pth")]
            if checkpoints:
                # Save runs inside project
                runs_dict[run] = checkpoints
            else :
                runs_dict[run] = []

    if runs_dict:
        project_runs[project] = runs_dict


In [None]:
# @title Get-config-stats-data functions

#extract category and image counts
def get_filtered_dataset_stats(annotation_path, selected_category_names):
    """
    Returns:
    - number of selected categories
    - number of images that have at least one annotation in those categories
    """
    if not os.path.exists(annotation_path):
        return {"num_categories": 0, "num_images": 0}

    with open(annotation_path, 'r') as f:
        data = json.load(f)

    # Map category name to ID
    category_name_to_id = {
        cat["name"]: cat["id"]
        for cat in data.get("categories", [])
        if cat["name"] in selected_category_names
    }

    selected_cat_ids = set(category_name_to_id.values())

    # Map image_id to categories present
    image_to_cats = {}
    for ann in data.get("annotations", []):
        if ann["category_id"] in selected_cat_ids:
            image_to_cats.setdefault(ann["image_id"], set()).add(ann["category_id"])

    num_filtered_images = len(image_to_cats)

    return {
        "num_categories": len(selected_cat_ids),
        "num_images": num_filtered_images
    }


#Summary of selected run
def get_run_summary(run_path):
    info_file = os.path.join(run_path, "experiment_info.json")
    logs_file = os.path.join(run_path,"logs", f"metrics_{run}.json")
    summary = {}
    logs_info = {}


    # Attempt to load logs file
    try:
        with open(logs_file, "r") as f:
            logs_info = json.load(f)
    except:
        logs = {}
    info ={}
    try:
        with open(info_file, "r") as f:
            info = json.load(f)

        # Pull structured values
        summary["project"] = info.get("project")
        summary["created_at"] = info.get("created_at")
        summary["dataset_name"] = info.get("dataset_name")
        summary["snapshot_splits"] = info.get("snapshot_splits")
        summary["split_description"] = info.get("split_description")
        summary["run_description"] = info.get("run_description")
        summary["resolution"] = info.get("hyperparameters", {}).get("resolution")
        summary["epochs"] = info.get("hyperparameters", {}).get("epochs")
        summary["GPU_Name"] = info.get("gpu", {}).get("GPU_Name")
        summary["GPU_Perf"] = info.get("gpu", {}).get("Perf")
        summary["batch_size"] = info.get("hyperparameters", {}).get("batch_size")
        summary["grad_accum_steps"] = info.get("hyperparameters", {}).get("grad_accum_steps")
        summary["learning_rate"] = info.get("hyperparameters", {}).get("learning_rate")

        summary["max_num_img"] = info.get("max_num_img")
        summary["num_images"] = info.get("num_images")
        summary["num_categories"] = info.get("num_categories")
        summary["categories"] = info.get("categories")

    except :
        info = {}

     # Extract mAP
    if isinstance(logs_info, list) and logs_info:
        last_entry = logs_info[-1]
        summary["Last mAP"] = last_entry.get("test_coco_eval_bbox")
    else:
        summary["Last mAP"] = None

    return summary

# Retrieve dataset name of a project
def get_dataset_of_project(project_name):
    config_path = os.path.join(prime_path, project_name, "config.json")
    if os.path.exists(config_path):
        with open(config_path, "r") as f:
            data = json.load(f)
            return data.get("dataset_name", None)
    return None

# Configurations of the experiment

## Provide dataset for training
Imports dataset from various sources into the experiment's dataset folder.

source: one of ["local", "local_zip", "roboflow", "gdrive"]

path or url: str type- either a file path (local, gdrive), URL (roboflow universe):

### Configure your model
Please run the configuration cell and click 'Save Configuration'.

In [None]:
# @title Widgets constructs
import ipywidgets as widgets
from IPython.display import display
import os
from IPython.display import Javascript
import random


#----------Create new model UI--------------
###########################################

# ---------- Project Name ----------
project_widget = widgets.Text(
    value="",
    description="Project Name:",
    style={'description_width': 'initial'}
)

model_name_widget = widgets.Text(
    value="RF-DETR",
    description="Model Name:",
    style={'description_width': 'initial'}
)

experiment_description_widget = widgets.Textarea(
    value="", #The Mahjong dataset aims to facilitate the development of object
    # detection models that can accurately identify and distinguish different mahjong tiles.
    description="Description of the project:",
    layout=widgets.Layout(width='100%', height='80px'),
    style={'description_width': 'initial'}
)

# ---------- Dataset Name ----------
dataset_widget = widgets.Text(
    value="",
    description="Dataset Name:",
    style={'description_width': 'initial'}
)

#------------ Dropdown for dataset source-------------------
source_widget = widgets.Dropdown(
    options=["roboflow", "local" ,"local_zip", "gdrive"],
    value="gdrive",
    description="Dataset Source:",
    style={'description_width': 'initial'}
)

existing_dataset_dropdown = widgets.Dropdown(
    options=["Select dataset"],
    description="Choose Dataset from GDrive:"
)

path_URL_widget = widgets.Text(
    value="", # https://universe.roboflow.com/ds/MTIfKJbiPR?key=g6bXSfng8g
    description="Path/URL to dataset source ('None' for local/gdrive source):",
    layout=widgets.Layout(width='60%', display='none'),
    style={'description_width': 'initial'}
)

# ---------- Dataset Name Dropdown (GDrive only) ----------
def get_gdrive_datasets():
    return [d for d in os.listdir(ds_path)
    if os.path.isdir(os.path.join(ds_path, d))]

gdrive_dataset_dropdown = widgets.Dropdown(
    options=["Select a dataset"] + ["Create new dataset..."] + get_gdrive_datasets(),
    description="Select or Create Dataset:",
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)


# ---------- Category Selector Placeholder ----------
category_checkboxes = []
category_box = widgets.VBox([])
scrollable_categories = widgets.Box([category_box], layout=widgets.Layout(
    overflow='auto',
    border='1px solid gray',
    width='100%',
    height='150px',
    flex_flow='column',
    display='flex'
))

# Check/Uncheck All
check_all_btn = widgets.Button(description="Check All", layout=widgets.Layout(width='150px'))
uncheck_all_btn = widgets.Button(description="Uncheck All", layout=widgets.Layout(width='150px'))


#------------  Dataset Splitting------------------
max_num_img_widget = widgets.IntText(
    value=None,
    description="Max Images",
    placeholder='Leave blank for all',
)

split_mode_radio = widgets.RadioButtons(
    options=["Use existing snapshot", "Create new split"],
    value="Use existing snapshot",
    description="Split Mode:",
    style={'description_width': 'initial'}
)

# Dropdown for existing snapshot
existing_split_dropdown = widgets.Dropdown(
    options=["Select a split"],
    description="Snapshot Split:",
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

split_description_widget = widgets.Textarea(
    description="Description of the split:",
    layout=widgets.Layout(width='60%', height='50px'),
    style={'description_width': 'initial'}
)
split_description_label = widgets.HTML()
# Percentage fields
train_split = widgets.BoundedIntText(
    value=80, min=0, max=100, description="Train %",
    style={'description_width': 'initial'}
)

val_split = widgets.BoundedIntText(
    value=10, min=0, max=100, description="Val %",
    style={'description_width': 'initial'}
)
test_split = widgets.BoundedIntText(
    value=10, min=0, max=100, description="Test %",
    style={'description_width': 'initial'}
)


# ---------- Split Configuration ----------

split_warning = widgets.HTML("")

# Group percent inputs in one box
split_percent_inputs = widgets.HBox([train_split, val_split, test_split])
split_percent_inputs.layout.display = 'none'  # Hide initially

# ------------------New Run-----------------------
#################################################

run_description_widget = widgets.Textarea(
    value="",
    description="Description of the run:",
    layout=widgets.Layout(width='100%', height='80px'),
    style={'description_width': 'initial'}
)

# ----------Retake Ongoing Training--------
###########################################
# Choose a model
project_dropdown = widgets.Dropdown(
    options=["Select a project"] + list(project_runs.keys()),
    description="Choose Project:",
    style={'description_width': 'initial'}
)

# Date Pickers (Start and End)
date_start_picker = widgets.DatePicker(
    description='Start Date',
    style={'description_width': 'initial'}
)

date_end_picker = widgets.DatePicker(
    description='End Date',
    style={'description_width': 'initial'}
)

# Run Dropdown
run_dropdown = widgets.Dropdown(
    options=[],
    description="Choose Run:",
    style={'description_width': 'initial'}
)
# Checkpoint Dropdown
ckpt_dropdown = widgets.Dropdown(
    options=[],
    description="Choose Checkpoint:",
    style={'description_width': 'initial'}
)

run_summary_label = widgets.HTML("")

# -----------Hyperparameters-----------------
epochs_widget = widgets.IntText(
    value=50,
    description="Epochs:",
    style={'description_width': 'initial'}
)

batch_size_widget = widgets.IntText(
    value=4,
    description="Batch Size:",
    style={'description_width': 'initial'}
)

grad_accum_widget = widgets.IntText(
    value=4,
    description="Grad Accum Steps:",
    style={'description_width': 'initial'}
)

lr_widget = widgets.FloatText(
    value=1e-4,
    description="Learning Rate:",
    style={'description_width': 'initial'}
)


# Slider for images resolution input
resolution_widget = widgets.IntSlider(
    value=560,
    min=224,
    max=1120,
    step=56,
    description='Resolution:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='30%')
)


# split UI box
split_ui = widgets.VBox([
    widgets.HTML("<h4>Dataset Split Configuration</h4>"),
    max_num_img_widget,
    split_mode_radio,
    existing_split_dropdown,
    split_description_label,
    split_percent_inputs,
    split_warning,
    split_description_widget
])

# Source UI box
source_ui_Hbox = widgets.HBox([
    source_widget,
    path_URL_widget,
])
source_ui = widgets.VBox([
    dataset_widget,
    source_ui_Hbox,
])

# Hyperparameters UI box
hyper_param_ui = widgets.VBox([
    resolution_widget,
    widgets.HBox([epochs_widget, batch_size_widget]),
    widgets.HBox([grad_accum_widget, lr_widget]),

])

mode_selector = widgets.ToggleButtons(
    options=['New Project', 'New Run', 'Resume Training'],
    description='Training Mode:',
    style={'description_width': 'initial'}
)

# Save Button
save_button = widgets.Button(description="Save Configuration", button_style='success')


# New Project UI box
create_ui = widgets.VBox([
    widgets.HTML("<h3>Create New Model Configuration</h3>"),
    gdrive_dataset_dropdown,
    source_ui,
    widgets.HTML("<b>Select Categories:</b>"),
    scrollable_categories,
    widgets.HBox([check_all_btn, uncheck_all_btn]),
    project_widget,
    model_name_widget,
    experiment_description_widget,
    split_ui,
    hyper_param_ui,
    run_description_widget,
])

# New Run UI box
new_run_ui = widgets.VBox([
    widgets.HTML("<h3>New Run Configuration</h3>"),
    project_dropdown,
    run_description_widget,
    split_ui,
    hyper_param_ui,
])

# Resume Training UI box
retake_ui = widgets.VBox([
    widgets.HTML("<h3>Retake Ongoing Training</h3>"),
    project_dropdown,
    widgets.HBox([date_start_picker, date_end_picker]),
    widgets.VBox([run_dropdown,run_summary_label]),
    ckpt_dropdown,
    epochs_widget
])

 # containers
create_container = widgets.VBox([create_ui])
new_run_container = widgets.VBox([new_run_ui])
retake_container = widgets.VBox([retake_ui])



In [None]:
#@title Training UI
"""
Please run the configuration cell and click 'Save Configuration'.
"""


def update_path_url(change):
    if change['new'] in ['gdrive', 'local', 'local_zip']:
        path_URL_widget.layout.display = 'none'
    else:
        path_URL_widget.layout.display = 'block'
        path_URL_widget.value = 'https://universe.roboflow.com/your-default-url'



def handle_dataset_selection(change):
    if mode_selector.value != "New Project":
        return

    if change['new'] == "Create new dataset...":
        show_dataset_import_ui()
    else:
        hide_dataset_import_ui()


def check_uncheck_all(state):
    for cb in category_checkboxes:
        cb.value = state

check_all_btn.on_click(lambda b: check_uncheck_all(True))
uncheck_all_btn.on_click(lambda b: check_uncheck_all(False))


# ---------- Load Categories on Dataset Change ----------
def load_categories_from_dataset(change):
    dataset_name = change['new']
    annotation_path = os.path.join(ds_path, dataset_name, "_annotations.coco.json")
    if not os.path.exists(annotation_path):
        category_box.children = [widgets.Label("No _annotations.coco.json found.")]
        return

    with open(annotation_path, 'r') as f:
        data = json.load(f)

    categories = data.get("categories", [])
    category_checkboxes.clear()
    checkboxes = []
    # Create New Checkboxes for Each Category
    for cat in categories:
        cb = widgets.Checkbox(value=True, description=cat["name"])
        cb.observe(update_split_mode_based_on_categories, names='value')
        category_checkboxes.append(cb)
        checkboxes.append(cb)

    category_box.children = checkboxes

    update_split_mode_based_on_categories()



def validate_splits(change):
    total = train_split.value + val_split.value + test_split.value
    if total != 100:
        split_warning.value = f"<span style='color: red;'> Split total must be 100% (Current: {total}%)</span>"
    elif total == 0:
        split_warning.value = f"<span style='color: red;'> Cannot compute ratios. Total split is 0.</span>"
    else :
        split_warning.value = f"<span style='color: green;'></span>"
    return total




# Hide/show fields based on split mode
def toggle_split_mode(change):
    mode = change['new']
    if mode == "Use existing snapshot":
        existing_split_dropdown.layout.display = 'block'
        split_percent_inputs.layout.display = 'none'
        split_description_widget.layout.display = 'none'
    else:
        existing_split_dropdown.layout.display = 'none'
        split_percent_inputs.layout.display = 'block'
        split_description_widget.layout.display = 'block'
        validate_splits(change)


# Will be populated dynamically with snaphots based on selected dataset
def update_combined_split_dropdown(change=None):
    project, dataset_name = get_active_project_dataset()

    #if not project or not dataset_name:
    #    existing_split_dropdown.options = []
    #    return

    custom_splits = []
    # List of custom splits of the project
    project_dataset_dir = os.path.join(prime_path, project, "dataset")
    if  os.path.exists(project_dataset_dir):
          custom_splits = sorted([
              f for f in os.listdir(project_dataset_dir)
              if f.startswith("split_") and f.endswith(".json")
          ])

    # Roboflow split is added to existing splits
    if dataset_name is not None:
        rf_split_path = os.path.join(ds_path, dataset_name, "rf_split.json")
        if os.path.exists(rf_split_path):
            custom_splits.append("rf_split.json")

    if custom_splits:
        existing_split_dropdown.options = ["Select a split"] + custom_splits
    else:
        existing_split_dropdown.options = []



def update_split_mode_based_on_categories(change=None):
    # Filter checkboxes only
    checkboxes = [cb for cb in category_box.children if isinstance(cb, widgets.Checkbox)]
    if not checkboxes:
        return  # no categories available

    selected_values = [cb.value for cb in checkboxes]
    all_selected = all(selected_values)

    if not all_selected:
        # Only allow creating a new snapshot
        split_mode_radio.options = ["Create new snapshot"]
        split_mode_radio.value = "Create new snapshot"
    else:
        # Enable both options, default to existing
        split_mode_radio.options = ["Use existing snapshot", "Create new snapshot"]
        split_mode_radio.value = "Use existing snapshot"

    update_combined_split_dropdown()


checkboxes = [cb for cb in category_box.children if isinstance(cb, widgets.Checkbox)]

for cb in checkboxes:
    cb.observe(update_split_mode_based_on_categories, names='value')

# Extract datetime from run name string.
def extract_run_datetime(run_name):
    match = re.search(r'(\d{8}_\d{6})', run_name)
    if match:
        return datetime.strptime(match.group(1), "%Y%m%d_%H%M%S")
    return None


# Auto-fill date pickers based on available runs.
def update_date_pickers(project):
    runs = list(project_runs.get(project, {}).keys())
    dates = []

    for run in runs:
        run_dt = extract_run_datetime(run)
        if run_dt:
            dates.append(run_dt)

    if dates:
        min_date = min(dates).date()
        max_date = max(dates).date()
        date_start_picker.value = min_date
        date_end_picker.value = max_date


def update_runs(change=None):
    project = project_dropdown.value
    update_date_pickers(project)  # Auto-update dates

    runs = list(project_runs.get(project, {}).keys())

    # Date filttering
    if date_start_picker.value and date_end_picker.value:
        start_dt = datetime.combine(date_start_picker.value, datetime.min.time())
        end_dt = datetime.combine(date_end_picker.value, datetime.max.time())

        filtered_runs = []
        for run in runs:
            match = re.search(r'(\d{8}_\d{6})', run)
            if match:
                run_dt = datetime.strptime(match.group(1), "%Y%m%d_%H%M%S")
                if start_dt <= run_dt <= end_dt:
                    filtered_runs.append(run)

        run_dropdown.options = filtered_runs
        if filtered_runs:
            run_dropdown.value = filtered_runs[-1]
        else:
            run_dropdown.options = []
    else:
        run_dropdown.options = runs
        if runs:
            run_dropdown.value = runs[0]



def update_run_summary(change):
    run_path = os.path.join(prime_path, project_dropdown.value, change['new'])
    summary = get_run_summary(run_path)
    run_summary_label.value = f"<b>Selected Run:</b> {summary}"


def update_split_description(change):
    selected_split = change['new']
    project, dataset_name = get_active_project_dataset()
    project_dataset_dir = os.path.join(prime_path, project, "dataset")

    if selected_split and selected_split != "Select a split":
        split_path = os.path.join(project_dataset_dir, selected_split)
        if not os.path.exists(split_path):  # try rf_split as fallback
            split_path = os.path.join(ds_path, dataset_name, selected_split)

        try:
            with open(split_path, 'r') as f:
                data = json.load(f)
                desc = data.get("description", "No description found in split file.")
                split_description_label.value = f"<b>Split Description:</b> {desc}"
        except Exception as e:
            split_description_label.value = f"<b>Error loading description:</b> {str(e)}"
    else:
        split_description_label.value = ""


def update_checkpoints(change):
    project = project_dropdown.value
    run = change['new']
    checkpoints = project_runs.get(project, {}).get(run, [])

    if 'last.pth' in checkpoints:
        ckpt_dropdown.options = checkpoints
        ckpt_dropdown.value = 'last.pth'
    elif checkpoints:
        ckpt_dropdown.options = checkpoints
        ckpt_dropdown.value = checkpoints[0]
    else:
        ckpt_dropdown.options = []

# Separate project_dataset_name logic per mode
def get_active_project_dataset():
    mode = mode_selector.value
    if mode == "New Project":
        return project_widget.value, gdrive_dataset_dropdown.value
    elif mode in ["New Run", "Resume Training"]:
        project = project_dropdown.value
        # Load from folder name inside project
        dataset_name = get_dataset_of_project(project)
        try:
            return project, dataset_name
        except:
            return project, None
        return project, None
    return None, None



# Accept only values divisible by 56 (e.g. 448, 560, 672, etc.)
def validate_resolution(change):
    val = change['new']
    if val % 56 != 0:
        resolution_widget.value = val - (val % 56)
        print(f"Adjusted resolution to nearest valid value: {resolution_widget.value}")



config_complete = False  # Flag to block the rest until setup is done

config_values = {}

def on_button_clicked(b):
    mode = mode_selector.value
    config_values.clear()
    project, dataset_name = get_active_project_dataset()
    base_path = os.path.join(prime_path, project)
    dataset_path = os.path.join(ds_path, dataset_name)
    os.makedirs(base_path, exist_ok=True)

    # Get selected categories
    checkboxes = [cb for cb in category_box.children if isinstance(cb, widgets.Checkbox)]
    selected_categories = [cb.description for cb in checkboxes if cb.value]
    num_selected_categories = len(selected_categories)

    # Optional: also compute filtered image count
    annotation_path = os.path.join(dataset_path, "_annotations.coco.json")
    filtered_stats = get_filtered_dataset_stats(annotation_path, selected_categories)


    if mode == 'New Project':

        config_values.update({
            "mode": "New Project",
            "project": project,
            "dataset_name": dataset_name,
            "base_path": base_path,
            "model_name": model_name_widget.value,
            "project_description": experiment_description_widget.value,
            "run_description": run_description_widget.value,
            "num_categories": filtered_stats["num_categories"],
            "categories": selected_categories,
            "num_images": filtered_stats["num_images"],
          })



    elif mode in ["New Run", "Resume Training"]:
        run = run_dropdown.value
        run_path = os.path.join(base_path, run)
        run_description = run_description_widget.value

        run_info = get_run_summary(run_path)

        config_values.update({
            "mode": "New Run",
            "project": project,
            "run": run,
            "run_description": run_description,
            "dataset_name": dataset_name,
            "base_path": base_path,
            "num_categories": run_info.get("num_categories",filtered_stats["num_categories"]),
            "num_images": run_info.get("num_images",filtered_stats["num_images"]),
            "categories": run_info.get("categories"),
            "resolution": resolution_widget.value,

            "batch_size": batch_size_widget.value,
            "grad_accum_steps":grad_accum_widget.value,
            "learning_rate":  lr_widget.value,
            "dataset_path": dataset_path,
            "dataset_source": source_widget.value,
            "path_or_url": path_URL_widget.value,
        })

        if mode == "Resume Training":
            checkpoint = ckpt_dropdown.value
            ckpt_path = os.path.join(base_path, run, "model_checkpoints", checkpoint)

            config_values.update({
            "mode": "Resume Training",
            "run_description": run_info.get("run_description"),
            "dataset_path": dataset_path,
            "checkpoint": checkpoint,
            "resume_from": ckpt_path,
            "dataset_name": run_info.get("dataset_name",dataset_name),
            "resolution": run_info.get("resolution", resolution_widget.value),
            "max_num_img": run_info.get("max_num_img"),
            "batch_size": run_info.get("batch_size", batch_size_widget.value),
            "grad_accum_steps": run_info.get("grad_accum_steps", grad_accum_widget.value),
            "learning_rate": run_info.get("learning_rate", lr_widget.value),
            "num_categories": run_info.get("num_categories"),
            "num_images": run_info.get("num_images"),
            "dataset_source": run_info.get("dataset_source", source_widget.value),
            "path_or_url": run_info.get("path_or_url", path_URL_widget.value),
            "split_mode": "existing",
            "selected_split": run_info.get("selected_split")

            })
    config_values["epochs"] = epochs_widget.value

    if mode in ["New Run", "New Project"]:
        max_imgs = max_num_img_widget.value
        if isinstance(max_imgs, int) and max_imgs > 0:
            config_values["max_num_img"] = max_imgs
        else:
            config_values["max_num_img"] = None

        if split_mode_radio.value == "Create new snapshot":

            total = validate_splits()
            split_ratios = (
                train_split.value / total,
                val_split.value / total,
                test_split.value / total
            )
            norm_split = tuple(round(r, 3) for r in split_ratios)
            config_values["split_ratios"] = norm_split
            config_values["split_mode"] = "new_snapshot"
            config_values["split_description"] = split_description_widget.value
        else:
            config_values["split_mode"] = "existing"
            config_values["selected_split"] = existing_split_dropdown.value

    preview_config()


# - - - - - - - - - - - - - Run Observers - - -- - - - - - - - - - - - - -
gdrive_dataset_dropdown.observe(load_categories_from_dataset, names='value')
train_split.observe(validate_splits, names='value')
val_split.observe(validate_splits, names='value')
test_split.observe(validate_splits, names='value')
existing_split_dropdown.observe(update_split_description, names='value')

# Observe category checkbox state changes
for cb in scrollable_categories.children:
    if isinstance(cb, widgets.Checkbox):
        cb.observe(update_split_mode_based_on_categories, names='value')

resolution_widget.observe(validate_resolution, names='value')
source_widget.observe(update_path_url, names='value')
split_mode_radio.observe(toggle_split_mode, names='value')

gdrive_dataset_dropdown.observe(handle_dataset_selection, names='value')

gdrive_dataset_dropdown.observe(lambda change: update_combined_split_dropdown(), names='value')

def safe_update_splits(change):
    if mode_selector.value != "New Project":
        update_combined_split_dropdown(change)

project_dropdown.observe(safe_update_splits, names='value')
run_dropdown.observe(update_run_summary, names='value')
project_widget.observe(update_combined_split_dropdown, names='value')

project_dropdown.observe(update_runs, names='value')

run_dropdown.observe(update_checkpoints, names='value')



update_split_mode_based_on_categories()

save_button.on_click(on_button_clicked)


def show_dataset_import_ui():
  source_ui.layout.display = 'block'
def hide_dataset_import_ui():
  source_ui.layout.display = 'none'


# Hide/show logic on mode
def toggle_ui_visibility(change=None):
    mode = mode_selector.value
    create_container.layout.display = 'none'
    new_run_container.layout.display = 'none'
    retake_container.layout.display = 'none'

    # Reset dataset input fields when mode changes
    dataset_widget.value = ""
    path_URL_widget.value = ""
    gdrive_dataset_dropdown.options = ["Select a dataset"] + get_gdrive_datasets()
    if mode == "New Project":
        gdrive_dataset_dropdown.options = ["Select a dataset"] + ["Create new dataset..."] + get_gdrive_datasets()
        create_container.layout.display = 'block'
        new_run_container.layout.display = 'none'
        retake_container.layout.display = 'none'
    elif mode == "New Run":
        new_run_container.layout.display = 'block'
        create_container.layout.display = 'none'
        retake_container.layout.display = 'none'
    elif mode == "Resume Training":
        retake_container.layout.display = 'block'
        create_container.layout.display = 'none'
        new_run_container.layout.display = 'none'

    # force update based on current mode
    update_combined_split_dropdown()


# Call once to initialize
mode_selector.observe(toggle_ui_visibility, names='value')
toggle_ui_visibility()

def preview_config():
    preview = widgets.HTML("<b>Configuration Preview:</b><br>" + "<br>".join([
        f"{k}: {v}" for k, v in config_values.items()
    ]))
    display(preview)

# Display all widgets
display(widgets.VBox([
    widgets.HTML("<h2>Training Configuration</h2>"),
    mode_selector,
    widgets.HTML("<hr>"),
    create_container,
    new_run_container,
    retake_container,
    ######

    widgets.HTML("<hr>"),
    save_button
]))



VBox(children=(HTML(value='<h2>Training Configuration</h2>'), ToggleButtons(description='Training Mode:', opti…

In [None]:
# @title Validate Configuration Settings - Run cell to continue ->
from IPython.display import display, Javascript

# Wait for config to be saved
assert config_values, "Please click 'Save Configuration' first!"

print("Configuration complete. Proceeding...")


def run_all_cells_below():
    display(Javascript('IPython.notebook.execute_cells(IPython.notebook.get_selected_index()+1, IPython.notebook.ncells())'))

run_all_cells_below()
## Run-all downward from this cell :
#              ↓ ↓
#              ↓ ↓
#              ↓ ↓
#            ↓ ↓ ↓ ↓
#             ↓ ↓ ↓
#               ↓

AssertionError: Please click 'Save Configuration' first!

In [None]:
# @title Install libraries. Log in to W&B API
!pip install rfdetr==1.1.0 -qU
!pip install wandb -qU
!pip install "rfdetr[metrics]" -qU
import wandb

#If not introduced in the envrion key, you'll be prompted to enter an API key once (NOTE: key= str type)
wandb.login(key=os.environ["W&B_API_KEY"])

# Initialization of the experiment

In [None]:
# @title Define hyperparameters
import os
import json
from datetime import datetime


PROJECT = config_values["project"]
experiment_description = config_values["run_description"]

# Path to project dir
base_path = config_values["base_path"]
source_dataset = config_values["dataset_source"]
path_url_dataset = config_values["path_or_url"]
dataset_name = config_values["dataset_name"]
categories = config_values["categories"]


# Path to dataset linked to project
dataset_path = config_values["dataset_path"]
num_selected_categories = config_values.get("num_categories", "N/A")
num_filtered_images = config_values.get("num_images", "N/A")
max_files = config_values.get("max_num_img",None)
gpu_info = get_gpu_info()

# Configs of the experiment - Customize as needed

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Hyperparameters
RESOLUTION=config_values["resolution"]
EPOCHS = config_values["epochs"]
BATCH_SIZE = config_values["batch_size"]
GRAD_ACCUM_STEPS = config_values["grad_accum_steps"]
LR = config_values["learning_rate"]


snapshot_splits=config_values["selected_split"]
if snapshot_splits == "Select a split":
    snapshot_splits = "rf_split.json"



# Create New RUN Folder
run_name = f"{PROJECT}_batch{BATCH_SIZE}_grad{GRAD_ACCUM_STEPS}_{timestamp}"
run_path = os.path.join(base_path, run_name)
os.makedirs(run_path, exist_ok=True)

# SUBFOLDERS
folders = {
    "checkpoints": os.path.join(run_path, "model_checkpoints"),
    "logs": os.path.join(run_path, "logs"),
}

for folder in folders.values():
    os.makedirs(folder, exist_ok=True)


if config_values["mode"] ==  "New project":

      project_description = config_values["project_description"]
      model_name = config_values["model_name"]


      # Save experiment_info.json
      new_project_info = {
          "project": PROJECT,
          "project_description": project_description,
          "model": model_name,
          "dataset_name": dataset_name,
          "num_categories":num_selected_categories,
          "categories": categories,
          "created_at": timestamp
      }
      # Write project info if new project
      proj_config=os.path.join(base_path, f"config.json")

      with open(proj_config, "w") as f:
          json.dump(new_project_info, f, indent=4)
      print(f"New project created. [ {proj_config}]")

      if config_values["split_mode"] == "new_snapshot":
          split_description = config_values["split_description"]
          split_ratios = config_values["split_ratios"]
          image_dir = dataset_path
          output_dir = os.join(base_path,"dataset")

          # New split with split_ratios
          snapshot_splits = create_snapshot_split(image_dir, split_ratios=split_ratios, description=split_description, output_dir=output_dir)
else:
      # Delete empty previous runs
      last_run= os.path.join(base_path, get_latest_experiment(base_path,0))
      delete_run_if_empty_or_unused(last_run)

if config_values["mode"] !=  "Resume Training":
    # Save experiment_info.json
    experiment_info = {
        "project": PROJECT,
        "run_name": run_name,
        "dataset_name": dataset_name,
        "snapshot_splits":snapshot_splits,
        "run_description": experiment_description,
        "gpu": {
            "GPU_Name": gpu_info['gpu_name'],
            "Perf": gpu_info['pstate'],
            "Total Memory MB": gpu_info['memory_total']
        },
        "hyperparameters": {
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "grad_accum_steps": GRAD_ACCUM_STEPS,
            "learning_rate": LR,
            "resolution": RESOLUTION,
        },
        "num_images":num_filtered_images,
        "max_num_img":max_files,
        "num_categories":num_selected_categories,
        "categories": categories,
        "created_at": timestamp
    }

    # Write run info
    with open(os.path.join(run_path, "experiment_info.json"), "w") as f:
        json.dump(experiment_info, f, indent=4)

    print(f" Experiment folder created at: {run_path}")
    print(f" Subfolders: {', '.join(folders.keys())}")
    print(f"Configurations of experiment saved to experiment_info.json")
else:
    print(f"Experiment folder: {run_path}")
    print(f"Configurations availabe in \"experiment_info.json\"")


# Import datasets
import_dataset(source=source_dataset, path_or_url=path_url_dataset, project_name=PROJECT, dataset_name=dataset_name, snapshot_name=snapshot_splits, split="train",used_category_names=categories, max_files=max_files)

import_dataset(source=source_dataset, path_or_url=path_url_dataset, project_name=PROJECT, dataset_name=dataset_name, snapshot_name=snapshot_splits, split="valid",used_category_names=categories, max_files=max_files)



In [None]:
# @title Patched method version
from rfdetr import RFDETRBase
from rfdetr.util.metrics import MetricsPlotSink, MetricsTensorBoardSink, MetricsWandBSink
import os
import json
from rfdetr.config import TrainConfig
import logging


logger = logging.getLogger(__name__)
shared_metrics = {}

def train_from_config2(self, config: TrainConfig, **kwargs):
        start_epoch = getattr(config, "resume_from", 0)
        with open(
            os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r"
        ) as f:
            anns = json.load(f)
            num_classes = len(anns["categories"])
            class_names = [c["name"] for c in anns["categories"] if c["supercategory"] != "none"]
            self.model.class_names = class_names

        if self.model_config.num_classes != num_classes:
            logger.warning(
                f"num_classes mismatch: model has {self.model_config.num_classes} classes, but your dataset has {num_classes} classes\n"
                f"reinitializing your detection head with {num_classes} classes."
            )
            self.model.reinitialize_detection_head(num_classes)


        train_config = config.dict()
        model_config = self.model_config.dict()
        model_config.pop("num_classes")
        if "class_names" in model_config:
            model_config.pop("class_names")

        if "class_names" in train_config and train_config["class_names"] is None:
            train_config["class_names"] = class_names

        for k, v in train_config.items():
            if k in model_config:
                model_config.pop(k)
            if k in kwargs:
                kwargs.pop(k)

        all_kwargs = {**model_config, **train_config, **kwargs, "num_classes": num_classes}


        metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir)

        def save_plot_on_train_end(_=None):
            metrics_plot_sink.save()

        self.callbacks["on_fit_epoch_end"].append(metrics_plot_sink.update)

        # Modification to base method -------------------------------
        self.callbacks["on_fit_epoch_end"].append(save_plot_on_train_end)
        # -----------------------------------------------------------
        self.callbacks["on_train_end"].append(save_plot_on_train_end)

        if config.tensorboard:
            tensorboard_log_dir = os.path.join(config.output_dir, "logs", "tensorboard")
            metrics_tensor_board_sink = MetricsTensorBoardSink(output_dir=tensorboard_log_dir)
            self.callbacks["on_fit_epoch_end"].append(metrics_tensor_board_sink.update)
            self.callbacks["on_train_end"].append(metrics_tensor_board_sink.close)

        if config.wandb:
            metrics_wandb_sink = MetricsWandBSink(
                output_dir=config.output_dir,
                project=config.project,
                run=config.run,
                config=config.model_dump()
            )

            # Wrap update to capture metrics into shared_metrics -------------------------------
            def wrapped_wandb_update(values):
                if isinstance(values, dict):
                    shared_metrics.clear()
                    shared_metrics.update(values)
            # -----------------------------------------------------------
                metrics_wandb_sink.update(values)
            self.callbacks["on_fit_epoch_end"].append(wrapped_wandb_update)
            self.callbacks["on_train_end"].append(metrics_wandb_sink.close)

        if config.early_stopping:
            from rfdetr.util.early_stopping import EarlyStoppingCallback
            early_stopping_callback = EarlyStoppingCallback(
                model=self.model,
                patience=config.early_stopping_patience,
                min_delta=config.early_stopping_min_delta,
                use_ema=config.early_stopping_use_ema
            )
            self.callbacks["on_fit_epoch_end"].append(early_stopping_callback.update)

        self.model.train(
            **all_kwargs,
            callbacks=self.callbacks,
            start_epoch=start_epoch,
        )

In [None]:
# @title Initialize the model and setup Callback function's data and initialize log tools
import shutil
import glob
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter


# Patch RFDETRBase with custom training logic
RFDETRBase.train_from_config = train_from_config2

# Initializing RFDETR Model
model = RFDETRBase(resolution=RESOLUTION)
optimizer = optim.Adam(model.model.model.parameters(), lr=LR)

local_output = f"/content/training_output_{timestamp}"

logged_metrics = []

#  Setting TrainConfig parameters
config = TrainConfig(
    dataset_dir=local_base,
    output_dir = local_output,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    grad_accum_steps=GRAD_ACCUM_STEPS,
    lr=LR,
    tensorboard=True,
    project=PROJECT,
    run=run_name
)
#  Init WandB
wandb.init(
    project=config.project,
    name=config.run,
    id=config.run,
    resume="allow",
    # .dict() method has been Deprecated in Pydantic V2.0 to be removed in V3.0. in favor of .model_dump().
    config=config.model_dump(),
)

# Syncing callback
def create_logging_callback(model, optimizer, folders, local_output, run_name):
    def callback(data):
        epoch = data['epoch']
        # Use shared metrics collected from wandb sink
        loss = shared_metrics.get('train_loss')
        val_loss = shared_metrics.get('val_loss')
        test_metrics = shared_metrics.get('test_coco_eval_bbox') or [None] * 9
        ema_metrics = shared_metrics.get('ema_test_coco_eval_bbox') or [None] * 9


        def safe_float(x):
            try:
                return float(x)
            except (TypeError, ValueError):
                return None

        # Normalize the values
        test_metrics = [safe_float(x) for x in test_metrics]
        ema_metrics = [safe_float(x) for x in ema_metrics]

        # Create W&B compatible logging format
        metrics = {
            "epoch": epoch,
            "train_loss": safe_float(loss),
            "val_loss": safe_float(val_loss),
            "Metrics/Base/AP50": test_metrics[1],
            "Metrics/EMA/AP50": ema_metrics[1],
            "Metrics/Base/AP50_90": test_metrics[0],
            "Metrics/EMA/AP50_90": ema_metrics[0],
            "Metrics/Base/AR50_90": test_metrics[8],
            "Metrics/EMA/AR50_90": ema_metrics[8],
        }
        # Log to W&B
        wandb.log(metrics)

        # Save raw metrics for local JSON dump
        raw_metrics = {
            "epoch": epoch,
            "train_loss": metrics["train_loss"],
            "test_loss": metrics["val_loss"],
            "test_coco_eval_bbox": test_metrics,
            "ema_test_coco_eval_bbox": ema_metrics
        }
        logged_metrics.append(raw_metrics)


        print(f"[Callback] Epoch {epoch}: loss={loss}, val_loss={val_loss}, mAP={test_metrics}")

        # Save checkpoint every 10 epochs or on last
        if epoch % 10 == 0 or epoch == EPOCHS-1:

            ckpt_path = os.path.join(folders["checkpoints"], f"model_epoch_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'model': model.model.model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, ckpt_path)

            # Sync plot to GDrive
            src_plot = os.path.join(local_output, "metrics_plot.png")
            dst_plot = os.path.join(folders["logs"], f"metrics_plot_epoch_{epoch}.png")
            if os.path.exists(src_plot):
                shutil.copy2(src_plot, dst_plot)


        # Always save latest
        last_ckpt = os.path.join(folders["checkpoints"], "last.pth")
        torch.save({
            'epoch': epoch,
            'model': model.model.model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, last_ckpt)

        # Final save + close wandb EPOCHS-1, (0-indexed)
        if epoch >= EPOCHS-1:
            metrics_json_path = os.path.join(folders["logs"], f"metrics_{run_name}.json")
            with open(metrics_json_path, "w") as f:
                json.dump(logged_metrics, f, indent=2)
            print(f"Saved metrics to: {metrics_json_path}")

            wandb.finish()

        # Sync plot to GDrive
        src_plot = os.path.join(local_output, "metrics_plot.png")
        dst_plot = os.path.join(folders["logs"], f"metrics_plot_{run_name}.png")
        if os.path.exists(src_plot):
            shutil.copy2(src_plot, dst_plot)


        # Sync TensorBoard logs local_output -> gdrive
        src_tb_dir = os.path.join(local_output, "logs", "tensorboard")
        dst_tb_dir = os.path.join(folders["logs"], "tensorboard")
        os.makedirs(dst_tb_dir, exist_ok=True)
        for tb_file in glob.glob(os.path.join(src_tb_dir, "events.out.tfevents.*")):
            shutil.copy2(tb_file, dst_tb_dir)

    return callback


# Attach callback
custom_callback = create_logging_callback(model, optimizer, folders, local_output, run_name)
model.callbacks["on_fit_epoch_end"].append(custom_callback)


## Model Traininig

 Model training will resume from a checkpoint if any or from scratch

In [None]:
# @title Resume from the last run or the checkpoint selected by the user:

if config_values["mode"] ==  "Resume Training":
  latest_exp_name = config_values["run"]

  if latest_exp_name:
      print(f" Latest experiment: {latest_exp_name}")
      ckpt_path = config_values["resume_from"]
      resume_from = ckpt_path if os.path.exists(ckpt_path) else None
      if resume_from:
          print(f" Resuming from: {resume_from}")
          # Sync TensorBoard logs  gdrive -> local_output
          src_tb_dir= os.path.join(base_path,latest_exp_name, "logs", "tensorboard")
          dst_tb_dir= os.path.join(local_output, "logs", "tensorboard")
          os.makedirs(dst_tb_dir, exist_ok=True)
          for tb_file in glob.glob(os.path.join(src_tb_dir, "events.out.tfevents.*")):
              shutil.copy2(tb_file, dst_tb_dir)
      else:
          print(" No checkpoint selected, training will start from scratch.")
else:
    resume_from = None
    print(" No previous experiment found.")

# Set base config
config_updates = {
    "wandb": True,
                  }

resume_epoch = 0

if resume_from:


    checkpoint = torch.load(resume_from, map_location='cpu')
     # Check for model weights
    if isinstance(checkpoint, dict) and 'model' in checkpoint:
        state_dict = checkpoint['model']

        # Check how far it got
        resume_epoch = checkpoint.get('epoch', 0)
        print(f"Found checkpoint with epoch: {resume_epoch}")
        # Pointing to the FOLLOWING epoch to start training with
        resume_epoch = checkpoint["epoch"] + 1

        if resume_epoch <= EPOCHS:
            print("Resuming training...")
            should_resume = True

            # Filter & load weights
            model_dict = model.model.model.state_dict()
            filtered_state_dict = {
                k: v for k, v in state_dict.items()
                if k in model_dict and v.shape == model_dict[k].shape
            }
            model_dict.update(filtered_state_dict)
            model.model.model.load_state_dict(model_dict)

            # Load optimizer
            if 'optimizer' in checkpoint:
                optimizer.load_state_dict(checkpoint['optimizer'])

            wandb.config.update({"resume_from": resume_epoch}, allow_val_change=True)
            config = config.model_copy(update={"resume_from": resume_epoch,})

        else:
            print("Checkpoint indicates training is already complete. Starting fresh.")
    else:
        print("Invalid checkpoint format: 'model' missing. Starting fresh.")

# Update config with all needed info
config = config.model_copy(update=config_updates)


#### Initialize TensorBoard
Initialize before the training loop so it tracks everything from the beginning.

In [None]:
%load_ext tensorboard

%tensorboard --logdir {local_output}/logs/tensorboard

### Begin Training !

In [None]:

from google.colab import runtime
import time
# Train using the patched method
tstart=time.time()
try :
    model.train_from_config(config)
except Exception as e:
    print(f"An error occurred during training: {e}")
else:
    print("Training completed successfully!")
finally:
    print("Training process has ended.")
tend=time.time()


formatted = time.strftime("%H:%M:%S", time.gmtime(tend-tstart))

print(f"Training time lasted: {formatted}")
config_complete = False  # Flag to block another run-all again until setup is re-updated/confirmed in Widget UI

In [None]:
from IPython.display import Image, display

img_path = os.path.join(local_output, "metrics_plot.png")
display(Image(filename=img_path))

In [None]:
# Terminate session

runtime.unassign()
