In [None]:
# [INFO] Install necessary packages
!pip install -q ultralytics segmentation-models

import os
import sys
import time
import shutil
import yaml
import random
import re
import multiprocessing
import numpy as np
import pandas as pd
import cv2
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
from ultralytics import YOLO
from google.colab import drive

# Mount Drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# [INFO] Reproducibility Setup
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

# [INFO] Directory Configuration
BASE_SAVE_DIR = "/content/drive/MyDrive/Model/Final_Experiments_V2"
DRIVE_YAML_PATH = "/content/drive/MyDrive/Dataset/FINAL_YOLO_SPLIT/dataset.yaml"
DATASET_YAML = os.path.join(LOCAL_DATA_DIR, "dataset.yaml")

os.makedirs(BASE_SAVE_DIR, exist_ok=True)

In [None]:
# Configuration
DATASET_DIR = '/content/drive/MyDrive/Dataset/PROCESSED_YOLO'
IMG_DIR = os.path.join(DATASET_DIR, "images")
LABEL_DIR = os.path.join(DATASET_DIR, "labels")

# Mapping YOLO ID to Class Name
YOLO_ID_TO_NAME = {
    0: "Brain",
    1: "CSP",
    2: "LV"
}

def get_box_statistics(img_path):
    try:
        basename = os.path.basename(img_path)
        name_no_ext = os.path.splitext(basename)[0]
        label_path = os.path.join(LABEL_DIR, name_no_ext + ".txt")

        if not os.path.exists(label_path):
            return []

        boxes = []
        with open(label_path, 'r') as f:
            for line in f.readlines():
                parts = line.strip().split()
                if len(parts) < 5: continue

                cls_id = int(parts[0])
                norm_w = float(parts[3])
                norm_h = float(parts[4])
                norm_area = norm_w * norm_h  # Normalized Area

                boxes.append({
                    "filename": basename,
                    "class_id": cls_id,
                    "class_name": YOLO_ID_TO_NAME.get(cls_id, "Unknown"),
                    "width_norm": norm_w,
                    "height_norm": norm_h,
                    "area_norm": norm_area
                })
        return boxes
    except Exception:
        return []

def run_dataset_analysis():
    print("Starting Bounding Box Analysis (Normalized)")
    image_paths = sorted(glob(os.path.join(IMG_DIR, "*")))
    workers = multiprocessing.cpu_count()

    all_boxes = []

    with ProcessPoolExecutor(max_workers=workers) as executor:
        futures = [executor.submit(get_box_statistics, p) for p in image_paths]
        for f in tqdm(as_completed(futures), total=len(futures), desc="Scanning Labels"):
            res = f.result()
            if res:
                all_boxes.extend(res)

    if not all_boxes:
        print("No boxes found.")
        return

    df = pd.DataFrame(all_boxes)

    # Set style to match reference (white grid)
    sns.set_theme(style="whitegrid")

    # --- Visualization 1: Class Distribution ---
    plt.figure(figsize=(8, 6))
    # Using viridis palette
    ax = sns.countplot(data=df, x='class_name', order=['Brain', 'CSP', 'LV'], palette='viridis')

    plt.title("Class Distribution")
    plt.xlabel("Class")
    plt.ylabel("Object Count")

    for p in ax.patches:
        ax.annotate(f'{int(p.get_height())}',
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='bottom', fontsize=10, color='black', xytext=(0, 2),
                    textcoords='offset points')
    plt.show()

    # --- Visualization 2: Object Size Distribution (Histogram + KDE) ---
    # Filter only CSP and LV
    df_sub = df[df['class_name'].isin(['CSP', 'LV'])]

    plt.figure(figsize=(8, 6))
    sns.histplot(
        data=df_sub,
        x='area_norm',
        hue='class_name',
        kde=True,
        element="step",
        stat="density",
        common_norm=False,
        palette={'CSP': '#fc8d62', 'LV': '#66c2a5'}, # Orange & Teal Green
        alpha=0.3
    )
    plt.title("Object Size Distribution (CSP vs LV)")
    plt.xlabel("Normalized Area (W x H)")
    plt.ylabel("Density")
    plt.show()

    # --- Visualization 3: Width vs Height Scatter ---
    plt.figure(figsize=(8, 8))
    sns.scatterplot(
        data=df_sub,
        x='width_norm',
        y='height_norm',
        hue='class_name',
        style='class_name',
        alpha=0.6,
        palette={'CSP': '#fc8d62', 'LV': '#2b8cbe'} # Orange & Blue/Teal
    )

    # Diagonal line (Aspect Ratio 1:1)
    plt.plot([0, 1], [0, 1], 'r-', alpha=0.4, linewidth=1)

    plt.title("Box Width vs Height Distribution")
    plt.xlabel("Width (Normalized)")
    plt.ylabel("Height (Normalized)")
    plt.xlim(0, 1.0)
    plt.ylim(0, 1.0)
    plt.legend(title="Class Name")
    plt.grid(True)
    plt.show()

    # Print Summary Statistics
    print("\n[INFO] Average Normalized Area:")
    print(df.groupby('class_name')['area_norm'].mean().round(4))

# Run Analysis
run_dataset_analysis()

In [None]:
SOURCE_DATA_DIR = "/content/drive/MyDrive/Dataset/PROCESSED_YOLO"
OUTPUT_DIR = "/content/drive/MyDrive/Dataset/FINAL_YOLO_SPLIT"
SPLIT_RATIOS = (0.70, 0.15, 0.15) # Train, Val, Test
CLASS_NAMES = ['Brain', 'CSP', 'LV']

def extract_subject_id(filename):
    """Extracts patient ID from filename to prevent data leakage."""
    match = re.search(r"Diverse_(\d+)_", filename) or re.search(r"(Patient\d+)", filename)
    return match.group(1) if match else filename

def copy_file_worker(args):
    try:
        shutil.copy2(*args)
        return True
    except Exception as e:
        return str(e)

def split_dataset():
    print(f"[INFO] Initializing dataset split at {OUTPUT_DIR}...")
    img_src = os.path.join(SOURCE_DATA_DIR, "images")
    lbl_src = os.path.join(SOURCE_DATA_DIR, "labels")

    if not os.path.exists(img_src):
        print("[ERROR] Source directory empty.")
        return

    # Grouping
    subject_map = {}
    for f in os.listdir(img_src):
        if not f.endswith(('.png', '.jpg', '.jpeg')): continue
        subj_id = extract_subject_id(f)
        subject_map.setdefault(subj_id, []).append(f)

    subjects = list(subject_map.keys())
    random.shuffle(subjects) # Deterministic due to set_seed()

    n = len(subjects)
    n_train = int(n * SPLIT_RATIOS[0])
    n_val = int(n * SPLIT_RATIOS[1])

    splits = {
        'train': subjects[:n_train],
        'val': subjects[n_train:n_train+n_val],
        'test': subjects[n_train+n_val:]
    }

    print(f"[INFO] Split Counts: Train={len(splits['train'])}, Val={len(splits['val'])}, Test={len(splits['test'])}")

    tasks = []
    for split_name, subj_list in splits.items():
        dst_img = os.path.join(OUTPUT_DIR, split_name, "images")
        dst_lbl = os.path.join(OUTPUT_DIR, split_name, "labels")
        os.makedirs(dst_img, exist_ok=True)
        os.makedirs(dst_lbl, exist_ok=True)

        for subj in subj_list:
            for fname in subject_map[subj]:
                tasks.append((os.path.join(img_src, fname), os.path.join(dst_img, fname)))
                txt_name = os.path.splitext(fname)[0] + ".txt"
                if os.path.exists(os.path.join(lbl_src, txt_name)):
                    tasks.append((os.path.join(lbl_src, txt_name), os.path.join(dst_lbl, txt_name)))

    with ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 12)) as exc:
        list(tqdm(exc.map(copy_file_worker, tasks), total=len(tasks), desc="Copying Files"))

    # Generate YAML
    yaml_data = {
        'path': OUTPUT_DIR,
        'train': 'train/images',
        'val': 'val/images',
        'test': 'test/images',
        'names': dict(enumerate(CLASS_NAMES))
    }
    with open(os.path.join(OUTPUT_DIR, 'dataset.yaml'), 'w') as f:
        yaml.dump(yaml_data, f, sort_keys=False)

    print(f"[INFO] Dataset split complete.")

if not os.path.exists(os.path.join(OUTPUT_DIR, 'dataset.yaml')):
    split_dataset()
else:
    print("[INFO] Dataset split already exists.")

In [None]:
DEVICE = 0 if torch.cuda.is_available() else 'cpu'
print(f"[INFO] Computation Device: {DEVICE}")

COMMON_TRAIN_ARGS = {
    'imgsz': 640,
    'epochs': 100,
    'patience': 15,
    'batch': 16,
    'workers': 4,
    'project': BASE_SAVE_DIR,
    'device': DEVICE,
    'verbose': False,
    'exist_ok': True,
    'optimizer': 'AdamW',
    'lr0': 0.001,
    'lrf': 0.00001,
    'cos_lr': True,
    'cache': True
}

# Augmentation Profile
AUG_NAME = "Aug_Profile_Tuned"
AUG_PARAMS = {
    'augment': True,
    'degrees': 45.0,
    'shear': 5.0,
    'perspective': 0.0005,
    'translate': 0.2,
    'scale': 0.6,
    'fliplr': 0.5,
    'flipud': 0.0
}

MODELS = [
    'yolov5su.pt', 'yolov5mu.pt',
    'yolov8s.pt',  'yolov8m.pt',
    'yolo11s.pt',  'yolo11m.pt',
    'yolo12s.pt',  'yolo12m.pt'
]

In [None]:
def get_map50_from_csv(run_folder):
    """Parses training CSV log to retrieve best mAP@50."""
    csv_path = os.path.join(run_folder, 'results.csv')
    if not os.path.exists(csv_path): return 0.0

    df = pd.read_csv(csv_path)
    df.columns = [c.strip() for c in df.columns]
    col = 'metrics/mAP50(B)'
    return df[col].max() if col in df.columns else 0.0

def train_wrapper(run_name, model_weights, params):
    """Wraps the YOLO training process."""
    print(f"[INFO] Training: {run_name}")
    args = COMMON_TRAIN_ARGS.copy()
    args.update(params)
    args['name'] = run_name
    args['data'] = DATASET_YAML

    try:
        model = YOLO(model_weights)
        start = time.time()
        model.train(**args)
        duration = time.time() - start

        run_dir = os.path.join(BASE_SAVE_DIR, run_name)
        score = get_map50_from_csv(run_dir)

        # Cleanup to save VRAM
        del model
        torch.cuda.empty_cache()

        return score, run_dir, duration
    except Exception as e:
        print(f"[ERROR] Training failed for {run_name}: {e}")
        return 0.0, None, 0.0

In [None]:
def run_benchmark(aug_params, aug_name):
    print(f"\n[INFO] Starting Benchmark: Baseline vs. {aug_name}")
    results = []

    # Baseline Parameters (Optimizer auto, augmentation disabled)
    raw_params = {'augment': False, 'optimizer': 'auto'}

    for model_file in MODELS:
        model_name = model_file.replace('.pt', '')

        # 1. Baseline Run
        run_name_raw = f"Benchmark_{model_name}_Raw"
        score_raw, _, time_raw = train_wrapper(run_name_raw, model_file, raw_params)

        results.append({
            'Model': model_name,
            'Size': model_name[-1].upper(),
            'Version': model_name[:6],
            'Condition': 'Baseline',
            'mAP50': score_raw,
            'Time_Min': time_raw / 60
        })

        # 2. Augmented Run
        run_name_aug = f"Benchmark_{model_name}_{aug_name}"
        score_aug, _, time_aug = train_wrapper(run_name_aug, model_file, aug_params)

        results.append({
            'Model': model_name,
            'Size': model_name[-1].upper(),
            'Version': model_name[:6],
            'Condition': 'Augmented',
            'mAP50': score_aug,
            'Time_Min': time_aug / 60
        })

    return pd.DataFrame(results)

In [None]:
def plot_results(df):
    if df.empty: return

    plt.figure(figsize=(14, 8))
    df['Display_Name'] = df['Version'] + '-' + df['Size']

    sns.barplot(
        data=df,
        x='Display_Name',
        y='mAP50',
        hue='Condition',
        palette={'Baseline': 'gray', 'Augmented': 'firebrick'}
    )

    plt.title("Impact of Data Augmentation on Detection Performance")
    plt.ylabel("mAP@50")
    plt.xlabel("Model Variant")
    plt.ylim(0, 1.0)
    plt.grid(axis='y', alpha=0.3)

    save_path = os.path.join(BASE_SAVE_DIR, "Benchmark_Comparison.png")
    plt.savefig(save_path)
    print(f"[INFO] Plot saved to {save_path}")
    plt.show()

if __name__ == "__main__":
    df_results = run_benchmark(AUG_PARAMS, AUG_NAME)

    csv_path = os.path.join(BASE_SAVE_DIR, "final_benchmark_metrics.csv")
    df_results.to_csv(csv_path, index=False)

    print("\n[INFO] Final Results:")
    print(df_results.sort_values(by='mAP50', ascending=False))

    plot_results(df_results)