In [1]:
# CELL 1: Imports and Setup
import pandas as pd
import numpy as np
import torch
import yaml
from torch.utils.data import Dataset
from PIL import Image
import os
from sklearn.model_selection import train_test_split
import csv
from tqdm import tqdm
import datetime
import shutil
from pathlib import Path
from ultralytics import YOLO
import matplotlib.pyplot as plt
import random
import logging
from typing import Dict, List, Set, Tuple  # Added Tuple here

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set GPU memory allocation to 80%
torch.cuda.set_per_process_memory_fraction(0.8)

def set_all_seeds(seed=42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_yaml_file(train_dir: str, 
                    val_dir: str, 
                    yaml_path: str, 
                    num_classes: int, 
                    class_names: List[str]):
    """
    Create YAML configuration file for YOLOv8
    Args:
        train_dir: Path to training images directory
        val_dir: Path to validation images directory
        yaml_path: Output path for YAML file
        num_classes: Number of classes
        class_names: List of class names
    """
    train_dir = os.path.abspath(train_dir)
    val_dir = os.path.abspath(val_dir)
    
    yaml_content = {
        'train': os.path.join(train_dir, 'images'),
        'val': os.path.join(val_dir, 'images'),
        'nc': num_classes,
        'names': class_names
    }
    
    with open(yaml_path, 'w') as f:
        yaml.dump(yaml_content, f)
    
    logging.info(f"Created YAML configuration file at {yaml_path}")

def get_background_images(train_image_count: int, 
                         train_annotations: pd.DataFrame, 
                         background_dir: str, 
                         background_percentage: float = 0,
                         random_state: int = 42) -> List[str]:
    """
    Get random background images that don't contain any annotations
    Args:
        train_image_count: Number of training images (not boxes)
        train_annotations: Full annotations DataFrame
        background_dir: Directory containing background images
        background_percentage: Percentage of training images to use as background
        random_state: Random seed
    Returns:
        List of selected background image IDs
    """
    # Calculate number of background images needed
    num_background = int(train_image_count * (background_percentage / 100))
    
    # Get all image IDs that contain any annotations
    annotated_images = set(train_annotations['ImageID'].unique())
    
    # Get all available background images
    try:
        background_files = [f[:-4] for f in os.listdir(background_dir) if f.endswith('.jpg')]
    except FileNotFoundError:
        logging.warning(f"Background directory not found: {background_dir}")
        return []
    
    # Filter out images that have annotations
    valid_backgrounds = list(set(background_files) - annotated_images)
    
    if len(valid_backgrounds) < num_background:
        logging.warning(
            f"Only {len(valid_backgrounds)} valid background images available, "
            f"requested {num_background}"
        )
        num_background = len(valid_backgrounds)
    
    # Randomly select background images
    random.seed(random_state)
    selected_backgrounds = random.sample(valid_backgrounds, num_background)
    
    logging.info(f"Selected {len(selected_backgrounds)} background images")
    return selected_backgrounds

def process_background_images(background_images: List[str], 
                            background_dir: str, 
                            output_dir: str):
    """
    Process background images and add them to the dataset
    Args:
        background_images: List of background image IDs
        background_dir: Source directory for background images
        output_dir: Output directory for dataset
    """
    for image_id in tqdm(background_images, desc="Processing background images"):
        # Create empty label file
        label_path = os.path.join(output_dir, 'labels', f"{image_id}.txt")
        open(label_path, 'w').close()  # Create empty file
        
        # Create symbolic link to background image
        src_img_path = os.path.join(background_dir, f"{image_id}.jpg")
        dst_img_path = os.path.join(output_dir, 'images', f"{image_id}.jpg")
        
        if os.path.exists(src_img_path):
            if not os.path.exists(dst_img_path):
                try:
                    os.symlink(src_img_path, dst_img_path)
                except OSError as e:
                    logging.error(f"Failed to create symlink for {image_id}: {str(e)}")
        else:
            logging.warning(f"Background image not found: {src_img_path}")

def save_results(results: List[Dict], project_dir: str):
    """
    Save training results to CSV file
    Args:
        results: List of dictionaries containing training metrics
        project_dir: Directory to save results
    """
    results_path = os.path.join(project_dir, 'training_metrics.csv')
    
    # Collect all possible column names from all results
    csv_columns = set()
    for result in results:
        csv_columns.update(result.keys())
    
    csv_columns = sorted(list(csv_columns))  # Sort columns for consistency
    
    try:
        with open(results_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=csv_columns)
            writer.writeheader()
            for result in results:
                # Fill in missing values with None
                row = {col: result.get(col, None) for col in csv_columns}
                writer.writerow(row)
        
        logging.info(f"Results saved to {results_path}")
        
    except IOError as e:
        logging.error(f"Failed to save results to CSV: {str(e)}")

def verify_dataset_integrity(dataset_dir: str, 
                           class_mapping: Dict[str, int]) -> bool:
    """
    Verify the integrity of the created dataset
    """
    images_dir = os.path.join(dataset_dir, 'images')
    labels_dir = os.path.join(dataset_dir, 'labels')
    
    # Check directory structure
    if not all(os.path.exists(d) for d in [images_dir, labels_dir]):
        logging.error("Dataset directory structure is invalid")
        return False
    
    # Get image and label files
    image_files = set(f[:-4] for f in os.listdir(images_dir) if f.endswith('.jpg'))
    label_files = set(f[:-4] for f in os.listdir(labels_dir) if f.endswith('.txt'))
    
    # Check for missing pairs
    missing_labels = image_files - label_files
    missing_images = label_files - image_files
    
    if missing_labels:
        logging.error(f"Found {len(missing_labels)} images without labels")
        return False
    
    if missing_images:
        logging.error(f"Found {len(missing_images)} labels without images")
        # Clean up orphaned label files
        for image_id in missing_images:
            os.remove(os.path.join(labels_dir, f"{image_id}.txt"))
        logging.info("Removed orphaned label files")
        
    # Verify label format and class indices
    valid_classes = set(class_mapping.values())
    boxes_per_class = {idx: 0 for idx in valid_classes}
    
    for label_file in os.listdir(labels_dir):
        if not label_file.endswith('.txt'):
            continue
            
        with open(os.path.join(labels_dir, label_file), 'r') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    values = line.strip().split()
                    if len(values) != 5:
                        logging.error(f"Invalid label format in {label_file} line {line_num}")
                        return False
                    
                    class_idx = int(values[0])
                    if class_idx not in valid_classes:
                        logging.error(f"Invalid class index in {label_file} line {line_num}: {class_idx}")
                        return False
                    
                    boxes_per_class[class_idx] += 1
                        
                    # Verify bounding box coordinates are in range [0, 1]
                    coords = [float(v) for v in values[1:]]
                    if not all(0 <= v <= 1 for v in coords):
                        logging.error(f"Invalid coordinates in {label_file} line {line_num}")
                        return False
                        
                except ValueError as e:
                    logging.error(f"Invalid value in {label_file} line {line_num}: {str(e)}")
                    return False
    
    # Log box distribution
    logging.info("\nBounding box distribution:")
    for class_idx, count in boxes_per_class.items():
        logging.info(f"Class {class_idx}: {count} boxes")
    
    logging.info("Dataset integrity verification passed")
    return True

def verify_dataset_distribution(dataset_dir: str, 
                              class_mapping: Dict[str, int],
                              target_images_per_class: int) -> bool:
    """
    Verify the distribution of images and boxes across classes
    """
    images_dir = os.path.join(dataset_dir, 'images')
    labels_dir = os.path.join(dataset_dir, 'labels')
    
    # Count images and boxes per class
    images_per_class = {idx: set() for idx in class_mapping.values()}
    boxes_per_class = {idx: 0 for idx in class_mapping.values()}
    
    for label_file in os.listdir(labels_dir):
        if not label_file.endswith('.txt'):
            continue
            
        image_id = label_file[:-4]
        class_indices_in_image = set()
        
        with open(os.path.join(labels_dir, label_file), 'r') as f:
            for line in f:
                try:
                    class_idx = int(line.strip().split()[0])
                    boxes_per_class[class_idx] += 1
                    class_indices_in_image.add(class_idx)
                except (ValueError, IndexError) as e:
                    logging.error(f"Error parsing label file {label_file}: {str(e)}")
                    return False
        
        # Add image ID to all classes present in the image
        for class_idx in class_indices_in_image:
            images_per_class[class_idx].add(image_id)
    
    # Check class distribution
    logging.info("\nClass Distribution Analysis:")
    for class_idx, images in images_per_class.items():
        image_count = len(images)
        box_count = boxes_per_class[class_idx]
        avg_boxes_per_image = box_count / image_count if image_count > 0 else 0
        
        logging.info(f"Class {class_idx}:")
        logging.info(f"  Images: {image_count} (Target: {target_images_per_class})")
        logging.info(f"  Total boxes: {box_count}")
        logging.info(f"  Average boxes per image: {avg_boxes_per_image:.2f}")
        
        if image_count < target_images_per_class * 0.9:  # Allow 10% tolerance
            logging.warning(f"Class {class_idx} has fewer images than target")
    
    return True

# CELL 2: Data Loading and Preprocessing
class MultiClassDataLoader:
    def __init__(self, class_desc_path: str, annotations_path: str, image_dir: str):
        self.class_descriptions = pd.read_csv(class_desc_path)
        self.train_annotations = pd.read_csv(annotations_path)
        self.image_dir = image_dir
        self.class_mapping = {}  # LabelName to index mapping
        self.reverse_mapping = {}  # index to DisplayName mapping
        
        # Get available image IDs from directory
        self.available_image_ids = set(
            f[:-4] for f in os.listdir(image_dir) 
            if f.endswith('.jpg')
        )
        logging.info(f"Found {len(self.available_image_ids)} images in directory")
        
        # Filter annotations to only include available images
        self.train_annotations = self.train_annotations[
            self.train_annotations['ImageID'].isin(self.available_image_ids)
        ]
        logging.info(f"Filtered annotations to {len(self.train_annotations)} rows with available images")
        
    def get_training_classes(self) -> List[str]:
        """Get list of classes marked for use in model"""
        # Get classes marked for training
        training_classes = self.class_descriptions[
            self.class_descriptions['UseInModel'] == True
        ]['DisplayName'].tolist()
        
        # Filter to only include classes that have sufficient data
        valid_classes = []
        for class_name in training_classes:
            label_name = self.class_descriptions[
                self.class_descriptions['DisplayName'] == class_name
            ]['LabelName'].iloc[0]
            
            # Count annotations for this class
            class_count = len(self.train_annotations[
                self.train_annotations['LabelName'] == label_name
            ])
            
            if class_count > 0:
                valid_classes.append(class_name)
                logging.info(f"Class {class_name}: {class_count} annotations in available images")
            else:
                logging.warning(f"Class {class_name}: no annotations in available images, skipping")
        
        logging.info(f"Found {len(valid_classes)} classes with data in available images")
        return valid_classes
    
    def create_class_mappings(self, training_classes: List[str]):
        """Create mappings between class names and indices"""
        self.class_mapping.clear()
        self.reverse_mapping.clear()
        
        for idx, class_name in enumerate(training_classes):
            label_name = self.class_descriptions[
                self.class_descriptions['DisplayName'] == class_name
            ]['LabelName'].iloc[0]
            self.class_mapping[label_name] = idx
            self.reverse_mapping[idx] = class_name
            
def filter_annotations(class_loader: MultiClassDataLoader,
                      images_per_class: int,
                      fixed_val_images: int) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """
    Filter annotations ensuring proper image counts per class
    Args:
        class_loader: MultiClassDataLoader instance
        images_per_class: Number of training images per class
        fixed_val_images: Number of validation images per class
    Returns:
        Tuple of (train_data, val_data) dictionaries
    """
    train_data = {}
    val_data = {}
    
    # Keep track of selected images across all classes
    all_selected_train_images = set()
    all_selected_val_images = set()
    
    for label_name, class_idx in class_loader.class_mapping.items():
        # Get all annotations for this class
        class_annotations = class_loader.train_annotations[
            class_loader.train_annotations['LabelName'] == label_name
        ]
        
        # Get unique image IDs for this class
        class_image_ids = class_annotations['ImageID'].unique()
        total_images_needed = images_per_class + fixed_val_images
        
        if len(class_image_ids) < total_images_needed:
            logging.warning(
                f"Insufficient images for class {class_loader.reverse_mapping[class_idx]}. "
                f"Required: {total_images_needed}, Available: {len(class_image_ids)}"
            )
            continue
        
        # Shuffle image IDs
        shuffled_image_ids = np.random.permutation(class_image_ids)
        
        # Select validation images
        val_images = set()
        val_idx = 0
        while len(val_images) < fixed_val_images and val_idx < len(shuffled_image_ids):
            img_id = shuffled_image_ids[val_idx]
            if img_id not in all_selected_val_images:  # Avoid validation/train overlap
                val_images.add(img_id)
                all_selected_val_images.add(img_id)
            val_idx += 1
        
        if len(val_images) < fixed_val_images:
            logging.warning(
                f"Could only find {len(val_images)} unique validation images for class "
                f"{class_loader.reverse_mapping[class_idx]}"
            )
        
        # Select training images
        train_images = set()
        train_idx = val_idx
        while len(train_images) < images_per_class and train_idx < len(shuffled_image_ids):
            img_id = shuffled_image_ids[train_idx]
            if img_id not in val_images:  # Ensure no overlap with validation
                train_images.add(img_id)
                all_selected_train_images.add(img_id)
            train_idx += 1
        
        if len(train_images) < images_per_class:
            logging.warning(
                f"Could only find {len(train_images)} unique training images for class "
                f"{class_loader.reverse_mapping[class_idx]}"
            )
        
        # Get all annotations for selected images
        train_data[label_name] = class_annotations[
            class_annotations['ImageID'].isin(train_images)
        ]
        val_data[label_name] = class_annotations[
            class_annotations['ImageID'].isin(val_images)
        ]
        
        # Log statistics
        train_box_count = len(train_data[label_name])
        val_box_count = len(val_data[label_name])
        
        logging.info(f"Class {class_loader.reverse_mapping[class_idx]}:")
        logging.info(f"  Training: {len(train_images)} images with {train_box_count} boxes")
        logging.info(f"  Validation: {len(val_images)} images with {val_box_count} boxes")
    
    # Log overall image usage statistics
    logging.info("\nOverall Dataset Statistics:")
    logging.info(f"Total unique training images: {len(all_selected_train_images)}")
    logging.info(f"Total unique validation images: {len(all_selected_val_images)}")
    
    return train_data, val_data

# CELL 3: Dataset Creation
class MultiClassYOLODataset(Dataset):
    def __init__(self, annotations_dict: Dict[str, pd.DataFrame], 
                 image_dir: str, class_mapping: Dict[str, int],
                 transforms=None):
        self.annotations_dict = annotations_dict
        self.image_dir = image_dir
        self.class_mapping = class_mapping
        self.transforms = transforms
        
        # Combine all annotations and get unique image IDs
        all_annotations = pd.concat(annotations_dict.values())
        self.image_ids = all_annotations['ImageID'].unique()
        self.image_annotations = all_annotations.groupby('ImageID')
        
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        image = Image.open(img_path).convert("RGB")
        
        # Get all annotations for this image
        image_annotations = self.image_annotations.get_group(image_id)
        
        # Convert bounding boxes to YOLO format
        labels = []
        for _, row in image_annotations.iterrows():
            class_idx = self.class_mapping[row['LabelName']]
            x_min, x_max = row['XMin'], row['XMax']
            y_min, y_max = row['YMin'], row['YMax']
            
            x_center = (x_min + x_max) / 2
            y_center = (y_min + y_max) / 2
            width = x_max - x_min
            height = y_max - y_min
            
            labels.append([class_idx, x_center, y_center, width, height])
            
        return img_path, torch.tensor(labels)

def create_multi_class_dataset(annotations_dict: Dict[str, pd.DataFrame],
                             image_dir: str,
                             output_dir: str,
                             class_mapping: Dict[str, int]):
    """
    Create YOLO dataset structure for multiple classes, handling multiple bounding boxes per image
    """
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'labels'), exist_ok=True)
    
    # Combine all annotations
    all_annotations = pd.concat(annotations_dict.values())
    
    # Get unique image IDs
    unique_image_ids = all_annotations['ImageID'].unique()
    logging.info(f"Processing {len(unique_image_ids)} unique images")
    
    # Process each unique image
    for image_id in tqdm(unique_image_ids, desc="Creating dataset"):
        # Get all annotations for this image
        image_annots = all_annotations[all_annotations['ImageID'] == image_id]
        
        # Create label file with all bounding boxes
        label_path = os.path.join(output_dir, 'labels', f"{image_id}.txt")
        
        with open(label_path, 'w') as f:
            for _, row in image_annots.iterrows():
                class_idx = class_mapping[row['LabelName']]
                x_center = (row['XMin'] + row['XMax']) / 2
                y_center = (row['YMin'] + row['YMax']) / 2
                width = row['XMax'] - row['XMin']
                height = row['YMax'] - row['YMin']
                
                f.write(f"{class_idx} {x_center} {y_center} {width} {height}\n")
        
        # Create symbolic link to image
        src_img_path = os.path.join(image_dir, f"{image_id}.jpg")
        dst_img_path = os.path.join(output_dir, 'images', f"{image_id}.jpg")
        if os.path.exists(src_img_path):
            if not os.path.exists(dst_img_path):
                try:
                    os.symlink(src_img_path, dst_img_path)
                except OSError as e:
                    logging.error(f"Failed to create symlink for {image_id}: {str(e)}")
                    # If symlink fails, try copying the file
                    try:
                        shutil.copy2(src_img_path, dst_img_path)
                    except IOError as e:
                        logging.error(f"Failed to copy image {image_id}: {str(e)}")
        else:
            logging.warning(f"Image not found: {src_img_path}")
            # Remove the label file if image doesn't exist
            os.remove(label_path)


# CELL 4: Training Functions
def train_model(model, yaml_path, run_dir, box_count_per_class, max_epochs):
    """
    Train the YOLO model with all hyperparameters
    Args:
        model: Initialized YOLO model
        yaml_path: Path to dataset configuration
        run_dir: Directory for saving results
        box_count_per_class: Number of boxes per class used in training
        max_epochs: Maximum number of training epochs
    Returns:
        Training results
    """
    try:
        results = model.train(
            data=yaml_path,
            epochs=max_epochs,
            imgsz=640,
            # Model hyperparameters
            batch=16,              # Batch size
            workers=6,             # Number of worker threads
            device=0,              # GPU device number
            # conf=0.1,           # Confidence threshold
            # iou=0.1,            # IoU threshold
            
            # Training parameters
            project=run_dir,
            name=f'yolo_boxes_{box_count_per_class}',
            patience=20,           # Early stopping patience
            save=True,            # Save checkpoints
            save_period=10,       # Save every N epochs
            
            # Learning rate parameters
            # lr0=0.00001,        # Initial learning rate
            # lrf=0.000001,       # Final learning rate
            # warmup_epochs=3,    # Number of warmup epochs
            # warmup_momentum=0.8,# Warmup momentum
            # warmup_bias_lr=0.1, # Warmup bias learning rate
            
            # Loss weights
            # box=7.5,            # Box loss weight
            # cls=0.5,            # Class loss weight
            # dfl=1.5,            # DFL loss weight
            
            # Visualization and logging
            plots=True,           # Generate plots
            verbose=True,         # Verbose output
            
            # Regularization
            # weight_decay=0.0005,# Weight decay
            # dropout=0.2,        # Dropout rate
            
            # Data augmentation (enabled by default)
            # hsv_h=0,           # HSV-Hue augmentation
            # hsv_s=0,           # HSV-Saturation augmentation
            # hsv_v=0,           # HSV-Value augmentation
            # translate=0,        # Translation augmentation
            # scale=0,           # Scaling augmentation
            # fliplr=0,          # Horizontal flip augmentation
            # mosaic=0,          # Mosaic augmentation
            # erasing=0,         # Random erasing
            # crop_fraction=1     # Crop fraction
        )
        return results
    
    except Exception as e:
        logging.error(f"Training failed with error: {str(e)}")
        raise

def save_training_results(results, run_dir, model_dir, box_count_per_class):
    """Save model checkpoints and training metrics"""
    try:
        # Save models
        yolo_output_dir = os.path.join(run_dir, f'yolo_boxes_{box_count_per_class}')
        
        # Save best model
        if os.path.exists(os.path.join(yolo_output_dir, 'weights', 'best.pt')):
            best_model_path = os.path.join(model_dir, f'model_best_boxes_{box_count_per_class}.pt')
            shutil.copy2(
                os.path.join(yolo_output_dir, 'weights', 'best.pt'),
                best_model_path
            )
        
        # Save final model
        if os.path.exists(os.path.join(yolo_output_dir, 'weights', 'last.pt')):
            final_model_path = os.path.join(model_dir, f'model_final_boxes_{box_count_per_class}.pt')
            shutil.copy2(
                os.path.join(yolo_output_dir, 'weights', 'last.pt'),
                final_model_path
            )
        
        # Collect metrics
        metrics = {
            'BoundingBoxCountPerClass': box_count_per_class,
            'ModelDirectory': model_dir,
            'FinalEpoch': results.epoch
        }
        
        # Add per-class metrics if available
        if hasattr(results, 'results_dict'):
            metrics_dict = results.results_dict
            
            # Overall metrics
            metrics.update({
                'Precision': float(metrics_dict.get('metrics/precision', 0.0)),
                'Recall': float(metrics_dict.get('metrics/recall', 0.0)),
                'mAP50': float(metrics_dict.get('metrics/mAP50', 0.0)),
                'mAP50-95': float(metrics_dict.get('metrics/mAP50-95', 0.0))
            })
            
            # Add per-class metrics if available
            for class_idx in range(results.num_classes):
                class_prefix = f'metrics/precision_{class_idx}'
                if class_prefix in metrics_dict:
                    metrics.update({
                        f'Precision_Class_{class_idx}': float(metrics_dict[f'metrics/precision_{class_idx}']),
                        f'Recall_Class_{class_idx}': float(metrics_dict[f'metrics/recall_{class_idx}']),
                        f'mAP50_Class_{class_idx}': float(metrics_dict[f'metrics/mAP50_{class_idx}']),
                        f'mAP50-95_Class_{class_idx}': float(metrics_dict[f'metrics/mAP50-95_{class_idx}'])
                    })
        
        return metrics
        
    except Exception as e:
        logging.error(f"Failed to save training results: {str(e)}")
        return None


def train_multi_class_yolo(image_count_per_class: int,
                          image_dir: str,
                          project_dir: str,
                          class_loader: MultiClassDataLoader,
                          fixed_val_size: int = 500,
                          max_epochs: int = 300,
                          background_percentage: float = 0):  # Set default to 0
    """Train YOLO model for multiple classes"""
    
    # Create run directory
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = os.path.join(project_dir, f'multi_class_run_{image_count_per_class}_images_{timestamp}')
    dataset_dir = os.path.join(run_dir, 'dataset')
    model_dir = os.path.join(run_dir, 'models')
    
    os.makedirs(run_dir, exist_ok=True)
    os.makedirs(dataset_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    
    # Get training classes and create mappings
    training_classes = class_loader.get_training_classes()
    class_loader.create_class_mappings(training_classes)
    
    # Filter annotations for each class - now returns tuple (train_data, val_data)
    train_data, val_data = filter_annotations(class_loader, image_count_per_class, fixed_val_size)
    
    if not train_data or not val_data:
        logging.error("No classes had sufficient data for training")
        return None
    
    # Create dataset structure
    train_dir = os.path.join(dataset_dir, 'train')
    val_dir = os.path.join(dataset_dir, 'val')
    
    # Create training dataset
    create_multi_class_dataset(
        train_data, 
        image_dir, 
        train_dir,
        class_loader.class_mapping
    )
    
    # Verify training dataset integrity
    logging.info("Verifying training dataset integrity...")
    if not verify_dataset_integrity(train_dir, class_loader.class_mapping):
        logging.error("Training dataset verification failed. Aborting training.")
        return None
    
    # Create validation dataset
    create_multi_class_dataset(
        val_data, 
        image_dir, 
        val_dir,
        class_loader.class_mapping
    )
    
    # Verify validation dataset integrity
    logging.info("Verifying validation dataset integrity...")
    if not verify_dataset_integrity(val_dir, class_loader.class_mapping):
        logging.error("Validation dataset verification failed. Aborting training.")
        return None
    
    # Create YAML configuration
    yaml_path = os.path.join(run_dir, 'dataset.yaml')
    create_yaml_file(
        train_dir, val_dir, yaml_path,
        num_classes=len(training_classes),
        class_names=[class_loader.reverse_mapping[i] for i in range(len(training_classes))]
    )
    
    # Print dataset summary
    logging.info("\nDataset Summary:")
    logging.info(f"Number of classes: {len(training_classes)}")
    
    # Calculate total images and boxes
    train_images = len(os.listdir(os.path.join(train_dir, 'images')))
    val_images = len(os.listdir(os.path.join(val_dir, 'images')))
    
    logging.info(f"Training images: {train_images}")
    logging.info(f"Validation images: {val_images}")
    
    # Class distribution summary
    logging.info("\nClass Distribution:")
    for class_idx, class_name in class_loader.reverse_mapping.items():
        # Count boxes in training set
        train_count = sum(1 for f in os.listdir(os.path.join(train_dir, 'labels'))
                         for line in open(os.path.join(train_dir, 'labels', f))
                         if line.startswith(f"{class_idx} "))
        
        # Count boxes in validation set
        val_count = sum(1 for f in os.listdir(os.path.join(val_dir, 'labels'))
                       for line in open(os.path.join(val_dir, 'labels', f))
                       if line.startswith(f"{class_idx} "))
        
        logging.info(f"Class {class_name}: {train_count} training boxes, {val_count} validation boxes")
    
    # Initialize and train model
    try:
        model = YOLO('yolov8s.pt')
        results = train_model(model, yaml_path, run_dir, image_count_per_class, max_epochs)
        save_training_results(results, run_dir, model_dir, image_count_per_class)
        return results
    except Exception as e:
        logging.error(f"Error during training: {str(e)}")
        return None



In [2]:
from IPython import get_ipython
from contextlib import redirect_stdout, redirect_stderr
import io

if __name__ == "__main__":
    # Initialize paths and parameters
    image_dir = r'E:\Data 255\YOLO multi-class\image_data\training_images_25_class'
    project_dir = r'E:\Data 255\YOLO multi-class\yolo_training'
    class_desc_path = 'oidv7-class-descriptions-boxable.csv'
    annotations_path = 'oidv6-train-annotations-bbox.csv'
    
    # Create output file path
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = os.path.join(project_dir, f'notebook_output_{timestamp}.txt')
    
    # Capture all output
    output_buffer = io.StringIO()
    with redirect_stdout(output_buffer), redirect_stderr(output_buffer):
        # Your existing code here
        set_all_seeds()
        
        # Create project directory
        os.makedirs(project_dir, exist_ok=True)
        
        # Initialize class loader
        class_loader = MultiClassDataLoader(class_desc_path, annotations_path, image_dir)
        
        # Training parameters
        image_counts = [500,1000]  # Images per class
        fixed_val_size = 500  # Validation images per class
        
        # Train models
        results = []
        for image_count in image_counts:
            try:
                metrics = train_multi_class_yolo(
                    image_count,  # Now represents images per class
                    image_dir,
                    project_dir,
                    class_loader,
                    fixed_val_size  # Now represents images per class
                )
                if metrics:
                    results.append(metrics)
                    save_results(results, project_dir)
            except Exception as e:
                logging.error(f"Error during training with {image_count} images per class: {str(e)}")
                continue
        
        logging.info("All training complete!")
    
    # Save captured output
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write(output_buffer.getvalue())

2024-11-24 09:53:08,903 - INFO - Found 102217 images in directory
2024-11-24 09:53:10,058 - INFO - Filtered annotations to 1061477 rows with available images
2024-11-24 09:53:10,132 - INFO - Class Footwear: 75466 annotations in available images
2024-11-24 09:53:10,188 - INFO - Class Suit: 15390 annotations in available images
2024-11-24 09:53:10,246 - INFO - Class Glasses: 9346 annotations in available images
2024-11-24 09:53:10,300 - INFO - Class Dress: 9040 annotations in available images
2024-11-24 09:53:10,360 - INFO - Class Jeans: 15804 annotations in available images
2024-11-24 09:53:10,420 - INFO - Class Tire: 22192 annotations in available images
2024-11-24 09:53:10,509 - INFO - Class Fashion accessory: 19463 annotations in available images
2024-11-24 09:53:10,598 - INFO - Class Microphone: 7432 annotations in available images
2024-11-24 09:53:10,658 - INFO - Class Guitar: 6723 annotations in available images
2024-11-24 09:53:10,716 - INFO - Class Toy: 19766 annotations in avai

New https://pypi.org/project/ultralytics/8.3.36 available  Update with 'pip install -U ultralytics'
Ultralytics 8.3.23  Python-3.10.0 torch-2.5.0+cu118 CUDA:0 (NVIDIA GeForce RTX 3080, 10240MiB)
[34m[1mengine\trainer: [0mtask=detect, mode=train, model=yolov8s.pt, data=E:\Data 255\YOLO multi-class\yolo_training\multi_class_run_500_images_20241124_095310\dataset.yaml, epochs=300, time=None, patience=20, batch=16, imgsz=640, save=True, save_period=10, cache=False, device=0, workers=6, project=E:\Data 255\YOLO multi-class\yolo_training\multi_class_run_500_images_20241124_095310, name=yolo_boxes_500, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False

2024-11-24 15:55:30,067 - ERROR - Failed to save training results: 'DetMetrics' object has no attribute 'epoch'. See valid attributes below.

    Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
    object detection model.

    Args:
        save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
        plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
        on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
        names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.

    Attributes:
        save_dir (Path): A path to the directory where the output plots will be saved.
        plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
        on_plot (func): An o

New https://pypi.org/project/ultralytics/8.3.36 available  Update with 'pip install -U ultralytics'
Ultralytics 8.3.23  Python-3.10.0 torch-2.5.0+cu118 CUDA:0 (NVIDIA GeForce RTX 3080, 10240MiB)
[34m[1mengine\trainer: [0mtask=detect, mode=train, model=yolov8s.pt, data=E:\Data 255\YOLO multi-class\yolo_training\multi_class_run_1000_images_20241124_155530\dataset.yaml, epochs=300, time=None, patience=20, batch=16, imgsz=640, save=True, save_period=10, cache=False, device=0, workers=6, project=E:\Data 255\YOLO multi-class\yolo_training\multi_class_run_1000_images_20241124_155530, name=yolo_boxes_1000, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=Fa

2024-11-25 02:22:35,778 - ERROR - Failed to save training results: 'DetMetrics' object has no attribute 'epoch'. See valid attributes below.

    Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
    object detection model.

    Args:
        save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
        plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
        on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
        names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.

    Attributes:
        save_dir (Path): A path to the directory where the output plots will be saved.
        plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
        on_plot (func): An o