In [1]:
import os
import cv2
import numpy as np
import xml.etree.ElementTree as ET
import shutil

In [2]:
import os
import cv2
import shutil
import numpy as np
import xml.etree.ElementTree as ET

def create_full_dataset(root_dir, output_dir, class_mapping):
    """
    Create dataset with mirrored structure for both images and masks
    with strict class name validation
    """
    # Define paths
    input_images = os.path.join(root_dir, "PCB_DATASET", "images")
    annotations_base = os.path.join(root_dir, "PCB_DATASET", "Annotations")
    
    # Create output directories
    output_images = os.path.join(output_dir, "images")
    output_masks = os.path.join(output_dir, "masks")
    os.makedirs(output_images, exist_ok=True)
    os.makedirs(output_masks, exist_ok=True)

    # Get valid class names for error messages
    valid_classes = list(class_mapping.keys())
    
    # Process each defect category
    defects = ["Missing_hole", "Mouse_bite", "Open_circuit", 
              "Short", "Spur", "Spurious_copper"]

    for defect in defects:
        print(f"\nProcessing {defect}...")
        
        # Create output subdirectories
        img_defect_dir = os.path.join(output_images, defect)
        mask_defect_dir = os.path.join(output_masks, defect)
        os.makedirs(img_defect_dir, exist_ok=True)
        os.makedirs(mask_defect_dir, exist_ok=True)

        # Get input paths
        input_img_dir = os.path.join(input_images, defect)
        input_xml_dir = os.path.join(annotations_base, defect)

        # Process each image
        for img_file in os.listdir(input_img_dir):
            if not img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue

            base_name = os.path.splitext(img_file)[0]
            img_path = os.path.join(input_img_dir, img_file)
            xml_path = os.path.join(input_xml_dir, f"{base_name}.xml")
            
            # Skip if XML is missing
            if not os.path.exists(xml_path):
                print(f"Warning: Missing XML for {img_file}")
                continue

            try:
                # Copy image to output directory
                output_img_path = os.path.join(img_defect_dir, img_file)
                shutil.copy2(img_path, output_img_path)

                # Generate mask
                img = cv2.imread(img_path)
                h, w = img.shape[:2]
                mask = np.zeros((h, w), dtype=np.uint8)

                tree = ET.parse(xml_path)
                root = tree.getroot()

                for obj in root.findall('object'):
                    # Validate class name
                    class_name = obj.find('name').text.strip().lower()
                    
                    if class_name not in class_mapping:
                        raise ValueError(
                            f"Invalid class '{class_name}' in {img_file}\n"
                            f"Valid classes: {valid_classes}\n"
                            f"XML path: {xml_path}"
                        )
                    
                    class_id = class_mapping[class_name]

                    # Process bounding box
                    bndbox = obj.find('bndbox')
                    try:
                        xmin = int(bndbox.find('xmin').text)
                        ymin = int(bndbox.find('ymin').text)
                        xmax = int(bndbox.find('xmax').text)
                        ymax = int(bndbox.find('ymax').text)
                    except AttributeError:
                        print(f"Invalid bndbox in {img_file}")
                        continue

                    # Validate coordinates
                    xmin = max(0, min(xmin, w-1))
                    xmax = max(0, min(xmax, w-1))
                    ymin = max(0, min(ymin, h-1))
                    ymax = max(0, min(ymax, h-1))
                    
                    # Draw rectangle with class ID
                    cv2.rectangle(mask, (xmin, ymin), (xmax, ymax), class_id, -1)

                # Verify and save mask
                unique_values = np.unique(mask)
                print(f"Processed {img_file} - Unique mask values: {unique_values}")
                
                mask_path = os.path.join(mask_defect_dir, f"{base_name}.png")
                cv2.imwrite(mask_path, mask)

            except Exception as e:
                print(f"Error processing {img_file}: {str(e)}")
                raise  # Re-raise the error to stop execution

# Class mapping with lowercase keys
CLASS_MAPPING = {
    "missing_hole": 1,
    "mouse_bite": 2,
    "open_circuit": 3,
    "short": 4,
    "spur": 5,
    "spurious_copper": 6
}

# Usage
create_full_dataset(
    root_dir="/kaggle/input/pcb-defects",
    output_dir="/kaggle/working/pcb_dataset",
    class_mapping=CLASS_MAPPING
)


Processing Missing_hole...
Processed 01_missing_hole_01.jpg - Unique mask values: [0 1]
Processed 04_missing_hole_01.jpg - Unique mask values: [0 1]
Processed 01_missing_hole_17.jpg - Unique mask values: [0 1]
Processed 04_missing_hole_12.jpg - Unique mask values: [0 1]
Processed 07_missing_hole_06.jpg - Unique mask values: [0 1]
Processed 12_missing_hole_08.jpg - Unique mask values: [0 1]
Processed 04_missing_hole_10.jpg - Unique mask values: [0 1]
Processed 01_missing_hole_02.jpg - Unique mask values: [0 1]
Processed 08_missing_hole_06.jpg - Unique mask values: [0 1]
Processed 04_missing_hole_06.jpg - Unique mask values: [0 1]
Processed 06_missing_hole_08.jpg - Unique mask values: [0 1]
Processed 01_missing_hole_20.jpg - Unique mask values: [0 1]
Processed 12_missing_hole_09.jpg - Unique mask values: [0 1]
Processed 04_missing_hole_04.jpg - Unique mask values: [0 1]
Processed 01_missing_hole_08.jpg - Unique mask values: [0 1]
Processed 07_missing_hole_08.jpg - Unique mask values: [0