In [2]:
import os
import glob
import random
import xml.etree.ElementTree as ET
from PIL import Image

In [3]:
IMAGE_DIR = "images"                 # Directory containing PNG images
XML_DIR = "annotations"              # Directory containing XML annotations
OUTPUT_DIR = "cnn_dataset"          # Directory to save the prepared dataset

In [4]:
# Extract unique color classes from XML annotations
color_classes = set()
for xml_file in glob.glob(os.path.join(XML_DIR, "*.xml")):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    for obj in root.findall('object'):
        color = obj.find('color').text
        if color:
            color_classes.add(color)

color_classes = sorted(list(color_classes))

print(f"Found {len(color_classes)} unique brick classes: {color_classes}")

# Create directories for each color class
for color in color_classes:
    os.makedirs(os.path.join(OUTPUT_DIR, 'train', color), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, 'test', color), exist_ok=True) 
print(f"Successfully created directories for each color class in {OUTPUT_DIR}/train and {OUTPUT_DIR}/test.")

Found 40 unique brick classes: ['m0', 'm1', 'm10', 'm11', 'm12', 'm13', 'm14', 'm15', 'm16', 'm17', 'm18', 'm19', 'm2', 'm20', 'm21', 'm22', 'm23', 'm24', 'm25', 'm26', 'm27', 'm28', 'm29', 'm3', 'm30', 'm31', 'm32', 'm33', 'm34', 'm35', 'm36', 'm37', 'm38', 'm39', 'm4', 'm5', 'm6', 'm7', 'm8', 'm9']
Successfully created directories for each color class in cnn_dataset/train and cnn_dataset/test.


In [None]:
# Change this to multithreading if possible (one one thread it took 12h to run half of the dataset lmao)
def extract_bricks_from_image(file_id, dataset_type):
    try:
        tree = ET.parse(os.path.join(XML_DIR, f"{file_id}.xml"))
        root = tree.getroot()
        
        for obj in root.findall('object'):
            try:
                # Get color
                color_class = obj.find('color').text
                if not color_class:
                    continue
                
                # Get bounding box coordinates
                bndbox = obj.find('bndbox')
                xmin = float(bndbox.find('xmin').text)
                ymin = float(bndbox.find('ymin').text)
                xmax = float(bndbox.find('xmax').text)
                ymax = float(bndbox.find('ymax').text)
                
                # Load the image
                img_path = os.path.join(IMAGE_DIR, f"{file_id}.png")
                if not os.path.exists(img_path):
                    continue
                    
                img = Image.open(img_path)
                image_width, image_height = img.size
                
                # Ensure coordinates are within image bounds
                xmin = max(0, int(xmin))
                ymin = max(0, int(ymin))
                xmax = min(image_width, int(xmax))
                ymax = min(image_height, int(ymax))
                
                # Check if the box is valid (has positive width and height)
                if xmax <= xmin or ymax <= ymin:
                    continue
                
                # Crop the image to the bounding box
                brick_img = img.crop((xmin, ymin, xmax, ymax))
                
                # Create a unique filename for the brick
                output_filename = f"{file_id}_{xmin}_{ymin}_{xmax}_{ymax}.png"
                output_path = os.path.join(OUTPUT_DIR, dataset_type, color_class, output_filename)
                
                # Save the cropped image
                brick_img.save(output_path)
            except Exception as e:
                print(f"Error processing object in {file_id}: {e}")
    except Exception as e:
        print(f"Error processing file {file_id}: {e}")

In [None]:
# Get all image files
image_files = sorted(glob.glob(os.path.join(IMAGE_DIR, "*.png")))
print(f"Found {len(image_files)} image files")

# Randomly split into train and validation sets (80:20)
random.seed(42)  # for reproducibility
random.shuffle(image_files)
split_idx = int(0.8 * len(image_files))
train_images = image_files[:split_idx]
test_images = image_files[split_idx:]

print(f"Training images: {len(train_images)}")
print(f"Validation images: {len(test_images)}")

for img_path in train_images:
    file_id = os.path.splitext(os.path.basename(img_path))[0]
    extract_bricks_from_image(file_id, 'train')

for img_path in test_images:
    file_id = os.path.splitext(os.path.basename(img_path))[0]
    extract_bricks_from_image(file_id, 'test')

print("Dataset preparation complete!")