<a href="https://colab.research.google.com/github/Alex-Jung-HB/0813_python_object-detection-using-segformer-YOLO11n/blob/main/0813_python_segmentation_training_(YOLO11n).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

YOLO11 Object Detection Training

In [None]:
"""
YOLO11 Training Tool - Fixed Dataset Handling

For Jupyter/Colab users, use the simple function:
    train_yolo_simple("/path/to/data.zip", classes="all", epochs=100)
    train_yolo_simple("/path/to/data.zip", classes="0,2,5", epochs=50)

For command line usage:
    python yolo11_trainer.py --cli
    python yolo11_trainer.py --zip data.zip --classes all --epochs 100
"""

import os
import sys
import zipfile
import json
import yaml
from pathlib import Path
import threading
import shutil
import argparse
import subprocess
import random
import time
import warnings

# Suppress common warnings that clutter the output
warnings.filterwarnings("ignore", category=UserWarning, module="torch.*")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.*")

# These will be imported after checking if they're available
# import torch
# from ultralytics import YOLO

# Try to import tkinter, handle if no display available
GUI_AVAILABLE = True
try:
    import tkinter as tk
    from tkinter import ttk, filedialog, messagebox, scrolledtext
    # Test if display is available
    root_test = tk.Tk()
    root_test.withdraw()
    root_test.destroy()
except (ImportError, tk.TclError) as e:
    GUI_AVAILABLE = False
    print("=" * 60)
    print("YOLO11 Training Tool")
    print("=" * 60)
    print(f"GUI not available: {e}")
    print("Running in command-line mode...")
    print("\nTo enable GUI:")
    print("- On Linux/WSL: Install X server (Xming, VcXsrv, or X11)")
    print("- On SSH: Use 'ssh -X' for X11 forwarding")
    print("- On headless servers: Use CLI mode with --cli flag")
    print("=" * 60)

def check_and_install_packages():
    """Check and install required packages with better error handling"""
    missing_packages = []
    installation_commands = []

    print("🔍 Checking required packages...")

    # Check PyTorch
    try:
        import torch
        print(f"✅ PyTorch {torch.__version__} is available")

        # Check CUDA availability with better error handling
        if torch.cuda.is_available():
            gpu_count = torch.cuda.device_count()
            print(f"✅ CUDA available with {gpu_count} GPU(s)")
            for i in range(gpu_count):
                try:
                    gpu_name = torch.cuda.get_device_name(i)
                    print(f"   GPU {i}: {gpu_name}")
                except Exception as e:
                    print(f"   GPU {i}: Unknown (error: {e})")
        else:
            print("⚠️  CUDA not available - will use CPU")

    except ImportError:
        print("❌ PyTorch not found")
        missing_packages.append("torch")
        installation_commands.append("pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124")

    # Check Ultralytics
    try:
        import ultralytics
        print(f"✅ Ultralytics {ultralytics.__version__} is available")
    except ImportError:
        print("❌ Ultralytics not found")
        missing_packages.append("ultralytics")
        installation_commands.append("pip install ultralytics")

    # Install missing packages
    if missing_packages:
        print(f"\n📦 Installing {len(missing_packages)} missing package(s)...")

        for i, cmd in enumerate(installation_commands):
            package_name = missing_packages[i]
            print(f"\n⏳ Installing {package_name}...")

            try:
                # Use subprocess with better error handling
                result = subprocess.run(
                    cmd.split(),
                    capture_output=True,
                    text=True,
                    timeout=300  # 5 minute timeout
                )

                if result.returncode == 0:
                    print(f"✅ {package_name} installed successfully")
                else:
                    print(f"❌ Failed to install {package_name}")
                    print(f"Error: {result.stderr}")
                    return False

            except subprocess.TimeoutExpired:
                print(f"❌ Installation of {package_name} timed out")
                return False
            except Exception as e:
                print(f"❌ Error installing {package_name}: {e}")
                return False

        print("\n🔄 Reloading modules...")
        # Try to import again after installation
        try:
            import torch
            from ultralytics import YOLO
            print("✅ All packages loaded successfully")
        except ImportError as e:
            print(f"❌ Still missing packages after installation: {e}")
            return False

    return True

class DatasetManager:
    """Handles all dataset operations with robust error handling"""

    def __init__(self, log_func=print):
        self.log = log_func
        self.image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp'}

    def analyze_dataset_structure(self, extract_path):
        """Analyze and report dataset structure with detailed logging"""
        self.log("🔍 Analyzing dataset structure...")

        structure_info = {
            'images': [],
            'labels': [],
            'image_dirs': {},
            'label_dirs': {},
            'total_images': 0,
            'total_labels': 0,
            'class_ids': set()
        }

        # Walk through all directories
        for root, dirs, files in os.walk(extract_path):
            rel_root = os.path.relpath(root, extract_path)
            if rel_root == '.':
                rel_root = 'root'

            # Count images and labels in this directory
            images_in_dir = []
            labels_in_dir = []

            for file in files:
                file_lower = file.lower()
                if any(file_lower.endswith(ext) for ext in self.image_extensions):
                    images_in_dir.append(file)
                    structure_info['images'].append(os.path.join(root, file))
                elif file_lower.endswith('.txt') and file_lower not in ['classes.txt', 'readme.txt']:
                    labels_in_dir.append(file)
                    structure_info['labels'].append(os.path.join(root, file))

                    # Extract class IDs from this label file
                    try:
                        with open(os.path.join(root, file), 'r') as f:
                            for line in f:
                                if line.strip():
                                    parts = line.split()
                                    if len(parts) >= 5:
                                        try:
                                            class_id = int(parts[0])
                                            structure_info['class_ids'].add(class_id)
                                        except ValueError:
                                            continue
                    except Exception:
                        pass

            if images_in_dir:
                structure_info['image_dirs'][rel_root] = len(images_in_dir)
                structure_info['total_images'] += len(images_in_dir)

            if labels_in_dir:
                structure_info['label_dirs'][rel_root] = len(labels_in_dir)
                structure_info['total_labels'] += len(labels_in_dir)

        # Report findings
        self.log(f"📊 Dataset Analysis Complete:")
        self.log(f"   🖼️  Total images: {structure_info['total_images']}")
        self.log(f"   🏷️  Total labels: {structure_info['total_labels']}")
        self.log(f"   🎯 Unique classes: {len(structure_info['class_ids'])}")

        if structure_info['image_dirs']:
            self.log(f"   📁 Image directories:")
            for dir_name, count in structure_info['image_dirs'].items():
                self.log(f"      {dir_name}: {count} images")

        if structure_info['label_dirs']:
            self.log(f"   📂 Label directories:")
            for dir_name, count in structure_info['label_dirs'].items():
                self.log(f"      {dir_name}: {count} labels")

        return structure_info

    def create_yolo_structure(self, base_path):
        """Create YOLO directory structure"""
        yolo_dirs = {
            'train_images': os.path.join(base_path, 'train', 'images'),
            'train_labels': os.path.join(base_path, 'train', 'labels'),
            'val_images': os.path.join(base_path, 'val', 'images'),
            'val_labels': os.path.join(base_path, 'val', 'labels'),
            'test_images': os.path.join(base_path, 'test', 'images'),
            'test_labels': os.path.join(base_path, 'test', 'labels')
        }

        # Create all directories
        for dir_name, dir_path in yolo_dirs.items():
            os.makedirs(dir_path, exist_ok=True)
            self.log(f"   📁 Created: {os.path.relpath(dir_path, base_path)}")

        return yolo_dirs

    def find_image_label_pairs(self, structure_info):
        """Find matching image-label pairs"""
        self.log("🔍 Finding image-label pairs...")

        image_label_pairs = []
        unmatched_images = []

        for img_path in structure_info['images']:
            img_name = os.path.basename(img_path)
            img_name_no_ext = os.path.splitext(img_name)[0]

            # Look for corresponding label file
            label_path = None
            for lbl_path in structure_info['labels']:
                lbl_name = os.path.basename(lbl_path)
                lbl_name_no_ext = os.path.splitext(lbl_name)[0]

                if img_name_no_ext == lbl_name_no_ext:
                    label_path = lbl_path
                    break

            if label_path and os.path.exists(label_path):
                image_label_pairs.append((img_path, label_path, img_name))
            else:
                unmatched_images.append(img_name)

        self.log(f"✅ Found {len(image_label_pairs)} valid image-label pairs")
        if unmatched_images:
            self.log(f"⚠️  {len(unmatched_images)} images without labels")
            if len(unmatched_images) <= 10:
                for img in unmatched_images[:10]:
                    self.log(f"      {img}")
            else:
                for img in unmatched_images[:5]:
                    self.log(f"      {img}")
                self.log(f"      ... and {len(unmatched_images) - 5} more")

        return image_label_pairs

    def split_dataset(self, image_label_pairs, train_ratio=0.7, val_ratio=0.2):
        """Split dataset into train/val/test with minimum validation guarantee"""
        if len(image_label_pairs) == 0:
            return {'train': [], 'val': [], 'test': []}

        # Shuffle for random split
        random.shuffle(image_label_pairs)
        total_pairs = len(image_label_pairs)

        # Ensure minimum validation set size
        min_val_size = max(1, min(10, total_pairs // 10))  # At least 1, max 10, or 10% of dataset

        if total_pairs < 3:
            # Very small dataset - put most in training, at least 1 in validation
            if total_pairs == 1:
                splits = {'train': image_label_pairs, 'val': [], 'test': []}
            elif total_pairs == 2:
                splits = {'train': image_label_pairs[:1], 'val': image_label_pairs[1:], 'test': []}
            else:  # total_pairs == 3
                splits = {'train': image_label_pairs[:2], 'val': image_label_pairs[2:], 'test': []}
        else:
            # Calculate split points
            val_size = max(min_val_size, int(total_pairs * val_ratio))
            test_size = max(1, int(total_pairs * (1 - train_ratio - val_ratio)))
            train_size = total_pairs - val_size - test_size

            # Ensure train_size is positive
            if train_size < 1:
                train_size = total_pairs - val_size
                test_size = 0

            train_end = train_size
            val_end = train_size + val_size

            splits = {
                'train': image_label_pairs[:train_end],
                'val': image_label_pairs[train_end:val_end],
                'test': image_label_pairs[val_end:] if test_size > 0 else []
            }

        self.log(f"📊 Dataset split:")
        self.log(f"   🏋️  Training: {len(splits['train'])} samples ({len(splits['train'])/total_pairs*100:.1f}%)")
        self.log(f"   ✅ Validation: {len(splits['val'])} samples ({len(splits['val'])/total_pairs*100:.1f}%)")
        if splits['test']:
            self.log(f"   🧪 Test: {len(splits['test'])} samples ({len(splits['test'])/total_pairs*100:.1f}%)")

        return splits

    def copy_files_to_splits(self, splits, yolo_dirs):
        """Copy files to train/val/test directories"""
        self.log("📋 Copying files to YOLO structure...")

        total_copied = 0

        for split_name, pairs in splits.items():
            if len(pairs) == 0:
                continue

            img_dir = yolo_dirs[f'{split_name}_images']
            lbl_dir = yolo_dirs[f'{split_name}_labels']

            split_copied = 0

            for img_path, lbl_path, img_name in pairs:
                try:
                    # Copy image
                    if os.path.exists(img_path):
                        shutil.copy2(img_path, img_dir)
                        split_copied += 1
                        total_copied += 1
                    else:
                        self.log(f"⚠️  Image not found: {img_path}")
                        continue

                    # Copy label
                    if lbl_path and os.path.exists(lbl_path):
                        shutil.copy2(lbl_path, lbl_dir)
                    else:
                        self.log(f"⚠️  Label not found for: {img_name}")

                except Exception as e:
                    self.log(f"❌ Error copying {img_name}: {e}")
                    continue

            self.log(f"   {split_name}: {split_copied} files copied")

        return total_copied > 0

    def reorganize_dataset(self, extract_path, structure_info):
        """Complete dataset reorganization with robust error handling"""
        self.log("🔄 Reorganizing dataset to YOLO format...")

        # Create YOLO directory structure
        yolo_dirs = self.create_yolo_structure(extract_path)

        # Find image-label pairs
        image_label_pairs = self.find_image_label_pairs(structure_info)

        if len(image_label_pairs) == 0:
            self.log("❌ No valid image-label pairs found!")
            return False

        # Check if dataset already has splits
        has_existing_splits = self.check_existing_splits(structure_info)

        if has_existing_splits:
            self.log("✅ Preserving existing train/val splits")
            success = self.preserve_existing_splits(structure_info, yolo_dirs)
        else:
            self.log("🔀 Creating new train/val/test splits")
            splits = self.split_dataset(image_label_pairs)
            success = self.copy_files_to_splits(splits, yolo_dirs)

        if not success:
            return False

        # Verify the reorganization
        return self.verify_dataset_structure(yolo_dirs)

    def check_existing_splits(self, structure_info):
        """Check if dataset already has train/val directory structure"""
        has_train = any('train' in dir_name.lower() for dir_name in structure_info['image_dirs'].keys())
        has_val = any('val' in dir_name.lower() or 'valid' in dir_name.lower() for dir_name in structure_info['image_dirs'].keys())
        return has_train and has_val

    def preserve_existing_splits(self, structure_info, yolo_dirs):
        """Preserve existing train/val/test splits"""
        files_copied = 0

        for img_path in structure_info['images']:
            img_name = os.path.basename(img_path)
            img_name_no_ext = os.path.splitext(img_name)[0]

            if not os.path.exists(img_path):
                continue

            # Determine split based on directory path
            dir_path = os.path.dirname(img_path).lower()

            if 'train' in dir_path:
                dest_img_dir = yolo_dirs['train_images']
                dest_lbl_dir = yolo_dirs['train_labels']
            elif 'val' in dir_path or 'valid' in dir_path:
                dest_img_dir = yolo_dirs['val_images']
                dest_lbl_dir = yolo_dirs['val_labels']
            elif 'test' in dir_path:
                dest_img_dir = yolo_dirs['test_images']
                dest_lbl_dir = yolo_dirs['test_labels']
            else:
                # Default to train if unclear
                dest_img_dir = yolo_dirs['train_images']
                dest_lbl_dir = yolo_dirs['train_labels']

            try:
                # Copy image
                shutil.copy2(img_path, dest_img_dir)
                files_copied += 1

                # Find and copy corresponding label
                for lbl_path in structure_info['labels']:
                    lbl_name = os.path.basename(lbl_path)
                    lbl_name_no_ext = os.path.splitext(lbl_name)[0]

                    if img_name_no_ext == lbl_name_no_ext:
                        if os.path.exists(lbl_path):
                            shutil.copy2(lbl_path, dest_lbl_dir)
                        break

            except Exception as e:
                self.log(f"❌ Error copying {img_name}: {e}")
                continue

        self.log(f"✅ Copied {files_copied} files preserving splits")
        return files_copied > 0

    def verify_dataset_structure(self, yolo_dirs):
        """Verify that the dataset structure is correct"""
        self.log("🔍 Verifying dataset structure...")

        required_dirs = ['train_images', 'train_labels', 'val_images', 'val_labels']

        for dir_name in required_dirs:
            dir_path = yolo_dirs[dir_name]

            if not os.path.exists(dir_path):
                self.log(f"❌ Missing directory: {dir_path}")
                return False

            # Count files
            if 'images' in dir_name:
                files = [f for f in os.listdir(dir_path) if f.lower().endswith(tuple(self.image_extensions))]
            else:
                files = [f for f in os.listdir(dir_path) if f.endswith('.txt')]

            file_count = len(files)
            self.log(f"   ✅ {dir_name}: {file_count} files")

            # Check for empty critical directories
            if file_count == 0:
                if dir_name in ['train_images', 'train_labels']:
                    self.log(f"❌ Critical directory is empty: {dir_name}")
                    return False
                elif dir_name in ['val_images', 'val_labels']:
                    self.log(f"⚠️  Validation directory is empty: {dir_name}")
                    # Try to create validation set from training data
                    return self.create_validation_from_training(yolo_dirs)

        self.log("✅ Dataset structure verification passed")
        return True

    def create_validation_from_training(self, yolo_dirs):
        """Create validation set from training data when validation is empty"""
        self.log("🔄 Creating validation set from training data...")

        train_images_dir = yolo_dirs['train_images']
        train_labels_dir = yolo_dirs['train_labels']
        val_images_dir = yolo_dirs['val_images']
        val_labels_dir = yolo_dirs['val_labels']

        # Get all training images
        train_images = [f for f in os.listdir(train_images_dir) if f.lower().endswith(tuple(self.image_extensions))]

        if len(train_images) < 2:
            self.log("❌ Not enough training images to create validation set")
            return False

        # Move 20% of training to validation (minimum 1, maximum 20)
        val_count = max(1, min(20, len(train_images) // 5))

        # Randomly select images for validation
        random.shuffle(train_images)
        val_images = train_images[:val_count]

        self.log(f"📦 Moving {len(val_images)} samples to validation...")

        moved_count = 0
        for img_file in val_images:
            img_name_no_ext = os.path.splitext(img_file)[0]

            # Move image
            src_img = os.path.join(train_images_dir, img_file)
            dst_img = os.path.join(val_images_dir, img_file)

            if os.path.exists(src_img):
                shutil.move(src_img, dst_img)
                moved_count += 1

                # Move corresponding label if exists
                label_file = img_name_no_ext + '.txt'
                src_label = os.path.join(train_labels_dir, label_file)
                dst_label = os.path.join(val_labels_dir, label_file)

                if os.path.exists(src_label):
                    shutil.move(src_label, dst_label)

        self.log(f"✅ Created validation set with {moved_count} samples")
        return moved_count > 0


class YOLO11Trainer:
    """GUI interface for YOLO11 training with improved dataset handling"""

    def __init__(self, root):
        if not GUI_AVAILABLE:
            raise RuntimeError("GUI not available - no display found")

        self.root = root
        self.root.title("YOLO11 Training Tool")
        self.root.geometry("900x700")

        # Initialize dataset manager
        self.dataset_manager = DatasetManager(self.log_message)

        # Variables
        self.data_path = tk.StringVar()
        self.selected_classes = []
        self.all_classes = []
        self.device_info = tk.StringVar()

        # Check GPU availability on startup
        self.check_gpu()

        # Create GUI
        self.create_widgets()

    def check_gpu(self):
        """Check for available GPUs and set up environment"""
        try:
            import torch
        except ImportError:
            self.device_info.set("⚠️ PyTorch not available")
            return

        try:
            if torch.cuda.is_available():
                gpu_count = torch.cuda.device_count()
                gpu_names = []
                for i in range(gpu_count):
                    try:
                        gpu_names.append(torch.cuda.get_device_name(i))
                    except Exception:
                        gpu_names.append(f"GPU {i}")

                device_text = f"✅ {gpu_count} GPU(s): {', '.join(gpu_names[:2])}"
                if len(gpu_names) > 2:
                    device_text += f" (+{len(gpu_names)-2} more)"
                os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpu_count)))
            else:
                device_text = "⚠️ No GPU available, will use CPU"
        except Exception as e:
            device_text = f"⚠️ GPU check failed: {str(e)[:50]}..."

        self.device_info.set(device_text)

    def create_widgets(self):
        # Main frame with scrollable content
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

        # GPU info
        ttk.Label(main_frame, text="Device Status:").grid(row=0, column=0, sticky=tk.W, pady=5)
        ttk.Label(main_frame, textvariable=self.device_info, wraplength=600).grid(row=0, column=1, columnspan=2, sticky=tk.W, pady=5)

        # File selection
        ttk.Label(main_frame, text="Training Data (ZIP):").grid(row=1, column=0, sticky=tk.W, pady=5)
        ttk.Entry(main_frame, textvariable=self.data_path, width=60).grid(row=1, column=1, sticky=(tk.W, tk.E), pady=5)
        ttk.Button(main_frame, text="Browse", command=self.select_zip_file).grid(row=1, column=2, pady=5, padx=(5,0))

        # Load data button
        ttk.Button(main_frame, text="🔍 Load & Analyze Data", command=self.load_data).grid(row=2, column=0, columnspan=3, pady=10)

        # Classes selection frame
        classes_frame = ttk.LabelFrame(main_frame, text="Select Objects to Train", padding="10")
        classes_frame.grid(row=3, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)

        # Classes listbox with scrollbar
        list_frame = ttk.Frame(classes_frame)
        list_frame.grid(row=0, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S))

        self.classes_listbox = tk.Listbox(list_frame, selectmode=tk.MULTIPLE, height=8)
        scrollbar_classes = ttk.Scrollbar(list_frame, orient=tk.VERTICAL, command=self.classes_listbox.yview)
        self.classes_listbox.configure(yscrollcommand=scrollbar_classes.set)

        self.classes_listbox.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        scrollbar_classes.grid(row=0, column=1, sticky=(tk.N, tk.S))

        # Buttons for class selection
        ttk.Button(classes_frame, text="Select All", command=self.select_all_classes).grid(row=1, column=0, pady=5, padx=(0,5))
        ttk.Button(classes_frame, text="Clear Selection", command=self.clear_selection).grid(row=1, column=1, pady=5, padx=5)

        # Training parameters
        params_frame = ttk.LabelFrame(main_frame, text="Training Parameters", padding="10")
        params_frame.grid(row=4, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=5)

        # Row 0: Epochs and Image Size
        ttk.Label(params_frame, text="Epochs:").grid(row=0, column=0, sticky=tk.W)
        self.epochs_var = tk.StringVar(value="100")
        ttk.Entry(params_frame, textvariable=self.epochs_var, width=10).grid(row=0, column=1, sticky=tk.W, padx=5)

        ttk.Label(params_frame, text="Image Size:").grid(row=0, column=2, sticky=tk.W, padx=(20,0))
        self.imgsz_var = tk.StringVar(value="640")
        ttk.Entry(params_frame, textvariable=self.imgsz_var, width=10).grid(row=0, column=3, sticky=tk.W, padx=5)

        # Row 1: Batch Size and Model
        ttk.Label(params_frame, text="Batch Size:").grid(row=1, column=0, sticky=tk.W)
        self.batch_var = tk.StringVar(value="16")
        ttk.Entry(params_frame, textvariable=self.batch_var, width=10).grid(row=1, column=1, sticky=tk.W, padx=5)

        ttk.Label(params_frame, text="Model:").grid(row=1, column=2, sticky=tk.W, padx=(20,0))
        self.model_var = tk.StringVar(value="yolo11n.pt")
        model_combo = ttk.Combobox(params_frame, textvariable=self.model_var, width=12, state="readonly")
        model_combo['values'] = ("yolo11n.pt", "yolo11s.pt", "yolo11m.pt", "yolo11l.pt", "yolo11x.pt")
        model_combo.grid(row=1, column=3, sticky=tk.W, padx=5)

        # Train button
        self.train_button = ttk.Button(main_frame, text="🚀 Start Training", command=self.start_training, state=tk.DISABLED)
        self.train_button.grid(row=5, column=0, columnspan=3, pady=20)

        # Progress and log
        self.progress = ttk.Progressbar(main_frame, mode='indeterminate')
        self.progress.grid(row=6, column=0, columnspan=3, sticky=(tk.W, tk.E), pady=5)

        # Log text area
        log_frame = ttk.LabelFrame(main_frame, text="Training Log", padding="5")
        log_frame.grid(row=7, column=0, columnspan=3, sticky=(tk.W, tk.E, tk.N, tk.S), pady=5)

        self.log_text = scrolledtext.ScrolledText(log_frame, height=12, width=80)
        self.log_text.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

        # Configure grid weights for responsiveness
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(3, weight=1)
        main_frame.rowconfigure(7, weight=1)
        classes_frame.columnconfigure(0, weight=1)
        classes_frame.rowconfigure(0, weight=1)
        list_frame.columnconfigure(0, weight=1)
        list_frame.rowconfigure(0, weight=1)
        log_frame.columnconfigure(0, weight=1)
        log_frame.rowconfigure(0, weight=1)

    def log_message(self, message):
        """Add message to log area with timestamp"""
        timestamp = time.strftime("%H:%M:%S")
        self.log_text.insert(tk.END, f"[{timestamp}] {message}\n")
        self.log_text.see(tk.END)
        self.root.update_idletasks()

    def select_zip_file(self):
        """Open file dialog to select zip file"""
        file_path = filedialog.askopenfilename(
            title="Select Training Data ZIP File",
            filetypes=[("ZIP files", "*.zip"), ("All files", "*.*")]
        )
        if file_path:
            self.data_path.set(file_path)

    def load_data(self):
        """Load and analyze the ZIP file"""
        if not self.data_path.get():
            messagebox.showerror("Error", "Please select a ZIP file first")
            return

        try:
            self.log_message("🚀 Starting data loading process...")

            # Extract ZIP file
            extract_path = "./temp_data"
            if os.path.exists(extract_path):
                shutil.rmtree(extract_path)
            os.makedirs(extract_path)

            self.log_message("📦 Extracting ZIP file...")
            with zipfile.ZipFile(self.data_path.get(), 'r') as zip_ref:
                zip_ref.extractall(extract_path)

            self.log_message("✅ ZIP file extracted successfully")

            # Analyze dataset structure
            structure_info = self.dataset_manager.analyze_dataset_structure(extract_path)

            if structure_info['total_images'] == 0:
                messagebox.showerror("Error", "No image files found in the dataset")
                return

            if structure_info['total_labels'] == 0:
                messagebox.showerror("Error", "No label files (.txt) found in the dataset")
                return

            # Reorganize dataset to YOLO format
            if not self.dataset_manager.reorganize_dataset(extract_path, structure_info):
                messagebox.showerror("Error", "Failed to reorganize dataset properly")
                return

            # Extract class information
            all_class_ids = sorted(list(structure_info['class_ids']))
            if not all_class_ids:
                messagebox.showerror("Error", "No valid class labels found in dataset")
                return

            # Load class names
            self.all_classes = self.load_class_names(extract_path, all_class_ids)

            # Populate listbox
            self.classes_listbox.delete(0, tk.END)
            for i, class_name in enumerate(self.all_classes):
                self.classes_listbox.insert(tk.END, f"{all_class_ids[i]}: {class_name}")

            self.log_message(f"✅ Found {len(self.all_classes)} object classes")
            self.train_button.config(state=tk.NORMAL)

        except Exception as e:
            error_msg = f"Failed to load data: {str(e)}"
            self.log_message(f"❌ {error_msg}")
            messagebox.showerror("Error", error_msg)

    def load_class_names(self, extract_path, all_class_ids):
        """Load class names from various sources"""
        # Start with default names
        class_names = [f"Class_{i}" for i in all_class_ids]

        # Try to load from classes.txt or YAML files
        for root, dirs, files in os.walk(extract_path):
            # Check for classes.txt
            if 'classes.txt' in files:
                try:
                    with open(os.path.join(root, 'classes.txt'), 'r') as f:
                        names = [line.strip() for line in f if line.strip()]
                        if len(names) >= max(all_class_ids) + 1:
                            class_names = [names[i] for i in all_class_ids]
                            self.log_message("📝 Loaded class names from classes.txt")
                            return class_names
                except Exception as e:
                    self.log_message(f"⚠️ Error reading classes.txt: {e}")

            # Check for YAML files
            for file in files:
                if file.endswith(('.yaml', '.yml')):
                    try:
                        with open(os.path.join(root, file), 'r') as f:
                            data = yaml.safe_load(f)
                            if 'names' in data:
                                names = data['names']
                                if isinstance(names, list) and len(names) >= max(all_class_ids) + 1:
                                    class_names = [names[i] for i in all_class_ids]
                                    self.log_message(f"📝 Loaded class names from {file}")
                                    return class_names
                                elif isinstance(names, dict):
                                    class_names = [names.get(i, f"Class_{i}") for i in all_class_ids]
                                    self.log_message(f"📝 Loaded class names from {file}")
                                    return class_names
                    except Exception as e:
                        self.log_message(f"⚠️ Error reading {file}: {e}")

        self.log_message("📝 Using default class names")
        return class_names

    def select_all_classes(self):
        """Select all classes in the listbox"""
        self.classes_listbox.select_set(0, tk.END)

    def clear_selection(self):
        """Clear all selections in the listbox"""
        self.classes_listbox.selection_clear(0, tk.END)

    def start_training(self):
        """Start YOLO training in a separate thread"""
        selected_indices = [int(self.classes_listbox.get(i).split(':')[0])
                          for i in self.classes_listbox.curselection()]

        if not selected_indices:
            messagebox.showerror("Error", "Please select at least one class to train")
            return

        # Disable training button and start progress
        self.train_button.config(state=tk.DISABLED)
        self.progress.start()

        # Start training in separate thread
        training_thread = threading.Thread(target=self.train_model, args=(selected_indices,))
        training_thread.daemon = True
        training_thread.start()

    def train_model(self, selected_indices):
        """Train YOLO model with improved error handling"""
        try:
            import torch
            from ultralytics import YOLO
        except ImportError as e:
            error_msg = f"Required packages not available: {e}"
            self.log_message(f"❌ {error_msg}")
            self.root.after(0, lambda: [
                self.train_button.config(state=tk.NORMAL),
                self.progress.stop(),
                messagebox.showerror("Error", error_msg)
            ])
            return

        try:
            self.log_message(f"🚀 Starting training with {len(selected_indices)} classes...")

            # Create dataset configuration
            class_mapping = self.create_dataset_yaml(selected_indices)

            # Filter labels to only include selected classes
            self.filter_labels(class_mapping)

            # Final verification
            if not self.final_dataset_check():
                error_msg = "Dataset verification failed"
                self.log_message(f"❌ {error_msg}")
                self.root.after(0, lambda: [
                    self.train_button.config(state=tk.NORMAL),
                    self.progress.stop(),
                    messagebox.showerror("Error", error_msg)
                ])
                return

            # Initialize YOLO model
            model_name = self.model_var.get()
            self.log_message(f"🤖 Loading {model_name} model...")
            model = YOLO(model_name)

            # Training parameters
            epochs = int(self.epochs_var.get())
            imgsz = int(self.imgsz_var.get())
            batch = int(self.batch_var.get())

            # Adjust batch size for GPU memory
            if torch.cuda.is_available():
                try:
                    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
                    if gpu_memory < 8 and batch > 8:
                        batch = 8
                        self.log_message(f"📉 Reduced batch size to {batch} for GPU memory")
                except Exception:
                    pass

            # Start training
            self.log_message("🏃 Training started...")
            self.log_message(f"Parameters: epochs={epochs}, imgsz={imgsz}, batch={batch}")

            yaml_path = os.path.abspath('./dataset.yaml')
            self.log_message(f"📄 Using dataset config: {yaml_path}")

            results = model.train(
                data=yaml_path,
                epochs=epochs,
                imgsz=imgsz,
                batch=batch,
                device='0' if torch.cuda.is_available() else 'cpu',
                project='./runs/train',
                name='yolo11_custom',
                exist_ok=True,
                verbose=True,
                patience=20,
                save_period=max(10, epochs // 10)
            )

            self.log_message("🎉 Training completed successfully!")
            self.log_message(f"📁 Model saved to: {results.save_dir}")

            # Re-enable training button
            self.root.after(0, lambda: [
                self.train_button.config(state=tk.NORMAL),
                self.progress.stop(),
                messagebox.showinfo("Success", f"Training completed!\nModel saved to: {results.save_dir}")
            ])

        except Exception as e:
            error_msg = f"Training failed: {str(e)}"
            self.log_message(f"❌ {error_msg}")
            self.log_message("💡 Troubleshooting tips:")
            self.log_message("• Check dataset paths and file permissions")
            self.log_message("• Try reducing batch size or image size")
            self.log_message("• Verify all images and labels are valid")

            self.root.after(0, lambda: [
                self.train_button.config(state=tk.NORMAL),
                self.progress.stop(),
                messagebox.showerror("Error", error_msg)
            ])

    def create_dataset_yaml(self, selected_indices):
        """Create dataset.yaml file for YOLO training"""
        class_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(selected_indices)}
        selected_names = [self.all_classes[i] for i in selected_indices]

        temp_data_path = os.path.abspath('./temp_data')

        dataset_config = {
            'path': temp_data_path,
            'train': 'train/images',
            'val': 'val/images',
            'test': 'test/images',
            'nc': len(selected_indices),
            'names': selected_names
        }

        yaml_path = './dataset.yaml'
        with open(yaml_path, 'w') as f:
            yaml.dump(dataset_config, f, default_flow_style=False)

        self.log_message(f"📄 Created dataset.yaml with {len(selected_indices)} classes")

        return class_mapping

    def filter_labels(self, class_mapping):
        """Filter label files to only include selected classes"""
        self.log_message("🔄 Filtering labels for selected classes...")

        label_dirs = [
            './temp_data/train/labels',
            './temp_data/val/labels',
            './temp_data/test/labels'
        ]

        total_filtered = 0

        for label_dir in label_dirs:
            if os.path.exists(label_dir):
                label_files = [f for f in os.listdir(label_dir) if f.endswith('.txt')]

                for file in label_files:
                    label_path = os.path.join(label_dir, file)
                    if self.filter_label_file(label_path, class_mapping):
                        total_filtered += 1

        self.log_message(f"✅ Filtered {total_filtered} label files")

    def filter_label_file(self, label_path, class_mapping):
        """Filter individual label file"""
        try:
            with open(label_path, 'r') as f:
                lines = f.readlines()

            filtered_lines = []

            for line in lines:
                if line.strip():
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        try:
                            class_id = int(parts[0])
                            if class_id in class_mapping:
                                parts[0] = str(class_mapping[class_id])
                                filtered_lines.append(' '.join(parts) + '\n')
                        except ValueError:
                            continue

            with open(label_path, 'w') as f:
                f.writelines(filtered_lines)

            return True

        except Exception as e:
            self.log_message(f"⚠️ Error filtering {label_path}: {e}")
            return False

    def final_dataset_check(self):
        """Final check before training starts"""
        self.log_message("🔍 Final dataset verification...")

        # Check required paths
        required_paths = [
            './temp_data/train/images',
            './temp_data/train/labels',
            './temp_data/val/images',
            './temp_data/val/labels',
            './dataset.yaml'
        ]

        for path in required_paths:
            if not os.path.exists(path):
                self.log_message(f"❌ Missing: {path}")
                return False

        # Count files
        train_images = len([f for f in os.listdir('./temp_data/train/images')
                           if f.lower().endswith(tuple(self.dataset_manager.image_extensions))])
        val_images = len([f for f in os.listdir('./temp_data/val/images')
                         if f.lower().endswith(tuple(self.dataset_manager.image_extensions))])

        if train_images == 0:
            self.log_message("❌ No training images found")
            return False

        if val_images == 0:
            self.log_message("❌ No validation images found")
            return False

        self.log_message(f"✅ Final check passed: {train_images} train, {val_images} val images")
        return True


class YOLO11TrainerCLI:
    """Command-line interface for YOLO11 training with improved dataset handling"""

    def __init__(self):
        self.all_classes = []
        self.data_path = ""
        self.dataset_manager = DatasetManager(print)
        self.check_gpu()

    def check_gpu(self):
        """Check for available GPUs with better error handling"""
        try:
            import torch
        except ImportError:
            print("⚠️ PyTorch not available, cannot check GPU status")
            return

        try:
            if torch.cuda.is_available():
                gpu_count = torch.cuda.device_count()
                gpu_names = []
                for i in range(gpu_count):
                    try:
                        gpu_names.append(torch.cuda.get_device_name(i))
                    except Exception:
                        gpu_names.append(f"GPU {i}")

                print(f"✅ {gpu_count} GPU(s) available: {', '.join(gpu_names)}")
                os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, range(gpu_count)))
            else:
                print("⚠️ No GPU available, will use CPU")
        except Exception as e:
            print(f"⚠️ GPU check failed: {e}")

    def is_valid_zip(self, file_path):
        """Check if file is a valid ZIP archive"""
        try:
            with zipfile.ZipFile(file_path, 'r') as zip_ref:
                bad_file = zip_ref.testzip()
                if bad_file:
                    print(f"⚠️ ZIP file contains corrupted file: {bad_file}")
                    return False
                return True
        except zipfile.BadZipFile:
            print("❌ File is not a valid ZIP file")
            return False
        except Exception as e:
            print(f"❌ Error reading ZIP file: {e}")
            return False

    def load_data(self, zip_path):
        """Load and analyze the ZIP file with comprehensive error handling"""
        self.data_path = zip_path

        if not self.is_valid_zip(zip_path):
            return False

        try:
            print("🚀 Starting data loading process...")

            # Extract ZIP file
            extract_path = "./temp_data"
            if os.path.exists(extract_path):
                shutil.rmtree(extract_path)
            os.makedirs(extract_path)

            print("📦 Extracting ZIP file...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_path)

            print("✅ ZIP file extracted successfully")

            # Analyze dataset structure
            structure_info = self.dataset_manager.analyze_dataset_structure(extract_path)

            if structure_info['total_images'] == 0:
                print("❌ No image files found in dataset")
                return False

            if structure_info['total_labels'] == 0:
                print("❌ No label files found in dataset")
                return False

            # Reorganize dataset
            if not self.dataset_manager.reorganize_dataset(extract_path, structure_info):
                print("❌ Failed to reorganize dataset")
                return False

            # Extract class information
            all_class_ids = sorted(list(structure_info['class_ids']))
            if not all_class_ids:
                print("❌ No valid class labels found")
                return False

            # Load class names
            self.all_classes = self.load_class_names(extract_path, all_class_ids)

            print(f"✅ Successfully loaded dataset with {len(self.all_classes)} classes:")
            for i, class_name in enumerate(self.all_classes):
                print(f"   {all_class_ids[i]:2d}: {class_name}")

            return True

        except Exception as e:
            print(f"❌ Failed to load data: {str(e)}")
            return False

    def load_class_names(self, extract_path, all_class_ids):
        """Load class names from various sources"""
        class_names = [f"Class_{i}" for i in all_class_ids]

        for root, dirs, files in os.walk(extract_path):
            if 'classes.txt' in files:
                try:
                    with open(os.path.join(root, 'classes.txt'), 'r') as f:
                        names = [line.strip() for line in f if line.strip()]
                        if len(names) >= max(all_class_ids) + 1:
                            class_names = [names[i] for i in all_class_ids]
                            print("📝 Loaded class names from classes.txt")
                            return class_names
                except Exception as e:
                    print(f"⚠️ Error reading classes.txt: {e}")

            for file in files:
                if file.endswith(('.yaml', '.yml')):
                    try:
                        with open(os.path.join(root, file), 'r') as f:
                            data = yaml.safe_load(f)
                            if 'names' in data:
                                names = data['names']
                                if isinstance(names, list) and len(names) >= max(all_class_ids) + 1:
                                    class_names = [names[i] for i in all_class_ids]
                                    print(f"📝 Loaded class names from {file}")
                                    return class_names
                                elif isinstance(names, dict):
                                    class_names = [names.get(i, f"Class_{i}") for i in all_class_ids]
                                    print(f"📝 Loaded class names from {file}")
                                    return class_names
                    except Exception as e:
                        print(f"⚠️ Error reading {file}: {e}")

        print("📝 Using default class names")
        return class_names

    def select_classes_interactive(self):
        """Interactive class selection with better UX"""
        print(f"\n📋 Available Classes ({len(self.all_classes)} total):")

        for i, class_name in enumerate(self.all_classes):
            print(f"   {i:2d}: {class_name}")

        print(f"\n🎯 Select classes to train:")
        print("   • Type 'all' for all classes")
        print("   • Type numbers separated by commas (e.g., 0,2,5)")
        print("   • Type ranges with dashes (e.g., 0-5,8,10-12)")
        print("   • Press Ctrl+C to cancel")

        while True:
            try:
                selection = input("\n➤ Enter your selection: ").strip()

                if selection.lower() == 'all':
                    selected_indices = list(range(len(self.all_classes)))
                    break

                selected_indices = self.parse_selection(selection)

                if selected_indices:
                    break
                else:
                    print("❌ No valid classes selected. Please try again.")

            except KeyboardInterrupt:
                print("\n🚫 Selection cancelled by user.")
                sys.exit(0)
            except Exception as e:
                print(f"❌ Error: {e}. Please try again.")

        print(f"\n✅ Selected {len(selected_indices)} classes:")
        for i in selected_indices:
            print(f"   {i:2d}: {self.all_classes[i]}")

        return selected_indices

    def parse_selection(self, selection):
        """Parse user selection string"""
        selected_indices = []
        parts = selection.split(',')

        for part in parts:
            part = part.strip()
            if '-' in part:
                try:
                    start, end = map(int, part.split('-'))
                    selected_indices.extend(range(start, end + 1))
                except ValueError:
                    print(f"❌ Invalid range format: {part}")
                    continue
            else:
                try:
                    selected_indices.append(int(part))
                except ValueError:
                    print(f"❌ Invalid number: {part}")
                    continue

        # Remove duplicates and validate
        selected_indices = list(set(selected_indices))
        valid_indices = [i for i in selected_indices if 0 <= i < len(self.all_classes)]

        if len(valid_indices) != len(selected_indices):
            invalid = [i for i in selected_indices if i not in valid_indices]
            print(f"⚠️ Invalid indices ignored: {invalid}")

        return sorted(valid_indices)

    def get_training_parameters(self):
        """Get training parameters from user"""
        print(f"\n⚙️ Training Configuration:")
        print("Press Enter to use default values shown in parentheses")

        params = {}

        try:
            epochs_input = input("Epochs (100): ").strip()
            params['epochs'] = int(epochs_input) if epochs_input else 100

            imgsz_input = input("Image size (640): ").strip()
            params['imgsz'] = int(imgsz_input) if imgsz_input else 640

            batch_input = input("Batch size (16): ").strip()
            params['batch'] = int(batch_input) if batch_input else 16

            models = ["yolo11n.pt", "yolo11s.pt", "yolo11m.pt", "yolo11l.pt", "yolo11x.pt"]
            print(f"\nAvailable models:")
            for i, model in enumerate(models):
                print(f"   {i}: {model}")

            model_input = input("Select model (0 for nano): ").strip()
            model_idx = int(model_input) if model_input else 0
            params['model'] = models[model_idx] if 0 <= model_idx < len(models) else models[0]

        except ValueError:
            print("⚠️ Invalid input detected. Using default parameters...")
            params = {'epochs': 100, 'imgsz': 640, 'batch': 16, 'model': 'yolo11n.pt'}
        except KeyboardInterrupt:
            print("\n🚫 Configuration cancelled by user.")
            sys.exit(0)

        return params

    def train_model(self, selected_indices, params):
        """Train YOLO model with comprehensive error handling"""
        try:
            import torch
            from ultralytics import YOLO
        except ImportError as e:
            print(f"❌ Required packages not available: {e}")
            return False

        try:
            print(f"🚀 Starting training with {len(selected_indices)} classes...")

            # Create dataset configuration
            class_mapping = self.create_dataset_yaml(selected_indices)

            # Filter labels
            self.filter_labels(class_mapping)

            # Final verification
            if not self.final_dataset_check():
                print("❌ Dataset verification failed")
                return False

            # Initialize model
            print(f"🤖 Loading {params['model']} model...")
            model = YOLO(params['model'])

            # Adjust batch size for GPU memory if needed
            if torch.cuda.is_available():
                try:
                    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
                    if gpu_memory < 8 and params['batch'] > 8:
                        params['batch'] = 8
                        print(f"📉 Reduced batch size to {params['batch']} for GPU memory")
                except Exception:
                    pass

            print("🏃 Training started...")
            print(f"Parameters: epochs={params['epochs']}, imgsz={params['imgsz']}, batch={params['batch']}")

            yaml_path = os.path.abspath('./dataset.yaml')
            print(f"📄 Using dataset config: {yaml_path}")

            # Start training
            results = model.train(
                data=yaml_path,
                epochs=params['epochs'],
                imgsz=params['imgsz'],
                batch=params['batch'],
                device='0' if torch.cuda.is_available() else 'cpu',
                project='./runs/train',
                name='yolo11_custom',
                exist_ok=True,
                verbose=True,
                patience=20,
                save_period=max(10, params['epochs'] // 10)
            )

            print("🎉 Training completed successfully!")
            print(f"📁 Model saved to: {results.save_dir}")
            return True

        except Exception as e:
            print(f"❌ Training failed: {str(e)}")
            print("\n🔧 Troubleshooting tips:")
            print("• Check dataset paths and file permissions")
            print("• Try reducing batch size or image size")
            print("• Verify all images and labels are valid")
            print("• Check available disk space")
            return False

    def create_dataset_yaml(self, selected_indices):
        """Create dataset.yaml configuration"""
        class_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(selected_indices)}
        selected_names = [self.all_classes[i] for i in selected_indices]

        dataset_config = {
            'path': os.path.abspath('./temp_data'),
            'train': 'train/images',
            'val': 'val/images',
            'test': 'test/images',
            'nc': len(selected_indices),
            'names': selected_names
        }

        with open('./dataset.yaml', 'w') as f:
            yaml.dump(dataset_config, f, default_flow_style=False)

        print(f"📄 Created dataset.yaml with {len(selected_indices)} classes")
        return class_mapping

    def filter_labels(self, class_mapping):
        """Filter label files for selected classes"""
        print("🔄 Filtering labels...")

        for root, dirs, files in os.walk('./temp_data'):
            if 'labels' in root:
                for file in files:
                    if file.endswith('.txt'):
                        self.filter_label_file(os.path.join(root, file), class_mapping)

    def filter_label_file(self, label_path, class_mapping):
        """Filter individual label file"""
        try:
            with open(label_path, 'r') as f:
                lines = f.readlines()

            filtered_lines = []
            for line in lines:
                if line.strip():
                    parts = line.strip().split()
                    if len(parts) >= 5:
                        try:
                            class_id = int(parts[0])
                            if class_id in class_mapping:
                                parts[0] = str(class_mapping[class_id])
                                filtered_lines.append(' '.join(parts) + '\n')
                        except ValueError:
                            continue

            with open(label_path, 'w') as f:
                f.writelines(filtered_lines)

        except Exception as e:
            print(f"⚠️ Error filtering {label_path}: {e}")

    def final_dataset_check(self):
        """Final comprehensive check before training"""
        print("🔍 Final dataset verification...")

        # Check required paths
        required_paths = [
            './temp_data/train/images',
            './temp_data/train/labels',
            './temp_data/val/images',
            './temp_data/val/labels',
            './dataset.yaml'
        ]

        for path in required_paths:
            if not os.path.exists(path):
                print(f"❌ Missing: {path}")
                return False

        # Count files and verify content
        train_images = len([f for f in os.listdir('./temp_data/train/images')
                           if f.lower().endswith(tuple(self.dataset_manager.image_extensions))])
        train_labels = len([f for f in os.listdir('./temp_data/train/labels') if f.endswith('.txt')])
        val_images = len([f for f in os.listdir('./temp_data/val/images')
                         if f.lower().endswith(tuple(self.dataset_manager.image_extensions))])
        val_labels = len([f for f in os.listdir('./temp_data/val/labels') if f.endswith('.txt')])

        print(f"📊 Dataset summary:")
        print(f"   🏋️ Training: {train_images} images, {train_labels} labels")
        print(f"   ✅ Validation: {val_images} images, {val_labels} labels")

        if train_images == 0:
            print("❌ No training images found")
            return False

        if val_images == 0:
            print("❌ No validation images found")
            return False

        # Verify dataset.yaml content
        try:
            with open('./dataset.yaml', 'r') as f:
                config = yaml.safe_load(f)
                required_keys = ['path', 'train', 'val', 'nc', 'names']
                for key in required_keys:
                    if key not in config:
                        print(f"❌ Missing key in dataset.yaml: {key}")
                        return False
        except Exception as e:
            print(f"❌ Error reading dataset.yaml: {e}")
            return False

        print("✅ Final verification passed - ready for training!")
        return True

    def run_interactive(self):
        """Run interactive CLI training"""
        print("\n🚀 YOLO11 Training Tool (Interactive Mode)")
        print("=" * 50)

        # Get ZIP file
        while True:
            print(f"\n📁 Training Data:")
            zip_path = input("Enter path to ZIP file: ").strip().strip('"\'')

            if not os.path.exists(zip_path):
                print("❌ File not found. Please try again.")
                continue

            if not zip_path.lower().endswith('.zip'):
                print("❌ Please provide a ZIP file.")
                continue

            break

        # Load data
        if not self.load_data(zip_path):
            print("❌ Failed to load data. Exiting.")
            return

        # Select classes
        selected_indices = self.select_classes_interactive()

        # Get training parameters
        params = self.get_training_parameters()

        # Show summary
        print(f"\n📋 Training Summary:")
        print(f"   📊 Dataset: {os.path.basename(zip_path)}")
        print(f"   🎯 Classes: {len(selected_indices)} selected")
        print(f"   🤖 Model: {params['model']}")
        print(f"   📈 Epochs: {params['epochs']}")
        print(f"   🖼️ Image size: {params['imgsz']}")
        print(f"   📦 Batch size: {params['batch']}")

        # Confirm training
        print(f"\n⚡ Ready to start training!")
        try:
            confirm = input("Continue? (y/N): ").strip().lower()
            if confirm in ['y', 'yes']:
                success = self.train_model(selected_indices, params)
                if success:
                    print("\n🎉 Training completed successfully!")
                    print("📁 Check './runs/train/yolo11_custom' for results")
                else:
                    print("\n❌ Training failed.")
            else:
                print("🚫 Training cancelled.")
        except KeyboardInterrupt:
            print("\n🚫 Training cancelled by user.")


def train_yolo_simple(zip_path=None, classes="all", epochs=100, imgsz=640, batch=16, model="yolo11n.pt"):
    """
    Simple function for Jupyter/Colab environments with robust dataset handling
    """

    print("🚀 YOLO11 Simple Training")
    print("=" * 30)

    # Check packages first
    if not check_and_install_packages():
        print("❌ Failed to install required packages")
        return

    cli = YOLO11TrainerCLI()

    if zip_path is None:
        print("Interactive mode - please provide input when prompted")
        cli.run_interactive()
        return

    # Validate ZIP file
    if not os.path.exists(zip_path):
        print(f"❌ File not found: {zip_path}")
        return

    if not zip_path.lower().endswith('.zip'):
        print(f"❌ File is not a ZIP file: {zip_path}")
        return

    # Load data
    if not cli.load_data(zip_path):
        print("❌ Failed to load data")
        return

    # Parse classes
    if isinstance(classes, str):
        if classes.lower() == "all":
            selected_indices = list(range(len(cli.all_classes)))
        else:
            try:
                selected_indices = [int(x.strip()) for x in classes.split(',')]
            except ValueError:
                print("❌ Invalid class format. Use 'all' or '0,1,2'")
                return
    elif isinstance(classes, list):
        selected_indices = classes
    else:
        print("❌ Classes must be 'all', '0,1,2', or [0,1,2]")
        return

    # Validate indices
    valid_indices = [i for i in selected_indices if 0 <= i < len(cli.all_classes)]
    if len(valid_indices) != len(selected_indices):
        print("⚠️ Some class indices were invalid and ignored")

    if not valid_indices:
        print("❌ No valid class indices provided")
        return

    print(f"✅ Selected classes: {[cli.all_classes[i] for i in valid_indices]}")

    # Train model
    params = {
        'epochs': epochs,
        'imgsz': imgsz,
        'batch': batch,
        'model': model
    }

    success = cli.train_model(valid_indices, params)

    if success:
        print("\n🎉 Training completed successfully!")
        print("📁 Check './runs/train/yolo11_custom' for results")
    else:
        print("\n❌ Training failed. Check error messages above.")


def is_jupyter_environment():
    """Check if running in Jupyter/Colab environment"""
    try:
        from IPython import get_ipython
        return get_ipython() is not None
    except ImportError:
        return False


def main():
    """Main entry point with improved error handling"""
    print("🚀 YOLO11 Training Tool - Initializing...")

    # Check and install packages first
    if not check_and_install_packages():
        print("❌ Failed to set up required packages. Exiting.")
        return

    # Check environment
    if is_jupyter_environment():
        print("🔬 Detected Jupyter/Colab environment")
        print("💡 Use train_yolo_simple() function for easy training:")
        print("   train_yolo_simple('/path/to/data.zip', classes='all', epochs=100)")
        print("\n📋 Starting interactive CLI mode...")
        cli = YOLO11TrainerCLI()
        cli.run_interactive()
        return

    # Parse command line arguments
    parser = argparse.ArgumentParser(
        description='YOLO11 Training Tool - Fixed dataset handling version',
        epilog='''
Examples:
  python yolo11_trainer.py --cli
  python yolo11_trainer.py --zip data.zip --classes all --epochs 50
  python yolo11_trainer.py --zip data.zip --classes 0,2,5 --epochs 100
        ''',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )

    parser.add_argument('--gui', action='store_true', help='Force GUI mode')
    parser.add_argument('--cli', action='store_true', help='Force CLI mode')
    parser.add_argument('--zip', type=str, help='Path to training data ZIP file')
    parser.add_argument('--classes', type=str, help='Class indices ("all" or "0,1,2")')
    parser.add_argument('--epochs', type=int, default=100, help='Training epochs')
    parser.add_argument('--imgsz', type=int, default=640, help='Image size')
    parser.add_argument('--batch', type=int, default=16, help='Batch size')
    parser.add_argument('--model', type=str, default='yolo11n.pt', help='YOLO model')

    try:
        args = parser.parse_args()
    except SystemExit:
        print("Starting interactive CLI mode...")
        cli = YOLO11TrainerCLI()
        cli.run_interactive()
        return

    # Determine interface mode
    if args.cli or (not GUI_AVAILABLE and not args.gui):
        # CLI mode
        cli = YOLO11TrainerCLI()

        if args.zip and args.classes:
            # Non-interactive mode
            if cli.load_data(args.zip):
                if args.classes.lower() == 'all':
                    selected_indices = list(range(len(cli.all_classes)))
                else:
                    selected_indices = cli.parse_selection(args.classes)

                if selected_indices:
                    params = {
                        'epochs': args.epochs,
                        'imgsz': args.imgsz,
                        'batch': args.batch,
                        'model': args.model
                    }
                    cli.train_model(selected_indices, params)
                else:
                    print("❌ No valid class indices provided")
        else:
            # Interactive mode
            cli.run_interactive()

    elif args.gui or GUI_AVAILABLE:
        # GUI mode
        if not GUI_AVAILABLE:
            print("❌ GUI not available. Use --cli flag.")
            return

        root = tk.Tk()
        app = YOLO11Trainer(root)
        root.mainloop()

    else:
        print("❌ No interface available. Use --cli flag.")


if __name__ == "__main__":
    main()
else:
    # When imported as module
    if is_jupyter_environment():
        print("🔬 YOLO11 Training Tool loaded in Jupyter/Colab")
        print("💡 Quick start:")
        print("   train_yolo_simple('/path/to/data.zip', classes='all', epochs=100)")
        print("   train_yolo_simple('/path/to/data.zip', classes='0,2,5', epochs=50)")

Trained Model Improvement with additional epoches

In [None]:
"""
YOLO11 Model Improvement - Strategy 1: Optimized Fine-tuning (FIXED)
===================================================================

This script continues training your existing YOLO11 model with optimized settings
to improve performance, especially for Traffic Light and Central Line detection.

Current Performance:
- Overall mAP50: 77.2%
- Traffic Light: 57.6% mAP50 (CRITICAL - needs major improvement)
- Central Line: 75.4% mAP50 (needs improvement)

Target Performance:
- Overall mAP50: 83%+
- Traffic Light: 76%+ mAP50
- Central Line: 83%+ mAP50
"""

import os
import time
from ultralytics import YOLO

def print_current_status():
    """Print current model status and improvement goals"""
    print("=" * 60)
    print("🚀 YOLO11 Traffic Detection Model Improvement")
    print("=" * 60)
    print("📊 Current Performance:")
    print("   Overall mAP50: 77.2%")
    print("   Overall Recall: 65.1%")
    print("   Traffic Light: 57.6% mAP50 (🚨 CRITICAL)")
    print("   Central Line: 75.4% mAP50 (⚠️ POOR)")
    print("   Lane: 85.8% mAP50 (✅ GOOD)")
    print("")
    print("🎯 Target Goals:")
    print("   Overall mAP50: 83%+")
    print("   Overall Recall: 75%+")
    print("   Traffic Light: 76%+ mAP50")
    print("   Central Line: 83%+ mAP50")
    print("=" * 60)

def verify_model_exists():
    """Verify that the trained model exists"""
    model_path = './runs/train/yolo11_custom/weights/best.pt'

    if not os.path.exists(model_path):
        print("❌ ERROR: Trained model not found!")
        print(f"   Expected location: {model_path}")
        print("")
        print("💡 Alternative locations to check:")
        print("   • ./runs/train/yolo11_custom/weights/last.pt")
        print("   • ./runs/train/*/weights/best.pt")
        print("   • ./best.pt")
        return False

    print(f"✅ Found trained model: {model_path}")
    return True

def train_strategy_1_optimized():
    """
    Strategy 1: Optimized Fine-tuning for Traffic Detection

    This strategy focuses on:
    1. Lower learning rate for fine-tuning
    2. Enhanced augmentation for small objects (traffic lights)
    3. Optimized loss weights for better localization
    4. Extended training with patience
    """

    print_current_status()

    # Verify model exists
    if not verify_model_exists():
        return None

    print("\n🔄 Loading your trained model...")

    try:
        # Load your existing trained model
        model = YOLO('./runs/train/yolo11_custom/weights/best.pt')
        print("✅ Model loaded successfully!")
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        print("💡 Try using 'last.pt' instead:")
        try:
            model = YOLO('./runs/train/yolo11_custom/weights/last.pt')
            print("✅ Model loaded from last.pt!")
        except:
            print("❌ Could not load model. Please check the path.")
            return None

    print("\n🚀 Starting Strategy 1: Optimized Fine-tuning...")
    print("⏱️  Estimated training time: ~45-60 minutes")
    print("📊 Training for 200 additional epochs")

    start_time = time.time()

    try:
        results = model.train(
            # 📁 CORE CONFIGURATION
            data='./dataset.yaml',              # Your existing dataset
            epochs=200,                         # Extended training (200 MORE epochs)
            imgsz=640,                          # Image size
            batch=16,                           # Batch size
            device='0',                         # GPU device
            project='./runs/train',             # Project directory
            name='yolo11_traffic_optimized',    # New experiment name
            exist_ok=True,                      # Overwrite if exists

            # 🎯 LEARNING RATE OPTIMIZATION (Fine-tuning settings)
            lr0=0.0003,                         # Lower initial learning rate for fine-tuning
            lrf=0.005,                          # Very low final learning rate (0.5% of initial)
            momentum=0.9,                       # Momentum for SGD optimizer
            weight_decay=0.0005,                # Weight decay for regularization
            warmup_epochs=5,                    # Warmup epochs for stable start
            warmup_momentum=0.8,                # Warmup momentum
            warmup_bias_lr=0.1,                 # Warmup bias learning rate

            # ⏰ TRAINING CONTROL
            patience=50,                        # Early stopping patience (more than default)
            close_mosaic=10,                    # Epochs to close mosaic augmentation

            # 🔧 LOSS FUNCTION WEIGHTS (Optimized for traffic detection)
            box=10.0,                           # Box loss weight (INCREASED for better localization)
            cls=1.0,                            # Classification loss weight (balanced)
            dfl=2.0,                            # Distribution focal loss weight (better edges)

            # 📈 DATA AUGMENTATION (Optimized for road scenes and small objects)
            hsv_h=0.015,                        # Hue augmentation (slight color variation)
            hsv_s=0.8,                          # Saturation augmentation (strong - traffic lights vary)
            hsv_v=0.6,                          # Value/brightness augmentation (day/night variation)
            degrees=10,                         # Rotation augmentation (roads don't rotate much)
            translate=0.15,                     # Translation augmentation (camera movement)
            scale=0.8,                          # Scale augmentation (distance changes)
            shear=1.0,                          # Shear augmentation (minimal for road scenes)
            perspective=0.0,                    # Perspective augmentation (disabled)
            flipud=0.0,                         # Vertical flip (disabled - roads don't flip)
            fliplr=0.3,                         # Horizontal flip (some roads are bidirectional)

            # 🎯 ADVANCED AUGMENTATION (Critical for small object detection)
            mosaic=1.0,                         # Mosaic augmentation (always enabled)
            mixup=0.2,                          # Mixup augmentation (increased for variety)
            copy_paste=0.4,                     # Copy-paste augmentation (helps small objects)

            # 🔍 DETECTION SETTINGS
            conf=0.001,                         # Confidence threshold for NMS
            iou=0.6,                            # IoU threshold for NMS
            max_det=300,                        # Maximum detections per image

            # 📊 TRAINING SETTINGS
            workers=8,                          # Number of worker threads
            seed=0,                             # Random seed for reproducibility
            deterministic=True,                 # Deterministic training
            single_cls=False,                   # Multi-class training
            rect=False,                         # Rectangular training (disabled for augmentation)
            cos_lr=True,                        # Cosine learning rate scheduler

            # 💾 SAVING AND MONITORING
            save=True,                          # Save checkpoints
            save_period=20,                     # Save every 20 epochs
            cache=False,                        # Cache images to RAM (disabled to save memory)
            plots=True,                         # Generate training plots
            overlap_mask=True,                  # Overlap mask for segmentation
            mask_ratio=4,                       # Mask downsample ratio
            dropout=0.0,                        # Dropout (disabled)
            val=True,                           # Validate during training
            split='val',                        # Validation split
            verbose=True,                       # Verbose output

            # 🎛️ OPTIMIZER SETTINGS
            optimizer='SGD',                    # Optimizer type
            amp=True,                           # Automatic Mixed Precision
            fraction=1.0,                       # Dataset fraction to use
            profile=False,                      # Profile ONNX and TensorRT speeds
            freeze=None,                        # Freeze layers
            multi_scale=False,                  # Multi-scale training

            # 🎯 ADDITIONAL SETTINGS
            nbs=64,                             # Nominal batch size
        )

        training_time = time.time() - start_time

        print("\n" + "=" * 60)
        print("🎉 TRAINING COMPLETED SUCCESSFULLY!")
        print("=" * 60)
        print(f"⏱️  Training time: {training_time/3600:.1f} hours")
        print(f"📁 Results saved to: {results.save_dir}")
        print(f"🏆 Best model: {results.save_dir}/weights/best.pt")
        print(f"📊 Last model: {results.save_dir}/weights/last.pt")

        # Print improvement expectations
        print("\n📈 Expected Improvements:")
        print("   • Traffic Light mAP50: 57.6% → 76%+ (target)")
        print("   • Central Line mAP50: 75.4% → 83%+ (target)")
        print("   • Overall mAP50: 77.2% → 83%+ (target)")
        print("   • Overall Recall: 65.1% → 75%+ (target)")

        print("\n🔍 To validate your improved model:")
        print("   model = YOLO('./runs/train/yolo11_traffic_optimized/weights/best.pt')")
        print("   results = model.val(data='./dataset.yaml')")

        return results

    except Exception as e:
        print(f"\n❌ Training failed with error: {e}")
        print("\n🔧 Troubleshooting tips:")
        print("   • Check GPU memory (reduce batch size if needed)")
        print("   • Verify dataset.yaml exists and is correct")
        print("   • Ensure sufficient disk space")
        print("   • Try reducing image size to 416 if memory issues")
        return None

def validate_improved_model():
    """Validate the improved model and compare with original performance"""
    print("\n🔍 Validating improved model...")

    improved_model_path = './runs/train/yolo11_traffic_optimized/weights/best.pt'

    if not os.path.exists(improved_model_path):
        print("❌ Improved model not found. Train the model first.")
        return

    try:
        # Load improved model
        model = YOLO(improved_model_path)

        # Validate on dataset
        results = model.val(data='./dataset.yaml')

        print("✅ Validation completed!")
        print("📊 Check the results above to see improvement")

        return results

    except Exception as e:
        print(f"❌ Validation failed: {e}")
        return None

def main():
    """Main function to run Strategy 1 training"""
    print("🚀 YOLO11 Traffic Detection - Strategy 1: Optimized Fine-tuning")
    print("================================================================")

    # Check if GPU is available
    try:
        import torch
        if torch.cuda.is_available():
            print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
            print(f"   GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        else:
            print("⚠️  No GPU detected. Training will use CPU (much slower).")
    except ImportError:
        print("⚠️  PyTorch not found. Please install PyTorch first.")
        return

    # Start training
    results = train_strategy_1_optimized()

    if results:
        # Validate the improved model
        validate_improved_model()

        print("\n✅ Training Strategy 1 completed successfully!")
        print("🎯 Your model should now have significantly better performance")
        print("   especially for Traffic Light and Central Line detection.")
    else:
        print("\n❌ Training failed. Please check the error messages above.")

if __name__ == "__main__":
    main()

SegFormer Transfer Learning

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
import glob
import yaml
from pathlib import Path
from typing import List, Tuple, Optional, Dict
import json
import random

class AutoConfig:
    """Automatically detect and configure paths and settings"""

    def __init__(self, base_dir: str = "."):
        self.base_dir = Path(base_dir)
        self.config = self._auto_detect_setup()

    def _auto_detect_setup(self) -> Dict:
        config = {}

        # Auto-detect YOLO model
        config['yolo_model'] = self._find_yolo_model()

        # Auto-detect dataset configuration
        config['dataset_config'] = self._find_dataset_config()

        # Auto-detect data paths
        config['data_paths'] = self._find_data_paths()

        # Set up output directories
        config['output_dirs'] = self._setup_output_dirs()

        return config

    def _find_yolo_model(self) -> str:
        """Find the best YOLO model to use"""
        # Priority order: trained model > yolo11n.pt > default
        possible_paths = [
            "runs/detect/train*/weights/best.pt",
            "yolo11n.pt",
            "yolo11s.pt",
            "yolo11m.pt"
        ]

        for pattern in possible_paths:
            matches = list(self.base_dir.glob(pattern))
            if matches:
                # Get the most recent if multiple matches
                latest = max(matches, key=lambda p: p.stat().st_mtime)
                print(f"Found YOLO model: {latest}")
                return str(latest)

        # Fallback to downloading yolo11n.pt
        print("No YOLO model found, will use yolo11n.pt (will download if needed)")
        return "yolo11n.pt"

    def _find_dataset_config(self) -> Dict:
        """Find and parse dataset configuration"""
        yaml_files = list(self.base_dir.glob("*.yaml")) + list(self.base_dir.glob("dataset.yaml"))

        if yaml_files:
            yaml_file = yaml_files[0]
            print(f"Found dataset config: {yaml_file}")

            try:
                with open(yaml_file, 'r') as f:
                    dataset_config = yaml.safe_load(f)

                # Extract class information
                if 'names' in dataset_config:
                    classes = dataset_config['names']
                    if isinstance(classes, dict):
                        num_classes = len(classes)
                        class_names = list(classes.values())
                    else:
                        num_classes = len(classes)
                        class_names = classes
                else:
                    # Default classes based on your screenshot
                    class_names = ['Background', 'Central line', 'Crosswalk', 'Lane',
                                  'Separation', 'Traffic light', 'Traffic sign']
                    num_classes = len(class_names)

                return {
                    'num_classes': num_classes,
                    'class_names': class_names,
                    'config_path': str(yaml_file)
                }
            except Exception as e:
                print(f"Error reading YAML: {e}")

        # Default configuration
        return {
            'num_classes': 7,
            'class_names': ['Background', 'Central line', 'Crosswalk', 'Lane',
                           'Separation', 'Traffic light', 'Traffic sign'],
            'config_path': None
        }

    def _find_data_paths(self) -> Dict:
        """Auto-detect training and validation data paths"""
        data_paths = {
            'train_images': [],
            'train_masks': [],
            'val_images': [],
            'val_masks': []
        }

        # Common data directory patterns
        possible_dirs = [
            'sample_data',
            'temp_data',
            'data',
            'dataset',
            'images',
            '.'
        ]

        all_images = []
        all_masks = []

        for data_dir in possible_dirs:
            data_path = self.base_dir / data_dir
            if data_path.exists():
                print(f"Scanning directory: {data_path}")

                # Find all images recursively
                img_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
                for ext in img_extensions:
                    all_images.extend(glob.glob(str(data_path / "**" / ext), recursive=True))

                # Find all masks recursively
                mask_extensions = ['*.png', '*.jpg', '*.jpeg']
                for ext in mask_extensions:
                    potential_masks = glob.glob(str(data_path / "**" / ext), recursive=True)
                    # Filter masks (usually in folders with 'mask', 'label', 'gt' in name)
                    for mask_path in potential_masks:
                        if any(keyword in mask_path.lower() for keyword in ['mask', 'label', 'gt', 'seg']):
                            all_masks.append(mask_path)

        # Remove duplicates and sort
        all_images = sorted(list(set(all_images)))
        all_masks = sorted(list(set(all_masks)))

        print(f"Found {len(all_images)} total images")
        print(f"Found {len(all_masks)} total masks")

        # Try to match images with masks
        matched_pairs = self._match_images_and_masks(all_images, all_masks)

        if matched_pairs:
            # Split matched pairs into train/val
            random.shuffle(matched_pairs)
            split_idx = int(0.8 * len(matched_pairs))

            train_pairs = matched_pairs[:split_idx]
            val_pairs = matched_pairs[split_idx:]

            if train_pairs:
                data_paths['train_images'], data_paths['train_masks'] = zip(*train_pairs)
                data_paths['train_images'] = list(data_paths['train_images'])
                data_paths['train_masks'] = list(data_paths['train_masks'])

            if val_pairs:
                data_paths['val_images'], data_paths['val_masks'] = zip(*val_pairs)
                data_paths['val_images'] = list(data_paths['val_images'])
                data_paths['val_masks'] = list(data_paths['val_masks'])

            print(f"Matched {len(matched_pairs)} image-mask pairs")
            print(f"Train: {len(data_paths['train_images'])}, Val: {len(data_paths['val_images'])}")

        # Fallback: use all images without masks for inference
        elif all_images:
            print("No masks found, using images for inference only")
            split_idx = int(0.8 * len(all_images))
            data_paths['train_images'] = all_images[:split_idx]
            data_paths['val_images'] = all_images[split_idx:]
            data_paths['train_masks'] = []
            data_paths['val_masks'] = []

        return data_paths

    def _match_images_and_masks(self, images: List[str], masks: List[str]) -> List[Tuple[str, str]]:
        """Match images with corresponding masks"""
        matched_pairs = []

        for img_path in images:
            img_name = Path(img_path).stem
            img_dir = str(Path(img_path).parent)

            best_match = None
            best_score = 0

            for mask_path in masks:
                mask_name = Path(mask_path).stem
                mask_dir = str(Path(mask_path).parent)

                # Calculate similarity score
                score = 0
                if img_name in mask_name or mask_name in img_name:
                    score += 3
                if img_name == mask_name:
                    score += 5
                if 'mask' in mask_dir.lower() and 'image' in img_dir.lower():
                    score += 2

                if score > best_score:
                    best_score = score
                    best_match = mask_path

            if best_match and best_score >= 2:
                matched_pairs.append((img_path, best_match))

        return matched_pairs

    def _setup_output_dirs(self) -> Dict:
        """Create output directories"""
        output_dirs = {
            'segformer_results': self.base_dir / 'segformer_results',
            'models': self.base_dir / 'segformer_results' / 'models',
            'predictions': self.base_dir / 'segformer_results' / 'predictions'
        }

        for dir_path in output_dirs.values():
            dir_path.mkdir(parents=True, exist_ok=True)

        return {k: str(v) for k, v in output_dirs.items()}

class AutoSegmentationDataset(Dataset):
    """Automatically configured segmentation dataset"""

    def __init__(self, image_paths: List[str], mask_paths: List[str],
                 processor: SegformerImageProcessor, num_classes: int,
                 augmentations: Optional[A.Compose] = None, is_training: bool = True):

        self.image_paths = sorted(image_paths)
        self.mask_paths = sorted(mask_paths) if mask_paths else []
        self.processor = processor
        self.num_classes = num_classes
        self.augmentations = augmentations if is_training else None
        self.is_training = is_training
        self.has_masks = len(self.mask_paths) > 0

        print(f"Dataset created: {len(self.image_paths)} images, {len(self.mask_paths)} masks")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        try:
            image = cv2.imread(self.image_paths[idx])
            if image is None:
                raise ValueError(f"Could not load image: {self.image_paths[idx]}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            # Return a dummy black image
            image = np.zeros((512, 512, 3), dtype=np.uint8)

        # Load mask if available
        if self.has_masks and idx < len(self.mask_paths):
            try:
                mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
                if mask is None:
                    mask = np.zeros(image.shape[:2], dtype=np.uint8)
            except:
                mask = np.zeros(image.shape[:2], dtype=np.uint8)
        else:
            # Create dummy mask for inference
            mask = np.zeros(image.shape[:2], dtype=np.uint8)

        # CRITICAL FIX: Ensure image and mask have same dimensions
        if image.shape[:2] != mask.shape[:2]:
            # Resize mask to match image dimensions
            mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
            print(f"⚠️  Resized mask from {mask.shape} to match image {image.shape[:2]}")

        # CRITICAL FIX: Resize to standard size BEFORE augmentations
        target_size = (512, 512)
        image = cv2.resize(image, target_size, interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)

        # Ensure mask values are within valid range
        mask = np.clip(mask, 0, self.num_classes - 1)

        # Apply augmentations (now both image and mask are same size)
        if self.augmentations:
            try:
                augmented = self.augmentations(image=image, mask=mask)
                image = augmented['image']
                mask = augmented['mask']
            except Exception as e:
                print(f"⚠️  Augmentation failed: {e}")
                # Fallback: apply basic normalization manually
                image = image.astype(np.float32) / 255.0
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                image = (image - mean) / std

        # Convert to tensor format
        if isinstance(image, np.ndarray):
            # If augmentation failed, convert manually
            if image.dtype != np.float32:
                image = image.astype(np.float32) / 255.0
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                image = (image - mean) / std
            pixel_values = torch.tensor(image).permute(2, 0, 1)
        else:
            # Already converted by augmentations
            pixel_values = image

        # Ensure consistent tensor size
        if pixel_values.shape[-2:] != torch.Size([512, 512]):
            pixel_values = torch.nn.functional.interpolate(
                pixel_values.unsqueeze(0),
                size=(512, 512),
                mode='bilinear',
                align_corners=False
            ).squeeze(0)

        # Convert mask to tensor and ensure correct size
        mask = torch.tensor(mask, dtype=torch.long)
        if mask.shape != torch.Size([512, 512]):
            mask = torch.nn.functional.interpolate(
                mask.unsqueeze(0).unsqueeze(0).float(),
                size=(512, 512),
                mode='nearest'
            ).squeeze().long()

        return {
            'pixel_values': pixel_values,
            'labels': mask,
            'image_path': self.image_paths[idx]
        }

class SafeSegformerModel(nn.Module):
    """Safe wrapper for SegformerForSemanticSegmentation with proper class handling"""

    def __init__(self, model_name: str, num_classes: int):
        super().__init__()
        self.num_classes = num_classes

        # Load base model without the classification head
        print(f"Loading SegFormer model for {num_classes} classes...")

        try:
            # First try with ignore_mismatched_sizes
            self.segformer = SegformerForSemanticSegmentation.from_pretrained(
                model_name,
                num_labels=num_classes,
                ignore_mismatched_sizes=True
            )
            print("✅ Model loaded successfully with ignore_mismatched_sizes=True")

        except Exception as e:
            print(f"❌ Failed to load with ignore_mismatched_sizes: {e}")

            try:
                # Fallback: Load base model and replace classifier manually
                print("🔄 Trying manual classifier replacement...")
                self.segformer = SegformerForSemanticSegmentation.from_pretrained(model_name)

                # Replace the classifier head
                original_classifier = self.segformer.decode_head.classifier
                in_channels = original_classifier.in_channels

                self.segformer.decode_head.classifier = nn.Conv2d(
                    in_channels, num_classes, kernel_size=1, stride=1, padding=0
                )

                # Update config
                self.segformer.config.num_labels = num_classes
                print("✅ Model loaded with manual classifier replacement")

            except Exception as e2:
                print(f"❌ Manual replacement failed: {e2}")
                raise RuntimeError(f"Could not load SegFormer model: {e2}")

    def forward(self, pixel_values, labels=None):
        return self.segformer(pixel_values=pixel_values, labels=labels)

class AutoSegformerPipeline:
    """Automated Segformer + YOLO pipeline"""

    def __init__(self, auto_mode: bool = True):
        self.config = AutoConfig().config if auto_mode else {}
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")

        # Initialize components
        self.processor = None
        self.segformer_model = None
        self.yolo_model = None

        self._setup_models()

    def _setup_models(self):
        """Setup all models automatically"""
        # Setup SegFormer
        model_name = "nvidia/segformer-b2-finetuned-ade-512-512"
        self.processor = SegformerImageProcessor.from_pretrained(model_name)

        num_classes = self.config['dataset_config']['num_classes']
        self.segformer_model = SafeSegformerModel(model_name, num_classes)
        self.segformer_model.to(self.device)

        # Setup YOLO
        yolo_path = self.config['yolo_model']
        try:
            self.yolo_model = YOLO(yolo_path)
            print(f"✅ YOLO model loaded: {yolo_path}")
        except Exception as e:
            print(f"⚠️  YOLO loading warning: {e}")
            self.yolo_model = YOLO("yolo11n.pt")  # Fallback

        print(f"✅ SegFormer initialized for {num_classes} classes")
        print(f"Class names: {self.config['dataset_config']['class_names']}")

    def create_dataloaders(self, batch_size: int = 4):
        """Create train and validation dataloaders automatically"""
        data_paths = self.config['data_paths']

        # Check if we have training data
        if not data_paths['train_images']:
            print("❌ No training images found!")
            return None, None

        # Augmentations - Fixed with shape checking disabled and proper transforms
        train_augs = A.Compose([
            # Don't resize here since we do it manually before augmentations
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.RandomGamma(gamma_limit=(80, 120), p=0.3),
            A.Blur(blur_limit=3, p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ], is_check_shapes=False)  # CRITICAL: Disable shape checking

        val_augs = A.Compose([
            # Don't resize here since we do it manually before augmentations
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ], is_check_shapes=False)  # CRITICAL: Disable shape checking

        # Create datasets
        train_dataset = AutoSegmentationDataset(
            data_paths['train_images'],
            data_paths['train_masks'],
            self.processor,
            self.config['dataset_config']['num_classes'],
            train_augs,
            is_training=True
        )

        # Create validation dataset (even if no masks)
        val_dataset = None
        if data_paths['val_images']:
            val_dataset = AutoSegmentationDataset(
                data_paths['val_images'],
                data_paths['val_masks'],
                self.processor,
                self.config['dataset_config']['num_classes'],
                val_augs,
                is_training=False
            )

        # Custom collate function to handle any remaining size mismatches
        def safe_collate_fn(batch):
            """Custom collate function that ensures all tensors are the same size"""
            pixel_values = []
            labels = []
            image_paths = []

            for item in batch:
                # Ensure all pixel_values are (3, 512, 512)
                pv = item['pixel_values']
                if pv.shape != torch.Size([3, 512, 512]):
                    pv = torch.nn.functional.interpolate(
                        pv.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False
                    ).squeeze(0)
                pixel_values.append(pv)

                # Ensure all labels are (512, 512)
                lbl = item['labels']
                if lbl.shape != torch.Size([512, 512]):
                    lbl = torch.nn.functional.interpolate(
                        lbl.unsqueeze(0).unsqueeze(0).float(), size=(512, 512), mode='nearest'
                    ).squeeze().long()
                labels.append(lbl)

                image_paths.append(item['image_path'])

            return {
                'pixel_values': torch.stack(pixel_values),
                'labels': torch.stack(labels),
                'image_path': image_paths
            }

        # Create dataloaders with custom collate function
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,
            collate_fn=safe_collate_fn
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=safe_collate_fn
        ) if val_dataset else None

        print(f"✅ Created dataloaders:")
        print(f"   Training: {len(train_dataset)} samples")
        print(f"   Validation: {len(val_dataset) if val_dataset else 0} samples")

        return train_loader, val_loader

    def train(self, num_epochs: int = 20, learning_rate: float = 1e-4, batch_size: int = 4):
        """Auto-training pipeline"""
        print("🚀 Starting automated training...")

        train_loader, val_loader = self.create_dataloaders(batch_size)

        if train_loader is None or len(train_loader) == 0:
            print("❌ No training data found! Please check your data paths.")
            return

        # Setup optimizer and scheduler
        optimizer = optim.AdamW(self.segformer_model.parameters(), lr=learning_rate, weight_decay=0.01)
        scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=num_epochs, power=0.9)

        best_val_loss = float('inf')

        for epoch in range(num_epochs):
            # Training
            self.segformer_model.train()
            train_loss = 0.0
            num_batches = 0

            for batch_idx, batch in enumerate(train_loader):
                try:
                    pixel_values = batch['pixel_values'].to(self.device)
                    labels = batch['labels'].to(self.device)

                    optimizer.zero_grad()
                    outputs = self.segformer_model(pixel_values=pixel_values, labels=labels)
                    loss = outputs.loss

                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item()
                    num_batches += 1

                    if batch_idx % 5 == 0:
                        print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx+1}, Loss: {loss.item():.4f}')

                except Exception as e:
                    print(f"⚠️  Training batch error: {e}")
                    continue

            # Validation
            val_loss = train_loss / max(num_batches, 1)  # Fallback if no val data

            if val_loader and len(val_loader) > 0:
                self.segformer_model.eval()
                val_loss = 0.0
                val_batches = 0

                with torch.no_grad():
                    for batch in val_loader:
                        try:
                            pixel_values = batch['pixel_values'].to(self.device)
                            labels = batch['labels'].to(self.device)

                            outputs = self.segformer_model(pixel_values=pixel_values, labels=labels)
                            val_loss += outputs.loss.item()
                            val_batches += 1
                        except Exception as e:
                            print(f"⚠️  Validation batch error: {e}")
                            continue

                if val_batches > 0:
                    val_loss = val_loss / val_batches

            avg_train_loss = train_loss / max(num_batches, 1)

            print(f'📊 Epoch {epoch+1}/{num_epochs}: Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}')

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                model_path = os.path.join(self.config['output_dirs']['models'], 'best_segformer.pth')
                torch.save(self.segformer_model.state_dict(), model_path)
                print(f'💾 New best model saved to: {model_path}')

            scheduler.step()

        print("🎉 Training completed!")

    def predict_combined(self, image_path: str, save_results: bool = True):
        """Combined YOLO + Segformer prediction"""
        if not os.path.exists(image_path):
            print(f"❌ Image not found: {image_path}")
            return None, None

        try:
            # Load image
            image = cv2.imread(image_path)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # YOLO Detection
            yolo_results = self.yolo_model(image_path)

            # SegFormer Segmentation
            inputs = self.processor(image_rgb, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            self.segformer_model.eval()
            with torch.no_grad():
                outputs = self.segformer_model(**inputs)

            # Process segmentation
            logits = outputs.logits
            upsampled_logits = nn.functional.interpolate(
                logits,
                size=image_rgb.shape[:2],
                mode="bilinear",
                align_corners=False,
            )
            segmentation_mask = upsampled_logits.argmax(dim=1).cpu().numpy()[0]

            if save_results:
                self._save_prediction_results(image_path, yolo_results, segmentation_mask, image_rgb)

            return yolo_results, segmentation_mask

        except Exception as e:
            print(f"❌ Prediction error: {e}")
            return None, None

    def _save_prediction_results(self, image_path: str, yolo_results, seg_mask: np.ndarray, original_image: np.ndarray):
        """Save prediction results"""
        try:
            output_dir = self.config['output_dirs']['predictions']
            base_name = Path(image_path).stem

            # Save segmentation mask
            mask_path = os.path.join(output_dir, f'{base_name}_segmentation.png')
            cv2.imwrite(mask_path, seg_mask.astype(np.uint8) * 50)  # Scale for visibility

            # Save YOLO results
            if yolo_results:
                yolo_img = yolo_results[0].plot()
                yolo_path = os.path.join(output_dir, f'{base_name}_yolo.jpg')
                cv2.imwrite(yolo_path, yolo_img)

            print(f"💾 Results saved to: {output_dir}")

        except Exception as e:
            print(f"⚠️  Error saving results: {e}")

def main():
    """Automated main function"""
    print("🚀 === Automated SegFormer + YOLO11 Pipeline ===")

    try:
        # Initialize automated pipeline
        pipeline = AutoSegformerPipeline(auto_mode=True)

        # Print configuration
        print("\n📋 Configuration detected:")
        print(f"   YOLO model: {pipeline.config['yolo_model']}")
        print(f"   Number of classes: {pipeline.config['dataset_config']['num_classes']}")
        print(f"   Classes: {pipeline.config['dataset_config']['class_names']}")
        print(f"   Training images: {len(pipeline.config['data_paths']['train_images'])}")
        print(f"   Validation images: {len(pipeline.config['data_paths']['val_images'])}")

        # Ask user what to do
        print("\n🎯 What would you like to do?")
        print("1. Train SegFormer")
        print("2. Run inference only")
        print("3. Both (recommended)")

        choice = input("Enter choice (1/2/3): ").strip()

        if choice in ['1', '3']:
            # Training
            epochs = int(input("Enter number of epochs (default 20): ") or "20")
            batch_size = int(input("Enter batch size (default 8 for T4): ") or "8")

            pipeline.train(num_epochs=epochs, batch_size=batch_size)

        if choice in ['2', '3']:
            # Inference
            test_images = input("Enter path to test image (or press Enter to use sample): ").strip()

            if not test_images:
                # Use any available image from validation or train set
                if pipeline.config['data_paths']['val_images']:
                    test_images = pipeline.config['data_paths']['val_images'][0]
                elif pipeline.config['data_paths']['train_images']:
                    test_images = pipeline.config['data_paths']['train_images'][0]

            if test_images and os.path.exists(test_images):
                print(f"\n🔍 Running inference on: {test_images}")
                yolo_results, seg_mask = pipeline.predict_combined(test_images)

                if yolo_results:
                    print("🎯 YOLO Detection Results:")
                    if len(yolo_results[0].boxes) > 0:
                        for i, box in enumerate(yolo_results[0].boxes):
                            class_id = int(box.cls[0])
                            confidence = float(box.conf[0])
                            print(f"   Object {i+1}: Class {class_id}, Confidence: {confidence:.3f}")
                    else:
                        print("   No objects detected")

                if seg_mask is not None:
                    print(f"\n🎨 Segmentation Results:")
                    print(f"   Mask shape: {seg_mask.shape}")
                    print(f"   Unique classes: {np.unique(seg_mask)}")

            else:
                print("❌ No test images found!")

        print("\n✅ Pipeline completed successfully!")

    except Exception as e:
        print(f"❌ Pipeline error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

COMPLETE SEGFORMER VIDEO SEGMENTATION

In [None]:
# =============================================================================
# FIXED SEGFORMER VIDEO SEGMENTATION FOR TRAFFIC DATASET
# Upload video → Process with trained SegFormer → Download results
# =============================================================================

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import display, HTML, clear_output
from google.colab import files
import ipywidgets as widgets

# =============================================================================
# FIXED CONFIGURATION FOR YOUR TRAFFIC DATASET
# =============================================================================

CONFIG = {
    # FIXED: Path to your trained SegFormer model
    'model_path': './segformer_results/models/best_segformer.pth',  # Your saved model
    'base_model': 'nvidia/segformer-b2-finetuned-ade-512-512',    # Base pretrained model

    # FIXED: Your actual traffic class names
    'class_names': ['Background', 'Central line', 'Crosswalk', 'Lane',
                   'Separation', 'Traffic light', 'Traffic sign'],

    # Processing settings
    'max_size': 640,            # Processing resolution (640 for speed, 1024 for quality)
    'overlay_alpha': 0.6,       # Overlay transparency (0.0-1.0)
    'fps_output': 30,           # Output video FPS

    # FIXED: Colors for traffic classes (RGB)
    'colors': {
        0: [0, 0, 0],           # Background - black
        1: [255, 0, 0],         # Central line - red
        2: [0, 255, 0],         # Crosswalk - green
        3: [0, 0, 255],         # Lane - blue
        4: [255, 255, 0],       # Separation - yellow
        5: [255, 0, 255],       # Traffic light - magenta
        6: [0, 255, 255],       # Traffic sign - cyan
    }
}

print("🎯 SEGFORMER VIDEO SEGMENTATION - TRAFFIC DATASET VERSION")
print("=" * 60)

# =============================================================================
# FIXED MODEL LOADER CLASS
# =============================================================================

class SafeSegformerModel(nn.Module):
    """Safe wrapper for loading your trained SegFormer model with auto class detection"""

    def __init__(self, base_model_name: str, num_classes: int, model_path: str = None):
        super().__init__()
        self.num_classes = num_classes

        # Auto-detect number of classes from saved model if available
        if model_path and os.path.exists(model_path):
            try:
                print("🔍 Auto-detecting number of classes from saved model...")
                state_dict = torch.load(model_path, map_location='cpu')

                # Find classifier weight shape to determine actual number of classes
                classifier_key = None
                for key in state_dict.keys():
                    if 'classifier.weight' in key:
                        classifier_key = key
                        break

                if classifier_key:
                    actual_num_classes = state_dict[classifier_key].shape[0]
                    print(f"📊 Detected {actual_num_classes} classes in saved model")
                    self.num_classes = actual_num_classes
                else:
                    print("⚠️  Could not detect classes from model, using provided number")

            except Exception as e:
                print(f"⚠️  Error detecting classes: {e}, using provided number")

        # Load base model with detected/provided number of classes
        print(f"🏗️  Creating model with {self.num_classes} classes...")
        self.segformer = SegformerForSemanticSegmentation.from_pretrained(
            base_model_name,
            num_labels=self.num_classes,
            ignore_mismatched_sizes=True
        )

    def forward(self, pixel_values):
        outputs = self.segformer(pixel_values=pixel_values)
        return outputs

# =============================================================================
# FIXED VIDEO PROCESSOR CLASS
# =============================================================================

class VideoSegFormer:
    def __init__(self, model_path, base_model, class_names, colors):
        self.model_path = model_path
        self.base_model = base_model
        self.original_class_names = class_names
        self.original_colors = colors
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        print(f"🧠 Loading SegFormer model...")
        print(f"   Device: {self.device}")

        try:
            # FIXED: Load processor from base model
            self.processor = SegformerImageProcessor.from_pretrained(base_model)

            # FIXED: Auto-detect classes and load model
            initial_num_classes = len(class_names)
            self.model = SafeSegformerModel(base_model, initial_num_classes, model_path)
            self.num_classes = self.model.num_classes

            # Adjust class names and colors based on detected number
            self.class_names, self.colors = self._adjust_classes_and_colors()

            print(f"   Final classes: {self.num_classes}")
            print(f"   Class mapping: {self.class_names}")

            # FIXED: Load your trained weights if available
            if model_path and os.path.exists(model_path):
                print(f"📥 Loading trained weights from: {model_path}")
                state_dict = torch.load(model_path, map_location=self.device)

                # Try to load with strict=False to handle any remaining mismatches
                try:
                    self.model.load_state_dict(state_dict, strict=True)
                    print("✅ Trained weights loaded successfully (strict)!")
                except RuntimeError as e:
                    print(f"⚠️  Strict loading failed: {e}")
                    print("🔄 Trying flexible loading...")
                    self.model.load_state_dict(state_dict, strict=False)
                    print("✅ Trained weights loaded successfully (flexible)!")
            else:
                print(f"⚠️  Trained weights not found: {model_path}")
                print("🔄 Using base pretrained model...")

            self.model.to(self.device)
            self.model.eval()
            print("✅ Model ready for inference!")

        except Exception as e:
            print(f"❌ Model loading error: {e}")
            raise Exception(f"Failed to load model: {e}")

    def _adjust_classes_and_colors(self):
        """Adjust class names and colors based on detected number of classes"""

        if self.num_classes == len(self.original_class_names):
            # Perfect match
            return self.original_class_names, self.original_colors

        elif self.num_classes == len(self.original_class_names) - 1:
            # Likely missing background class
            print("🔄 Adjusting for model without background class...")
            adjusted_names = self.original_class_names[1:]  # Remove background
            adjusted_colors = {i: self.original_colors[i+1] for i in range(self.num_classes)}
            return adjusted_names, adjusted_colors

        else:
            # Custom adjustment
            print(f"🔄 Custom class adjustment: {self.num_classes} detected vs {len(self.original_class_names)} expected")

            # Use as many class names as we have
            if self.num_classes <= len(self.original_class_names):
                adjusted_names = self.original_class_names[:self.num_classes]
                adjusted_colors = {i: self.original_colors[i] for i in range(self.num_classes)}
            else:
                # More classes than names
                adjusted_names = self.original_class_names + [f"Class_{i}" for i in range(len(self.original_class_names), self.num_classes)]
                adjusted_colors = self.original_colors.copy()
                for i in range(len(self.original_colors), self.num_classes):
                    adjusted_colors[i] = [np.random.randint(0, 255) for _ in range(3)]

            return adjusted_names, adjusted_colors

    def segment_frame(self, frame):
        """Segment a single frame"""
        try:
            # Resize if needed
            original_h, original_w = frame.shape[:2]
            if max(original_h, original_w) > CONFIG['max_size']:
                scale = CONFIG['max_size'] / max(original_h, original_w)
                new_w, new_h = int(original_w * scale), int(original_h * scale)
                resized = cv2.resize(frame, (new_w, new_h))
            else:
                resized = frame
                scale = 1.0

            # Convert to RGB and process
            rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)

            # FIXED: Process with your processor
            inputs = self.processor(rgb, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits

                # FIXED: Proper upsampling and prediction
                upsampled_logits = nn.functional.interpolate(
                    logits,
                    size=rgb.shape[:2],
                    mode="bilinear",
                    align_corners=False,
                )
                prediction = upsampled_logits.argmax(dim=1)[0].cpu().numpy()

            # Resize back if needed
            if scale != 1.0:
                pred_pil = Image.fromarray(prediction.astype(np.uint8))
                pred_pil = pred_pil.resize((original_w, original_h), Image.NEAREST)
                prediction = np.array(pred_pil)

            return prediction

        except Exception as e:
            print(f"⚠️ Frame processing error: {e}")
            return np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)

    def create_overlay(self, frame, prediction):
        """Create colored overlay for traffic classes"""
        colored_mask = np.zeros_like(frame)

        for class_id, color in self.colors.items():
            if class_id > 0:  # Skip background
                mask = prediction == class_id
                colored_mask[mask] = color

        # Blend
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        overlay = cv2.addWeighted(frame_rgb, 1-CONFIG['overlay_alpha'],
                                 colored_mask, CONFIG['overlay_alpha'], 0)

        return cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)

    def add_legend(self, frame, prediction):
        """Add legend showing detected traffic classes"""
        unique_classes = np.unique(prediction)

        # Handle different class indexing based on whether background is included
        if self.num_classes == len(self.original_class_names) - 1:
            # Model trained without background (classes 0-5 map to original 1-6)
            detected = [c for c in unique_classes if 0 <= c < len(self.class_names)]
        else:
            # Model includes background or has different structure
            detected = [c for c in unique_classes if c > 0 and c < len(self.class_names)]

        if not detected:
            return frame

        # Legend parameters
        legend_width = 250
        legend_height = 30 + len(detected) * 25
        x = frame.shape[1] - legend_width - 10
        y = 10

        # Background
        cv2.rectangle(frame, (x-5, y-5), (x+legend_width+5, y+legend_height+5),
                     (0, 0, 0), -1)
        cv2.rectangle(frame, (x-5, y-5), (x+legend_width+5, y+legend_height+5),
                     (255, 255, 255), 2)

        # Title
        cv2.putText(frame, "Traffic Elements:", (x, y+20), cv2.FONT_HERSHEY_SIMPLEX,
                   0.6, (255, 255, 255), 2)

        # Classes
        for i, class_id in enumerate(detected):
            y_pos = y + 40 + i * 25
            color = self.colors.get(class_id, [255, 255, 255])

            # Get class name based on indexing
            if class_id < len(self.class_names):
                class_name = self.class_names[class_id]
            else:
                class_name = f"Class{class_id}"

            # Color box
            cv2.rectangle(frame, (x, y_pos-10), (x+15, y_pos), color, -1)

            # Text
            cv2.putText(frame, class_name, (x+20, y_pos), cv2.FONT_HERSHEY_SIMPLEX,
                       0.5, (255, 255, 255), 1)

        return frame

# =============================================================================
# FIXED PROCESSING FUNCTIONS
# =============================================================================

def find_model_automatically():
    """Automatically find your trained model"""
    possible_paths = [
        './segformer_results/models/best_segformer.pth',
        './segformer_results/best_segformer.pth',
        './best_segformer.pth',
        './models/best_segformer.pth'
    ]

    for path in possible_paths:
        if os.path.exists(path):
            print(f"✅ Found model: {path}")
            return path

    print("⚠️  No trained model found. Will use base pretrained model.")
    return None

def process_video(input_file):
    """Process uploaded video with SegFormer"""

    # FIXED: Auto-find model
    model_path = find_model_automatically()
    if model_path:
        CONFIG['model_path'] = model_path

    try:
        # Initialize processor
        print("🚀 Initializing SegFormer...")
        segformer = VideoSegFormer(
            CONFIG['model_path'],
            CONFIG['base_model'],
            CONFIG['class_names'],
            CONFIG['colors']
        )

        # Open video
        cap = cv2.VideoCapture(input_file)
        if not cap.isOpened():
            print(f"❌ Cannot open video: {input_file}")
            return None

        # Get video info
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        print(f"📹 Video Info:")
        print(f"   Resolution: {width}x{height}")
        print(f"   FPS: {fps:.1f}")
        print(f"   Frames: {total_frames}")
        print(f"   Duration: {total_frames/fps:.1f}s")

        # Output file
        base_name = os.path.splitext(input_file)[0]
        output_file = f"{base_name}_traffic_segmented.mp4"
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_file, fourcc, CONFIG['fps_output'], (width, height))

        print(f"\n🔄 Processing video for traffic segmentation...")

        # Process frames
        frame_count = 0
        success_count = 0

        with tqdm(total=total_frames, desc="Segmenting frames") as pbar:
            while True:
                ret, frame = cap.read()
                if not ret:
                    break

                try:
                    # Segment frame
                    prediction = segformer.segment_frame(frame)

                    # Create overlay
                    overlay = segformer.create_overlay(frame, prediction)

                    # Add legend
                    final_frame = segformer.add_legend(overlay, prediction)

                    # Write frame
                    out.write(final_frame)
                    success_count += 1

                except Exception as e:
                    print(f"⚠️ Error in frame {frame_count}: {e}")
                    out.write(frame)  # Use original frame

                frame_count += 1
                pbar.update(1)

                # Update progress every 30 frames
                if frame_count % 30 == 0:
                    pbar.set_description(f"Segmenting frames - {success_count/frame_count*100:.1f}% success")

        # Cleanup
        cap.release()
        out.release()

        print(f"✅ Processing complete!")
        print(f"   Success rate: {success_count/total_frames*100:.1f}%")
        print(f"   Output: {output_file}")

        return output_file

    except Exception as e:
        print(f"❌ Processing failed: {e}")
        import traceback
        traceback.print_exc()
        return None

def show_sample_frames(video_file, num_samples=3):
    """Show sample frames from processed video"""
    try:
        cap = cv2.VideoCapture(video_file)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Get sample frames
        frames = []
        indices = [total_frames//4, total_frames//2, 3*total_frames//4]

        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

        cap.release()

        if frames:
            fig, axes = plt.subplots(1, len(frames), figsize=(15, 5))
            if len(frames) == 1:
                axes = [axes]

            for i, frame in enumerate(frames):
                axes[i].imshow(frame)
                axes[i].set_title(f"Frame {indices[i]} - Traffic Segmentation")
                axes[i].axis('off')

            plt.suptitle("🚗 Sample Traffic Segmented Frames", fontsize=16)
            plt.tight_layout()
            plt.show()

    except Exception as e:
        print(f"⚠️ Cannot show samples: {e}")

# =============================================================================
# FIXED INTERFACE
# =============================================================================

def upload_and_segment_traffic():
    """Simple function to upload and segment traffic video"""
    print("🚗 UPLOAD AND SEGMENT TRAFFIC VIDEO")
    print("=" * 40)

    # Auto-find model
    model_path = find_model_automatically()

    if model_path:
        print(f"✅ Using trained model: {model_path}")
    else:
        print("⚠️  No trained model found, using base pretrained model")
        print("💡 The model will still work but might not be optimal for your specific traffic dataset")

    print(f"🎯 Traffic classes: {', '.join(CONFIG['class_names'])}")

    # Upload
    print("\n📁 Select your video file:")
    uploaded = files.upload()

    if not uploaded:
        print("❌ No file uploaded!")
        return

    filename = list(uploaded.keys())[0]
    print(f"✅ Processing {filename} for traffic segmentation...")

    # Process
    output_file = process_video(filename)

    if output_file:
        print(f"\n🎉 Complete! Traffic segmentation finished!")
        print(f"📊 File size: {os.path.getsize(output_file)/1024/1024:.1f} MB")

        # Show sample frames
        show_sample_frames(output_file)

        print(f"\n⬇️ Downloading {output_file}...")
        files.download(output_file)
    else:
        print("❌ Processing failed!")

# =============================================================================
# STATUS CHECK AND AUTO-START
# =============================================================================

def check_setup():
    """Check if everything is ready"""
    print("⚙️ SETUP STATUS CHECK")
    print("=" * 30)

    # Check model and auto-detect classes
    model_path = find_model_automatically()
    if model_path:
        print(f"✅ Model: {model_path}")

        # Try to detect actual number of classes
        try:
            state_dict = torch.load(model_path, map_location='cpu')
            classifier_key = None
            for key in state_dict.keys():
                if 'classifier.weight' in key:
                    classifier_key = key
                    break

            if classifier_key:
                actual_num_classes = state_dict[classifier_key].shape[0]
                print(f"📊 Detected classes: {actual_num_classes}")

                if actual_num_classes == 6:
                    print("🔄 Model trained with 6 classes (likely without background)")
                    adjusted_classes = CONFIG['class_names'][1:]  # Remove background
                    print(f"📝 Adjusted class list: {adjusted_classes}")
                elif actual_num_classes == 7:
                    print(f"📝 Model classes: {CONFIG['class_names']}")
                else:
                    print(f"⚠️  Custom class count: {actual_num_classes}")

        except Exception as e:
            print(f"⚠️  Could not detect classes: {e}")
    else:
        print("⚠️  Model: Using base pretrained model")

    # Check device
    device = 'GPU' if torch.cuda.is_available() else 'CPU'
    print(f"✅ Device: {device}")

    print("\n🚀 Ready to process traffic videos!")
    print("=" * 30)

# Run setup check
check_setup()

print("\n" + "="*60)
print("🚗 HOW TO USE FOR TRAFFIC SEGMENTATION:")
print("="*60)
print(">>> upload_and_segment_traffic()")
print("="*60)

# Auto-start
print("\n🎯 Starting traffic video segmentation interface...")
upload_and_segment_traffic()