In [14]:
import os
import xml.etree.ElementTree as ET
import glob
from tqdm import tqdm
import shutil
import yaml
from ultralytics import YOLO

# Set up paths
PROJECT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
DATA_DIR = os.path.join(PROJECT_DIR, "data")
RAW_DIR = os.path.join(DATA_DIR, "raw")
PROCESSED_DIR = os.path.join(DATA_DIR, "processed")
MODEL_DIR = os.path.join(PROJECT_DIR, "models")

# Create processed directory if it doesn't exist
os.makedirs(PROCESSED_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# Function to convert XML annotations to YOLO format
def convert_annotation(xml_file, class_map):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    
    out_file = xml_file.replace('.xml', '.txt').replace(RAW_DIR, PROCESSED_DIR)
    
    with open(out_file, 'w') as f:
        for obj in root.iter('object'):
            cls = obj.find('name').text
            if cls not in class_map:
                continue
            cls_id = class_map[cls]
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
            bb = ((b[0] + b[1]) / 2 / w, (b[2] + b[3]) / 2 / h, (b[1] - b[0]) / w, (b[3] - b[2]) / h)
            f.write(f"{cls_id} {bb[0]:.6f} {bb[1]:.6f} {bb[2]:.6f} {bb[3]:.6f}\n")

# Get all classes
classes = set()
for xml_file in glob.glob(os.path.join(RAW_DIR, "train", "*.xml")):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    for obj in root.iter('object'):
        classes.add(obj.find('name').text)

class_map = {cls: idx for idx, cls in enumerate(sorted(classes))}

# Convert annotations and copy images
for subset in ['train', 'valid']:
    os.makedirs(os.path.join(PROCESSED_DIR, subset), exist_ok=True)
    for xml_file in tqdm(glob.glob(os.path.join(RAW_DIR, subset, "*.xml")), desc=f"Converting {subset} annotations"):
        convert_annotation(xml_file, class_map)
        
    for img_file in tqdm(glob.glob(os.path.join(RAW_DIR, subset, "*.jpg")), desc=f"Copying {subset} images"):
        shutil.copy(img_file, os.path.join(PROCESSED_DIR, subset))

# Create dataset.yaml file
dataset_config = {
    'path': PROCESSED_DIR,
    'train': os.path.join(PROCESSED_DIR, 'train'),
    'val': os.path.join(PROCESSED_DIR, 'valid'),
    'test': os.path.join(RAW_DIR, 'test'),  # Assuming test set is in RAW_DIR
    'nc': len(class_map),
    'names': list(class_map.keys())
}

with open(os.path.join(PROCESSED_DIR, 'dataset.yaml'), 'w') as f:
    yaml.dump(dataset_config, f)

print("Dataset preparation completed.")

# Train model for project
model_project = YOLO('yolov8n.yaml')
try:
    results_project = model_project.train(
        data=os.path.join(PROCESSED_DIR, 'dataset.yaml'),
        epochs=50,  # Reduced from 100
        imgsz=416,  # Reduced from 640
        batch=32,   # Increased from 16
        device='0' if torch.cuda.is_available() else 'cpu',
        project=MODEL_DIR,
        name='number_detection_project'
    )
except Exception as e:
    print(f"Error training project model: {e}")

# Train model for web
model_web = YOLO('yolov8n.yaml')
try:
    results_web = model_web.train(
        data=os.path.join(PROCESSED_DIR, 'dataset.yaml'),
        epochs=25,  # Reduced from 50
        imgsz=320,  # Reduced from 416
        batch=64,   # Increased from 32
        device='0' if torch.cuda.is_available() else 'cpu',
        project=MODEL_DIR,
        name='number_detection_web'
    )
except Exception as e:
    print(f"Error training web model: {e}")

# Evaluate the models
try:
    results_project = model_project.val()
    results_web = model_web.val()
    print("Project model evaluation results:")
    print(results_project)
    print("\nWeb model evaluation results:")
    print(results_web)
except Exception as e:
    print(f"Error evaluating models: {e}")

# Save the trained models
try:
    model_project.save(os.path.join(MODEL_DIR, 'number_detection_project.pt'))
    model_web.save(os.path.join(MODEL_DIR, 'number_detection_web.pt'))
    print("Models saved successfully.")
except Exception as e:
    print(f"Error saving models: {e}")

print("Training completed. Models saved in:", MODEL_DIR)

Converting train annotations:  11%|█         | 111/1032 [00:00<00:01, 874.83it/s]


KeyboardInterrupt: 