## ParkItRight

#### Import the required libraries

In [2]:
import os
import random
import shutil
import torch
from ultralytics import YOLO

#### Clear the data to train, validate and test the model

In [None]:
folders_to_clear = ["data/train/correct", "data/train/incorrect", "data/test/correct", "data/test/incorrect", "data/val/correct", "data/val/incorrect"]
for folder in folders_to_clear:
    for item in os.listdir(folder):
        if item != ".gitkeep":
            item_path = os.path.join(folder, item)
            os.unlink(item_path)
if os.path.exists("data/train.cache"):
    os.unlink("data/train.cache")
if os.path.exists("data/val.cache"):
    os.unlink("data/val.cache")

#### Copy the data in random splits of 70% training-data, 20% validation-data and 10% test-data

In [21]:
labels = ["correct", "incorrect"]
for label in labels:
    folder_path = os.path.join("data_raw", label)
    folder_content = os.listdir(folder_path)
    random.shuffle(folder_content)
    split_ratios = [0.7, 0.2]
    split_sizes = [int(len(folder_content) * ratio) for ratio in split_ratios]

    splits = {
        "train": folder_content[:split_sizes[0]],
        "val": folder_content[split_sizes[0]: split_sizes[0] + split_sizes[1]],
        "test": folder_content[split_sizes[0] + split_sizes[1]:]
    }

    for split_name, split_content in splits.items():
        for file in split_content:
            src_path = os.path.join(folder_path, file)
            dest_path = os.path.join("data", split_name, label ,file)
            shutil.copy(src_path, dest_path)

#### Select the available device to perform the torch calculations

In [None]:
if torch.cuda.is_available():
    print("Cuda is available. Torch will use Cuda.")
    device = "cuda"
elif torch.backends.mps.is_available():
    print("MPS is available. Torch will use MPS.")
    device = "mps"
else:
    print("GPU is not available. Torch will fall back to CPU.")
    device = "cpu"

#### Train the YOLO model

In [None]:
model = YOLO("base_model/YOLO11s-cls.pt")
results = model.train(data="data", epochs=30, imgsz=640, device = device)