In [None]:
import os
import time
import torch
from ultralytics import YOLO
import json

# Path to data.yaml
data_yaml_path = os.path.join(os.path.abspath(os.path.join(os.getcwd(), os.pardir)), "Versions", "MDD-AFL-Yolov8", "data.yaml")
#data_yaml_path = "/Users/afl/Documents/University/Year 3/Lectures/SEM1/Advanced CV/Assignments/ARI3129-MDD/Versions/MDD-AFL-Yolov8/data.yaml"

# Hyperparameters
epochs = 50
imgsz = 640
batch_size = 8
experiment_name = "MDD-AFL-Yolov8_TEST"

# Determine the device (MPS, GPU, or CPU)
if torch.cuda.is_available():
    device = "cuda"  # Use GPU with CUDA
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = "mps"  # Use Apple Metal Performance Shaders
    print("Using MPS (Metal Performance Shaders) for acceleration.")
else:
    device = "cpu"  # Fallback to CPU
    print("Using CPU. Consider enabling GPU or MPS for faster training.")

# Timer start
start_time = time.time()

# Create or load a YOLOv8 model (from scratch)
try:
    model = YOLO("yolov8n.yaml")  # specify architecture (e.g., yolov8n.yaml)
except FileNotFoundError:
    print("Error: YOLOv8 configuration file 'yolov8n.yaml' not found. Check your setup.")
    raise

# Train
try:
    model.train(
        data=data_yaml_path,
        epochs=epochs,
        imgsz=imgsz,
        batch=batch_size,
        name=experiment_name,
        pretrained=False,  # to train from scratch
        device=device  # Set the appropriate device
    )
except Exception as e:
    print(f"Error during training: {e}")
    raise

# Timer end
end_time = time.time()
print(f"Training completed in {end_time - start_time:.2f} seconds.")

# Evaluate
try:
    metrics = model.val()
    print("Validation Metrics:", metrics)
    
    # Save metrics to a file
    metrics_path = os.path.join(os.path.abspath(os.path.join(os.getcwd(), os.pardir)), "Versions", "MDD-AFL-Yolov8", "validation_metrics.json")
    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=4)
    print(f"Validation metrics saved to: {metrics_path}")
except Exception as e:
    print(f"Error during validation: {e}")
    raise