In [3]:
import os
import xml.etree.ElementTree as ET
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
# Function to parse XML files and extract bounding boxes
def parse_voc_xml(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    boxes = []
    for member in root.findall("object"):
        xmin = int(member.find("bndbox").find("xmin").text)
        ymin = int(member.find("bndbox").find("ymin").text)
        xmax = int(member.find("bndbox").find("xmax").text)
        ymax = int(member.find("bndbox").find("ymax").text)
        boxes.append([xmin, ymin, xmax, ymax])

    return boxes, tree


# Define the augmentation pipeline
transform = A.Compose(
    [
        A.Resize(height=350, width=350),
    ],
    bbox_params=A.BboxParams(format="pascal_voc", label_fields=["class_labels"]),
)


# Function to process a single image and its XML annotation
def process_image(image_path, xml_path):
    # Load the image
    image = cv2.imread(image_path)

    # Parse the XML file to get bounding boxes
    boxes, tree = parse_voc_xml(xml_path)
    class_labels = [0] * len(boxes)  # Replace with actual class labels if available

    # Apply the augmentation
    transformed = transform(image=image, bboxes=boxes, class_labels=class_labels)

    transformed_image = transformed["image"]
    transformed_bboxes = transformed["bboxes"]

    # Save the augmented image and create a new XML file with the updated bounding boxes
    augmented_image_path = image_path.replace("images", "augmented_images")
    augmented_xml_path = xml_path.replace("images", "augmented_images")

    os.makedirs(os.path.dirname(augmented_image_path), exist_ok=True)

    cv2.imwrite(augmented_image_path, transformed_image)

    create_augmented_xml(augmented_xml_path, transformed_bboxes, tree)


# Function to create a new XML file with updated bounding boxes
def create_augmented_xml(xml_path, bboxes, original_tree):
    root = original_tree.getroot()

    for i, member in enumerate(root.findall("object")):
        bbox = member.find("bndbox")
        bbox.find("xmin").text = str(bboxes[i][0])
        bbox.find("ymin").text = str(bboxes[i][1])
        bbox.find("xmax").text = str(bboxes[i][2])
        bbox.find("ymax").text = str(bboxes[i][3])

    original_tree.write(xml_path)


# Traverse directories and process images
base_dir = "Tensorflow/workspace/images/collectedimages"
for root, dirs, files in os.walk(base_dir):
    for file in files:
        if file.endswith(".jpg"):
            image_path = os.path.join(root, file)
            xml_path = image_path.replace(".jpg", ".xml")

            if os.path.exists(xml_path):
                process_image(image_path, xml_path)