In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import xml.etree.ElementTree as ET

def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

def parse_xml(xml_file, valid_junction_types, other_objects):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    junctions = []
    mask_areas = []

    for obj in root.iter('object'):
        label = obj.find('name').text
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)
        if label in valid_junction_types:
            junctions.append((label, xmin, ymin, xmax, ymax))
        if label in other_objects:
            mask_areas.append((xmin, ymin, xmax, ymax))
    return junctions, mask_areas

def mask_objects(dilated, mask_areas):
    # Remove contours by filling the area with the same color as the background
    for xmin, ymin, xmax, ymax in mask_areas:
        cv2.rectangle(dilated, (xmin, ymin), (xmax, ymax), (0, 0, 0), -1)

def find_intersections(dilated, bbox):
    xmin, ymin, xmax, ymax = bbox
    points = {
        'top': [(x, ymin) for x in range(xmin, xmax + 1) if dilated[ymin, x] > 0],
        'bottom': [(x, ymax) for x in range(xmin, xmax + 1) if dilated[ymax, x] > 0],
        'left': [(xmin, y) for y in range(ymin, ymax + 1) if dilated[y, xmin] > 0],
        'right': [(xmax, y) for y in range(ymin, ymax + 1) if dilated[y, xmax] > 0]
    }
    
    # Filter to only keep the furthest two points on each side
    filtered_intersections = {}
    for side, pts in points.items():
        if len(pts) > 2:
            # Sort points based on their distance to the geometric center of the bounding box side
            center = {
                'top': ((xmin + xmax) // 2, ymin),
                'bottom': ((xmin + xmax) // 2, ymax),
                'left': (xmin, (ymin + ymax) // 2),
                'right': (xmax, (ymin + ymax) // 2)
            }[side]
            pts.sort(key=lambda p: (p[0] - center[0]) ** 2 + (p[1] - center[1]) ** 2)
            filtered_intersections[side] = [pts[0], pts[-1]]  # Keep the furthest apart points
        elif len(pts) == 2:
            filtered_intersections[side] = pts
    
    return filtered_intersections

def draw_bounding_boxes_and_centers(image_path, junctions, mask_areas):
    img = cv2.imread(image_path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    _, otsu_thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    edges = cv2.Canny(otsu_thresh, 50, 150, apertureSize=3)
    dilated = cv2.dilate(edges, np.ones((2,2), np.uint8), iterations=1)

    #edges = cv2.Canny(otsu_thresh, 50, 150, apertureSize=3)
    #dilated = cv2.dilate(edges, np.ones((3,3), np.uint8), iterations=2)

    # Masking areas to delete irrelevant contours
    mask_objects(dilated, mask_areas)

    img_with_boxes = cv2.cvtColor(dilated, cv2.COLOR_GRAY2BGR)

    for idx, (label, xmin, ymin, xmax, ymax) in enumerate(junctions):
        intersections = find_intersections(dilated, (xmin, ymin, xmax, ymax))
        cv2.rectangle(img_with_boxes, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
        # Annotate bbox
        cv2.putText(img_with_boxes, f"{idx}", (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
        center_x = (xmin + xmax) // 2
        center_y = (ymin + ymax) // 2
        cv2.circle(img_with_boxes, (center_x, center_y), 5, (255, 0, 0), -1)

        for side, points in intersections.items():
            pt1, pt2 = points
            midpoint = ((pt1[0] + pt2[0]) // 2, (pt1[1] + pt2[1]) // 2)
            cv2.circle(img_with_boxes, pt1, 5, (255, 255, 0), -1)
            cv2.circle(img_with_boxes, pt2, 5, (255, 255, 0), -1)
            cv2.circle(img_with_boxes, midpoint, 5, (0, 0, 255), -1)
            print(f"Junction {idx} ({label}), Side: {side}, Intersection 1: {pt1}, Intersection 2: {pt2}, Midpoint: {midpoint}")

    return img_with_boxes

xml_dir = './data/xmls/VA_D_00034.xml'
image_path = './data/images/VA_D_00034.PNG'
output_folder_path = './output_images'
ensure_dir(output_folder_path)

valid_junction_types = ["junc_I", "junc_I_normal", "junc_I_open", "junc_I_isolation", "junc_L", "junc_T", "junc_X"]
other_objects = ["door_normal", "door_double", "window", "door_hinged"]

junctions, mask_areas = parse_xml(xml_dir, valid_junction_types, other_objects)
final_image = draw_bounding_boxes_and_centers(image_path, junctions, mask_areas)

output_image_path = os.path.join(output_folder_path, 'final_image_with_junctions_and_centers.png')
cv2.imwrite(output_image_path, final_image)

plt.figure(figsize=(10, 10))
plt.imshow(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB))
plt.savefig('./111.png')
plt.title('Final Image with Junctions, Centers, and Valid Intersections')
plt.axis('off')
plt.show()
