### Imports

In [None]:
import cv2
import torch
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance, ImageFilter
from PIL.ImageOps import invert

In [None]:
show_intermediate = False   # set to True to show intermediate results
output = []                 # create list to store all final outputs

### Image preprocessing

In [None]:
test_image = '../assets/example_image.jpg'

# convert to grayscale
img = Image.open(test_image).convert('L')
img = invert(img)
img = ImageEnhance.Contrast(img).enhance(2)
img = img.point(lambda p: p > 220 and 255)
img = img.filter(ImageFilter.SMOOTH)

# show image
if show_intermediate:
    img.show()

### Components Inference

In [None]:
# Get the model
model_path = '../models/components.pt' 
c_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path)
c_model.eval()

In [None]:
# Inference
c_results = c_model(img)

# Create a new dataframe with the results
for result in c_results.xyxy:
    output.append([result[5], result[0], result[1], result[2], result[3], result[4]])

# Print and show results
if show_intermediate:
    print(c_results.pandas().xyxy)
    c_results.show()

### Junction inference

In [None]:
model_path = '../models/junctions.pt'
j_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path)
j_model.eval()

In [None]:
# Perform inference on the image without components
j_results = j_model(img)

# Print and show results
if show_intermediate:
    print(j_results.pandas().xyxy)
    j_results.show()

### Post processing

In [None]:
# Remove overlapping junctions (move this function to utils later)
def remove_overlapping_junctions(j_results, c_results):
    """ Remove junctions that are overlapping with components

    Args:
        j_results (yolov5.results): Results of the junction detection model
        c_results (yolov5.results): Results of the component detection model

    Returns:
        list: List of coordinates of junctions that are not overlapping with components
    """
    # Get the bounding boxes of the junctions and the components
    j_boxes = j_results.xyxy[0]
    c_boxes = c_results.xyxy[0]

    # Get the coordinates of the junctions and the components
    j_coords = [(int(box[0]), int(box[1]), int(box[2]), int(box[3])) for box in j_boxes]
    c_coords = [(int(box[0]), int(box[1]), int(box[2]), int(box[3])) for box in c_boxes]

    # Remove junctions that are overlapping with components
    for j_coord in j_coords:
        for c_coord in c_coords:
            if j_coord[0] >= c_coord[0] and j_coord[1] >= c_coord[1] and j_coord[2] <= c_coord[2] and j_coord[3] <= c_coord[3]:
                j_coords.remove(j_coord)
                break

        # TODO: keep track of which junctions are kept and add those to output list

    return j_coords

j_coords = remove_overlapping_junctions(j_results, c_results)

# TODO: perform non-maximum suppression on the junctions

### Convert to generated image
Take the final output list and generate the digital circuit based on that

In [None]:
# TODO