# Age Estimation using the YOLO algorithm

Authors: Isak Killingrød, Jon A B Larssen, Jon I J Skånøy

About

## Setup

### Imports

In [None]:
# %matplotlib inline
from tqdm.notebook import tqdm

In [None]:
import os
import shutil
import requests
import tarfile
from requests.auth import HTTPBasicAuth
import torch
import torchvision.transforms as transforms
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import numpy as np
import re
import cv2
from concurrent.futures import ThreadPoolExecutor
import glob
from pathlib import Path
import yaml
import json
from datetime import datetime
import time
import optuna
from optuna.trial import TrialState
from ultralytics import YOLO
import platform
from facenet_pytorch import MTCNN
import warnings
import optuna
import random


### Config

In [None]:
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
if platform.system() == "Windows":
    print("Running on Windows")
    STORAGE = 'YOLO_NB_LOCAL'
    MODEL_SIZES = ['n', 's', 'm', 'l','x']
    MODEL_VERSIONS = [8,9,10,11,12]
    IMAGE_SIZE = 416
elif platform.system() == "Linux":
    print("Running on Linux")
    STORAGE = 'YOLO_NB_SERVER'
    MODEL_SIZES = ['x', 'l', 'm', 's', 'n']
    MODEL_VERSIONS = [12,11,10,9,8]
    IMAGE_SIZE = 416
else:
    print(f"Running on {platform.system()}")
    STORAGE = 'YOLO_NB_UNKNOWN'

In [None]:
RUN_OPTUNA = False
MERGE_DB = False

In [None]:
NUM_WORKERS = os.cpu_count() // 3 # For preprocessing, not tuning or training

In [None]:
USERNAME = 'adiencedb'
PASSWORD = 'adience'

In [None]:
DATA_DIR = 'data'
BASE_URL = 'http://www.cslab.openu.ac.il/download/adiencedb/AdienceBenchmarkOfUnfilteredFacesForGenderAndAgeClassification/'
ARCHIVE_URL = BASE_URL + "faces.tar.gz"
ARCHIVE_PATH = os.path.join(DATA_DIR, "faces.tar.gz")

### Functions

In [None]:
def show_sample_with_bbox(img_path, label_path, creator):
    """Display an image with its bounding box"""
    # TODO: Make it plot 3 samples in a 1 x 3 subplot
    # Load image
    img = Image.open(img_path)
    img_w, img_h = img.size
    
    # Load label
    with open(label_path, 'r') as f:
        line = f.readline().strip().split()
        class_id = int(line[0])
        x_center, y_center, width, height = map(float, line[1:5])
    
    # Convert YOLO format to pixel coordinates
    x1 = int((x_center - width/2) * img_w)
    y1 = int((y_center - height/2) * img_h)
    x2 = int((x_center + width/2) * img_w)
    y2 = int((y_center + height/2) * img_h)
    
    # Plot
    plt.figure(figsize=(8, 8))
    plt.imshow(np.array(img))
    plt.gca().add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='red', linewidth=2))
    
    # Label with age category
    age_category = creator.age_categories[class_id]
    plt.title(f"Age Category: {age_category[0]}-{age_category[1]} years")
    plt.axis('off')
    plt.show()

In [None]:
def get_model_filename(version, size):
    """
    Returns the correct YOLO model filename based on version and size.
    Only supports detection models.

    Args:
        version (int or str): YOLO version (8–12)
        size (str): Model size, depends on version

    Returns:
        str: Filename of the model checkpoint, e.g. 'yolov9e.pt'

    Raises:
        ValueError: If version or size is unsupported
    """
    version = str(version).lower()
    size = size.lower()

    if version == '9':
        model_map = {
            'n': 'yolov9t.pt',
            's': 'yolov9s.pt',
            'm': 'yolov9m.pt',
            'l': 'yolov9c.pt',
            'x': 'yolov9e.pt',
        }
        return model_map[size]

    if version in ['8', '10']:
        return f'yolov{version}{size}.pt'
    elif version in ['11', '12']:
        return f'yolo{version}{size}.pt'


In [None]:
def check_default_boxes(base_dir='data/age_dataset_tune'):

    for split in ['train', 'val']:
        label_dir = os.path.join(base_dir, 'labels', split)
        label_files = glob.glob(os.path.join(label_dir, '*.txt'))
        
        if not label_files:
            print(f"  No label files found in {label_dir}. Skipping...")
            continue
        
        default_boxes = 0
        for file in label_files:
            with open(file, 'r') as f:
                parts = f.readline().strip().split()
                if len(parts) == 5:
                    _, x, y, w, h = map(float, parts)
                    if abs(x - 0.5) < 0.05 and abs(y - 0.5) < 0.05 and abs(w - 0.8) < 0.05 and abs(h - 0.8) < 0.05:
                        default_boxes += 1
        
        total_files = len(label_files)
        percent = (default_boxes / total_files) * 100
        print(f"  {split.capitalize()} set: {default_boxes}/{total_files} default boxes ({percent:.2f}%)")

In [None]:
def make_objective(model_path, data_yaml, imgsz, device, epochs_per_trial):
    def objective(trial):
        params = {
            'lr0': trial.suggest_float('lr0', 1e-5, 1e-1, log=True),
            'lrf': trial.suggest_float('lrf', 0.01, 1.0),
            'momentum': trial.suggest_float('momentum', 0.6, 0.98),
            'weight_decay': trial.suggest_float('weight_decay', 0.0001, 0.001, log=True),
            'warmup_epochs': trial.suggest_int('warmup_epochs', 1, 5),
            'warmup_momentum': trial.suggest_float('warmup_momentum', 0.5, 0.95),
            'box': trial.suggest_float('box', 0.02, 0.2),
            'cls': trial.suggest_float('cls', 0.2, 4.0),
            'hsv_h': trial.suggest_float('hsv_h', 0.0, 0.1),
            'hsv_s': trial.suggest_float('hsv_s', 0.5, 0.9),
            'hsv_v': trial.suggest_float('hsv_v', 0.5, 0.9),
            'degrees': trial.suggest_float('degrees', 0.0, 45.0),
            'translate': trial.suggest_float('translate', 0.0, 0.5),
            'scale': trial.suggest_float('scale', 0.0, 0.5),
            'fliplr': trial.suggest_float('fliplr', 0.0, 0.5),
            'mosaic': trial.suggest_float('mosaic', 0.0, 1.0),
            'batch': trial.suggest_int('batch', 32, 32),
            'imgsz': trial.suggest_int('imgsz', imgsz, imgsz),
            'optimizer': 'AdamW'
        }

        try:
            model = YOLO(model_path)
            results = model.train(
                data=data_yaml,
                cache='disk',
                workers=1,
                epochs=epochs_per_trial,
                device=device,
                verbose=False,
                plots=True,
                **params
            )
            return float(results.fitness) if hasattr(results, 'fitness') else 0.0
        except Exception as e:
            print(f"❌ Trial failed with error: {e}")
            return 0.0
    return objective


In [None]:
def run_optuna_tuning(
    data_yaml,
    model_size='n',
    output_dir='runs/tune_optuna',
    imgsz=416,
    n_trials=40,
    epochs_per_trial=30,
    model_version=8,
    device='0'
):

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_name = f"optuna_v{model_version}_{model_size}_{timestamp}"
    output_path = os.path.join(output_dir, run_name)
    os.makedirs(output_path, exist_ok=True)

    try:
        model_path = get_model_filename(model_version, model_size)
    except ValueError as e:
        print(f"❌ Invalid model request: {e}")
        return {}, 0.0

    study_name = f"yolo_v{model_version}{model_size}"
    study_storage = f"sqlite:///{STORAGE}.db"

    # Query the database manually without creating a study
    optuna_study_summaries = optuna.study.get_all_study_summaries(storage=study_storage)
    existing_study = None
    for s in optuna_study_summaries:
        if s.study_name == study_name:
            existing_study = s
            break

    if existing_study:
        existing_trials = existing_study.n_trials
    else:
        existing_trials = 0

    if existing_trials >= n_trials:
        print(f"⏩ Skipping tuning: {existing_trials} completed trials already (target was {n_trials}).")
        if existing_study:
            study = optuna.load_study(study_name=study_name, storage=study_storage)
            best_params = study.best_params
            best_value = study.best_value
        else:
            best_params = {}
            best_value = 0.0
    else:
        remaining_trials = n_trials - existing_trials
        print(f"🔄 Starting/resuming tuning: {remaining_trials} trials needed.")
        study = optuna.create_study(
            direction='maximize',
            study_name=study_name,
            storage=study_storage,
            load_if_exists=True
        )

        objective = make_objective(model_path, data_yaml, imgsz, device, epochs_per_trial)
        study.optimize(objective, n_trials=remaining_trials)

        best_params = study.best_params
        best_value = study.best_value

    with open(os.path.join(output_path, f'best_params_v{model_version}_{model_size}.json'), 'w') as f:
        json.dump(best_params, f, indent=4)

    print(f"\n✅ Best result for YOLOv{model_version}-{model_size}: {best_value:.4f}")
    return best_params, best_value

In [None]:
def run_optuna_tuning_multi(
    base_dataset_dir='data/age_dataset_tune',
    model_sizes=['n', 's', 'm', 'l', 'x'],
    model_versions=[8, 9, 10, 11, 12],
    imgsz=416,
    n_trials=10,
    epochs_per_trial=30,
    device='0',
    output_base='runs/age_exp'
):

    data_yaml = os.path.join(base_dataset_dir, "data.yaml")

    if not data_yaml:
        print(f"⚠️ No datasets found in: {base_dataset_dir}")
        return

    print(f"\n📂 Evaluating dataset: {data_yaml}")

    for version in model_versions:
        for size in model_sizes:
            try:
                model_filename = get_model_filename(version, size)
            except ValueError as e:
                print(f"⏭️ Skipping unsupported model: YOLOv{version}-{size} ({e})")
                continue

            dataset_name = Path(data_yaml).parent.name
            run_name = f"v{version}_{size}"
            output_dir = os.path.join(output_base, run_name)

            print(f"\n{'='*100}")
            print(f"🧪 Tuning: YOLOv{version}-{size} on dataset {dataset_name}")
            print(f"{'='*100}")

            best_params, best_value = run_optuna_tuning(
                data_yaml=data_yaml,
                model_size=size,
                model_version=version,
                output_dir=output_dir,
                imgsz=imgsz,
                n_trials=n_trials,
                epochs_per_trial=epochs_per_trial,
                device=device
            )


In [None]:
def merge_optuna_databases(source_db_paths, target_db_path):
    """
    Merge Optuna studies from multiple databases into a target database.

    Args:
        source_db_paths (list): List of SQLite database URIs to merge.
        target_db_path (str): Target SQLite database URI.
    """
    target_storage = optuna.storages.RDBStorage(url=target_db_path)

    for db_path in source_db_paths:
        source_storage = optuna.storages.RDBStorage(url=db_path)
        study_summaries = optuna.get_all_study_summaries(storage=source_storage)
        
        for summary in study_summaries:
            study = optuna.load_study(study_name=summary.study_name, storage=source_storage)
            
            try:
                new_study = optuna.create_study(
                    study_name=study.study_name,
                    storage=target_storage,
                    direction=study.direction
                )
            except optuna.exceptions.DuplicatedStudyError:
                new_study = optuna.load_study(
                    study_name=study.study_name,
                    storage=target_storage
                )
            
            for trial in study.get_trials(deepcopy=True, states=(TrialState.COMPLETE,)):
                new_study.add_trial(trial)

In [None]:
def extract_study_trials_to_dataframe(db_paths, filter_study_name=None, sort_by_value=True):
    """
    Extract trials from Optuna databases into a DataFrame.

    Args:
        db_paths (list): List of database URIs.
        filter_study_name (str, optional): Only include studies matching this name. Default: None.
        sort_by_value (bool): Whether to sort trials by Value (ascending).

    Returns:
        pd.DataFrame: Trials information as a DataFrame.
    """
    study_infos = []

    for db_path in db_paths:
        storage = optuna.storages.RDBStorage(url=db_path)
        summaries = optuna.get_all_study_summaries(storage=storage)

        for summary in summaries:
            if filter_study_name and summary.study_name != filter_study_name:
                continue
            
            study = optuna.load_study(study_name=summary.study_name, storage=storage)
            completed_trials = [t for t in study.get_trials(deepcopy=False) if t.state == TrialState.COMPLETE]
            completed_trials.sort(key=lambda x: x.value)
            
            for rank, trial in enumerate(completed_trials, 1):
                study_infos.append({
                    "Database": db_path,
                    "Study Name": summary.study_name,
                    "Rank": rank,
                    "Trial Number": trial.number,
                    "Value": trial.value,
                    **trial.params
                })

    df = pd.DataFrame(study_infos)
    
    if sort_by_value:
        df = df.sort_values(by=["Value"], ascending=False)

    return df

### Classes

In [None]:
class FaceAgeDatasetCreator:
    def __init__(
        self,
        base_dir="data",
        faces_archive=None,
        faces_dir=None,
        output_dir=None,
        fold_files=None,
        imgsz=416,
        max_workers=4
    ):
        self.base_dir = base_dir
        os.makedirs(self.base_dir, exist_ok=True)
        
        self.faces_archive = faces_archive or os.path.join(base_dir, "faces.tar.gz")
        self.faces_dir = faces_dir or os.path.join(base_dir, "faces")
        # self.output_dir = output_dir or os.path.join(base_dir, "age_dataset")
        # os.makedirs(self.output_dir, exist_ok=True)

        if fold_files is None:
            self.fold_files = [os.path.join(base_dir, f"fold_{i}_data.txt") for i in range(5)]
        else:
            self.fold_files = fold_files

        self.img_size = imgsz

        self.age_categories = [(0, 2),(4, 6),(8, 12),(15, 20),(25, 32),(38, 43),(48, 53),(60, 100)]

        self.max_workers = max_workers or (os.cpu_count() // 2)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.mtcnn = MTCNN(keep_all=False, device=self.device)

        print(f"Running on device: {self.device}")

    def extract_faces_archive(self):
        if not os.path.exists(self.faces_dir):
            os.makedirs(self.faces_dir, exist_ok=True)
            print(f"Extracting {self.faces_archive} to {self.faces_dir}...")
            with tarfile.open(self.faces_archive, 'r:gz') as tar:
                for member in tqdm(tar.getmembers(), desc="Extracting faces"):
                    if member.name.startswith("faces/"):
                        member.name = member.name[len("faces/"):]
                        if member.name:
                            tar.extract(member, self.faces_dir, filter='data')
            print("Extraction complete.")
        else:
            print(f"{self.faces_dir} already exists. Skipping extraction.")

    def get_age_class(self, age_info):
        try:
            if isinstance(age_info, str) and '(' in age_info:
                match = re.findall(r'\d+', age_info)
                if len(match) >= 2:
                    lower, upper = int(match[0]), int(match[1])
                    for i, cat in enumerate(self.age_categories):
                        if (lower, upper) == cat:
                            return i
            else:
                age = int(age_info)
                for i, (low, high) in enumerate(self.age_categories):
                    if low <= age <= high:
                        return i
            return -1
        except:
            return -1

    def load_fold_data(self, fold_files=None):
        if fold_files is None:
            fold_files = self.fold_files
        all_data = []
        column_names = [
            'user_id', 'original_image', 'face_id', 'age', 'gender', 
            'x', 'y', 'dx', 'dy', 'tilt_ang', 'fiducial_yaw_angle', 'fiducial_score'
        ]
        for fold_file in fold_files:
            try:
                df = pd.read_csv(fold_file, sep='\t', header=None, names=column_names)
                df['age_class'] = df['age'].apply(self.get_age_class)
                df = df[df['age_class'] != -1]
                all_data.append(df)
            except Exception as e:
                print(f"Error loading {fold_file}: {e}")
        return pd.concat(all_data, ignore_index=True) if all_data else pd.DataFrame(columns=column_names + ['age_class'])

    def get_image_path(self, row):
        filename = f"coarse_tilt_aligned_face.{row['face_id']}.{row['original_image']}"
        return os.path.join(self.faces_dir, str(row['user_id']), filename)

    def detect_face(self, image_np):
        img_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
        img_pil = Image.fromarray(img_rgb)

        boxes, _ = self.mtcnn.detect(img_pil)
        if boxes is not None and len(boxes) > 0:
            x1, y1, x2, y2 = boxes[0]
            w = x2 - x1
            h = y2 - y1
            return (x1, y1, w, h)
        else:
            return None

    def is_dataset_complete(self, size_dir):
        """
        Check if the dataset for a given image size is complete and ready.
        """
        expected = [
            os.path.join(size_dir, "data.yaml"),
            os.path.join(size_dir, "classes.txt"),
            os.path.join(size_dir, "images/train"),
            os.path.join(size_dir, "images/val"),
            os.path.join(size_dir, "labels/train"),
            os.path.join(size_dir, "labels/val"),
        ]
        for path in expected:
            if not os.path.exists(path):
                return False
        
        val_imgs = list(Path(size_dir).joinpath("images/val").glob("*.jpg"))
        val_lbls = list(Path(size_dir).joinpath("labels/val").glob("*.txt"))
        
        return len(val_imgs) > 0 and len(val_imgs) == len(val_lbls)

    def process_dataset(self, data, img_dir, label_dir):
        os.makedirs(img_dir, exist_ok=True)
        os.makedirs(label_dir, exist_ok=True)

        transform = transforms.Resize((self.img_size, self.img_size))

        for idx, row in tqdm(data.iterrows(), total=len(data), desc="Processing images"):
            try:
                img_path = self.get_image_path(row)
                if not os.path.exists(img_path):
                    print(f"Warning: Image not found: {img_path}")
                    continue

                class_id = int(row['age_class'])
                filename = os.path.basename(img_path).replace('coarse_tilt_aligned_face.', '')
                base_filename = f"{idx}_{filename.split('.')[0]}"

                with Image.open(img_path).convert('RGB') as img:
                    orig_width, orig_height = img.size
                    img_np = np.array(img)
                    img_cv = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

                    face_coords = self.detect_face(img_cv)

                    img_resized = transform(img)
                    save_path = os.path.join(img_dir, f"{base_filename}.jpg")
                    img_resized.save(save_path)

                    if face_coords is not None:
                        x, y, w, h = face_coords
                        x = max(0, min(x, orig_width))
                        y = max(0, min(y, orig_height))
                        w = max(0, min(w, orig_width - x))
                        h = max(0, min(h, orig_height - y))

                        x_center = (x + w/2) / orig_width
                        y_center = (y + h/2) / orig_height
                        width_norm = w / orig_width
                        height_norm = h / orig_height
                    else:
                        x_center, y_center, width_norm, height_norm = 0.5, 0.5, 0.8, 0.8

                    if not (0 <= class_id < len(self.age_categories)):
                        print(f"Invalid age class at index {idx}: {class_id}")
                        continue

                label_path = os.path.join(label_dir, f"{base_filename}.txt")
                with open(label_path, 'w') as f:
                    f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width_norm:.6f} {height_norm:.6f}\n")

            except Exception as e:
                print(f"Error at index {idx}: {e}")

    def create_data_yaml(self, output_dir):
        yaml_path = os.path.join(output_dir, 'data.yaml')
        with open(yaml_path, 'w') as f:
            train_dir = os.path.abspath(os.path.join(output_dir, "images/train"))
            val_dir = os.path.abspath(os.path.join(output_dir, "images/val"))
            f.write(f"train: {train_dir}\n")
            f.write(f"val: {val_dir}\n")
            f.write(f"nc: {len(self.age_categories)}\n")
            f.write("names:\n")
            classes_path = os.path.join(output_dir, 'classes.txt')
            with open(classes_path, 'r') as cf:
                for i, line in enumerate(cf):
                    f.write(f"  {i}: '{line.strip()}'\n")

    def create_yolo_dataset(self, train_folds, val_fold, output_dir=None):
        if not self.is_dataset_complete(output_dir):
            img_train, img_val = os.path.join(output_dir, 'images/train'), os.path.join(output_dir, 'images/val')
            lbl_train, lbl_val = os.path.join(output_dir, 'labels/train'), os.path.join(output_dir, 'labels/val')
            
            for d in [img_train, img_val, lbl_train, lbl_val]:
                os.makedirs(d, exist_ok=True)

            with open(os.path.join(output_dir, 'classes.txt'), 'w') as f:
                for (low, high) in self.age_categories:
                    f.write(f"age_{low}_{high}\n")

            train_data = self.load_fold_data([self.fold_files[i] for i in train_folds])
            val_data = self.load_fold_data([self.fold_files[val_fold]])

            print(f"Train images: {len(train_data)}, Val images: {len(val_data)}")

            self.process_dataset(train_data, img_train, lbl_train)
            self.process_dataset(val_data, img_val, lbl_val)
            self.create_data_yaml(output_dir)
            print(f"✅ Dataset ready at: {output_dir}")
        else:
            print(f"✅ {output_dir} already complete. Skipping...")


In [None]:
class YOLOMultiTrainer:
    def __init__(self, 
                 data_yaml, 
                 model_size='n', 
                 model_version='8', 
                 device='0', 
                 project='runs/multi_runs'):
        
        self.data_yaml = data_yaml
        self.model_size = model_size
        self.model_version = model_version
        self.device = device
        self.project = project
        self.data_config = self._load_data_config()
        self.class_names = self.data_config['names']
        os.makedirs(self.project, exist_ok=True)

    def _load_data_config(self):
        with open(self.data_yaml, 'r') as f:
            return yaml.safe_load(f)

    def load_study_best_params(self, study_path=None, study_name=None, db_path=None):
        if study_path and os.path.exists(study_path):
            print(f"Loading Optuna study from file: {study_path}")
            study = optuna.load_study(study_name="loaded_study", storage=study_path)
        elif db_path and study_name:
            print(f"Loading Optuna study from DB: {db_path}, Study name: {study_name}")
            study = optuna.load_study(study_name=study_name, storage=db_path)
        else:
            raise ValueError("Provide study_path or (db_path + study_name)")
        
        self.best_params = study.best_params
        self.best_value = study.best_value
        return self.best_params, self.best_value

    def _set_random_seed(self):
        seed = random.randint(0, 10000)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = False
            torch.backends.cudnn.benchmark = True
        return seed

    def train_once(self, run_id, epochs=100, base_name=None):
        seed = self._set_random_seed()
        if base_name is None:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            base_name = f"train_{timestamp}"
        run_name = f"{base_name}_run{run_id}"

        model_path = f'yolov{self.model_version}{self.model_size}.pt'
        model = YOLO(model_path)

        params = self.best_params.copy()
        params.setdefault('imgsz', 416)
        params['optimizer'] = 'AdamW'
        params['seed'] = seed
        params['val'] = False
        params['deterministic'] = False
        params['batch'] = 32

        start_time = time.time()
        model.train(
            data=self.data_yaml,
            epochs=epochs,
            device=self.device,
            project=self.project,
            name=run_name,
            verbose=False,
            **params
        )
        end_time = time.time()

        run_dir = os.path.join(self.project, run_name)
        training_time = end_time - start_time

        return model, run_dir, training_time

    def _get_validation_images(self):
        val_path = self.data_config['val']
        if not os.path.isabs(val_path):
            val_path = os.path.join(os.path.dirname(self.data_yaml), val_path)
        if os.path.isdir(val_path):
            return [os.path.join(root, file) 
                    for root, _, files in os.walk(val_path) 
                    for file in files if file.lower().endswith(('.jpg', '.jpeg', '.png'))]
        elif os.path.isfile(val_path):
            with open(val_path, 'r') as f:
                return [line.strip() for line in f]
        else:
            raise ValueError(f"Invalid val_path: {val_path}")

    @staticmethod
    def get_label_path(img_path):
        img_path = Path(img_path)
        if 'images' in img_path.parts:
            idx = img_path.parts.index('images')
            label_path = Path(*img_path.parts[:idx], 'labels', *img_path.parts[idx+1:]).with_suffix('.txt')
        else:
            label_path = img_path.with_suffix('.txt')
        return str(label_path)

    def _load_ground_truth(self, img_path):
        label_path = self.get_label_path(img_path)
        if not os.path.exists(label_path):
            return []
        with open(label_path, 'r') as f:
            return [int(float(line.split()[0])) for line in f if line.strip()]

    def evaluate_accuracy(self, model, conf_threshold=0.1):
        images = self._get_validation_images()
        correct, total = 0, 0
        for img_path in tqdm(images, desc="Evaluating accuracy"):
            gt_classes = self._load_ground_truth(img_path)
            if not gt_classes:
                continue
            results = model(img_path, conf=conf_threshold, verbose=False)[0]
            preds = results.boxes.data.cpu().numpy()
            if len(preds) > 0:
                preds = preds[preds[:, 4].argsort()[::-1]]
                pred_class = int(preds[0][5])
                if pred_class in gt_classes:
                    correct += 1
            total += 1
        return correct / total if total else 0

    def generate_confusion_matrix(self, model, output_dir, conf_threshold=0.1):
        os.makedirs(output_dir, exist_ok=True)
        images = self._get_validation_images()
        n_classes = len(self.class_names)
        cm = np.zeros((n_classes, n_classes + 1), dtype=int)

        for img_path in tqdm(images, desc="Generating confusion matrix"):
            gt_classes = self._load_ground_truth(img_path)
            if not gt_classes:
                continue
            results = model(img_path, conf=conf_threshold, verbose=False)[0]
            preds = results.boxes.data.cpu().numpy()

            for gt_class in gt_classes:
                if len(preds) > 0:
                    preds = preds[preds[:, 4].argsort()[::-1]]
                    pred_class = int(preds[0][5])
                    cm[gt_class, pred_class] += 1
                else:
                    cm[gt_class, -1] += 1
        
        self.plot_confusion_matrix(cm, output_path=os.path.join(output_dir, 'confusion_matrix_raw.png'), normalize=False)
        self.plot_confusion_matrix(cm, output_path=os.path.join(output_dir, 'confusion_matrix_normalized.png'), normalize=True)

        raw_cm_path = os.path.join(output_dir, 'confusion_matrix_raw.csv')
        pd.DataFrame(cm).to_csv(raw_cm_path, index=False)

    def plot_confusion_matrix(self, cm, output_path, normalize=False):
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues')
        plt.xlabel('Predicted')
        plt.ylabel('Ground Truth')
        plt.title('Confusion Matrix (Normalized)' if normalize else 'Confusion Matrix (Raw)')
        plt.tight_layout()
        plt.savefig(output_path, dpi=300)
        plt.close()

    def run_multiple_trainings(self, 
                               study_path=None, 
                               study_name=None, 
                               db_path=None, 
                               num_runs=5, 
                               epochs=100, 
                               conf_threshold=0.25):
        
        self.load_study_best_params(study_path, study_name, db_path)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        base_name = f"multi_run_{timestamp}"

        all_metrics = []

        for run_id in range(1, num_runs + 1):
            model, run_dir, training_time = self.train_once(run_id, epochs, base_name)
            acc = self.evaluate_accuracy(model, conf_threshold)
            self.generate_confusion_matrix(model, output_dir=os.path.join(run_dir, 'confusion_matrix'), conf_threshold=conf_threshold)

            run_metrics = {
                'run_id': run_id,
                'training_time_sec': training_time,
                'accuracy': acc
            }
            all_metrics.append(run_metrics)

            with open(os.path.join(run_dir, 'metrics_summary.json'), 'w') as f:
                json.dump(run_metrics, f, indent=4)
        
        df = pd.DataFrame(all_metrics)
        df.to_csv(os.path.join(self.project, f"{base_name}_metrics.csv"), index=False)

        print(df.describe())

## Download data

### Session

In [None]:
session = requests.Session()
session.headers.update({
    "User-Agent": "Mozilla/5.0",
    "Referer": BASE_URL
})

### Folds

In [None]:
fold_files = [f"fold_{i}_data.txt" for i in range(5)]

for fname in fold_files:
    url = BASE_URL + fname
    dest = os.path.join(DATA_DIR, fname)
    
    if os.path.exists(dest):
        print(f"{dest} already exist")
        continue

    print(f"Downloading {url}")
    
    response = session.get(url, auth=HTTPBasicAuth(USERNAME, PASSWORD))
    
    if response.status_code == 200:
        with open(dest, 'wb') as f:
            f.write(response.content)
        print(f"Saved: {dest}")
    else:
        print(f"Failed: {url} (Status: {response.status_code})")

### Faces

In [None]:
if not os.path.exists(ARCHIVE_PATH):
    print(f"\nDownloading archive: {ARCHIVE_URL}")
    response = session.get(ARCHIVE_URL, auth=HTTPBasicAuth(USERNAME, PASSWORD), stream=True)

    if response.status_code == 200:
        total_size = int(response.headers.get('content-length', 0))
        chunk_size = 8192

        with open(ARCHIVE_PATH, 'wb') as f, tqdm(
            desc="Downloading faces.tar.gz",
            total=total_size,
            unit='B',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
                    bar.update(len(chunk))
        
        print(f"Downloaded: {ARCHIVE_PATH}")
        session.close()
    else:
        print(f"Failed to download archive (Status: {response.status_code})")
else:
    print(f"{ARCHIVE_PATH} already exist")

## Preprocess data

In [None]:
creator = FaceAgeDatasetCreator(base_dir=DATA_DIR, max_workers=NUM_WORKERS)

### Folds

In [None]:
fold_files = creator.fold_files
print(f"Looking for fold files: {fold_files}")

In [None]:
data = creator.load_fold_data()
print(f"Loaded {len(data)} records from fold files")

In [None]:
if len(data) > 0:
    print("\nSample data:")
    display(data)
    
    # Show age distribution
    plt.figure(figsize=(12, 6))
    data['age_class'].value_counts().sort_index().plot(kind='bar')
    plt.title('Distribution of Age Classes')
    plt.xlabel('Age Class')
    plt.ylabel('Count')
    plt.xticks(rotation=0)
    plt.grid(axis='y', alpha=0.3)
    plt.show()

### Extract faces

In [None]:
creator.extract_faces_archive()

### Generate dataset

In [None]:
# TODO: Objective function
# if REBUILD_DATASET: 
#     # TODO: Code for deleting the data
#     creator.create_yolo_dataset(train_folds=[0, 1, 2, 3], val_fold=4)
# else:
#     # if exist eller noe
#     creator.download_preprocessed_dataset() # TODO: Google Drive hosting eller noe

creator.create_yolo_dataset(train_folds=[0, 1, 2, 3], val_fold=4, output_dir="data/age_dataset_tune")


In [None]:
show_sample_with_bbox(img_path='data/age_dataset_test/images/val/5_1327.jpg', label_path='data/age_dataset_test/labels/val/5_1327.txt', creator=creator)

## Hyperparameter tuning

### Tune

In [None]:
if RUN_OPTUNA:
    run_optuna_tuning_multi(
        base_dataset_dir='data/age_dataset_tune',
        model_sizes=['n'],#MODEL_SIZES,
        model_versions=['8'],#MODEL_VERSIONS, 
        imgsz=IMAGE_SIZE,
        n_trials=100, # per combination
        epochs_per_trial=30,
        device='0',
        output_base='runs/age_exp'
    )
else:
    print("Config RUN_OPTUNA is False")

### Merge Optuna databases

In [None]:
source_databases = [
    "sqlite:///YOLO_NB_LOCAL_1.db",
    "sqlite:///YOLO_NB_SERVER.db",
    "sqlite:///YOLO_NB_SERVER_1.db",
    "sqlite:///YOLO.db"
]
target_database = "sqlite:///merged.db"

if MERGE_DB:
    merge_optuna_databases(source_databases, target_database)
    merged_db_paths = ["sqlite:///merged.db"]

    df = extract_study_trials_to_dataframe(merged_db_paths)
    display(df)
else:
    print("Config MERGE_DB is False")


Manual task: Discuss params and set in file

## Train

### Generate training dataset 

With different foldsplit than used in tuning

In [None]:
creator.create_yolo_dataset(train_folds=[0, 1, 2, 4], val_fold=3, output_dir="data/age_dataset_test")

In [None]:
# python multiple_runs.py --data data/age_dataset_test2/data.yaml --model-size n --db-path sqlite:///YOLO.db  --study-name "3 model size n" --num-runs 5 --epochs 30 --device 0 --project runs/multi_runs

In [None]:
trainer = YOLOMultiTrainer(
    data_yaml='data/age_dataset_test/data.yaml',
    model_size='m',
    model_version='8',
    device='0',
    project='runs/multi_runs'
)

trainer.run_multiple_trainings(
    study_name='3 model size n',
    db_path='sqlite:///YOLO.db',
    num_runs=5,
    epochs=30,
    conf_threshold=0.25
)