# Segment and Prediction Notebook

## Setup

### Imports and Global Variables


In [None]:
import os
import threading

from queue import Queue

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.transforms as transforms
import numpy as np 
import pandas as pd 

from PIL import Image
from ultralytics import YOLO

obj_name = {
    0: 'Main Pad',
    1: 'LED Pad',
    2: 'Main Chip',
    3: 'LED Chip'
}

models = {
    'final_yolov8s-seg_cgi_sled': ['segment', '320'],
    'final_yolov8s-seg_human_sled': ['segment', '320'],
    'final_yolov8s-seg_combined_sled': ['segment', '320'],
    'final_yolov9c-seg_cgi_sled': ['segment', '320'],
    'final_yolov9c-seg_human_sled': ['segment', '320'],
    'final_yolov9c-seg_combined_sled': ['segment', '320'],
    'final_yolov8s-obb_cgi_sled': ['OBB', '400'],
    'final_yolov8s-obb_human_sled': ['OBB', '400'],
    'final_yolov8s-obb_combined_sled': ['OBB', '400'],
}

l_metrics = [None] * len(models)
q = Queue(maxsize=1)

test_cropped_path = os.path.relpath("./datasets/test_cropped")
test_path = os.path.relpath("./datasets/test")

test_cropped_yaml_path = os.path.relpath("./datasets_yaml/test_cropped.yaml")
test_yaml_path = os.path.relpath("./datasets_yaml/sled/combined_sled.yaml")

images = []
no_led = []


### Thread safe function to test and validate models


In [None]:
def thread_safe_test(model_name, type, i, cropped_test, modified_imgsz):

    q.put(1)

    model = YOLO(os.path.join("runs", type, model_name, "weights", "best.pt"))

    if cropped_test:
        name = 'ctest_' + model_name
        if os.path.exists(os.path.join("runs", type, name)):
            return
        l_metrics[i] = model.val(data=test_cropped_yaml_path, plots=True, device=0, imgsz=modified_imgsz, split='test', name=name, task=type)
    else:
        name = 'test_' + model_name
        if os.path.exists(os.path.join("runs", type, name)):
            return
        l_metrics[i] = model.val(data=test_yaml_path, plots=True, device=0, split='test', name=name, task=type)



### Function to get a list of all images


In [None]:
def get_all_images(path):
    for file in os.listdir(path):
        if file.endswith(".jpg") or file.endswith(".png"):
            images.append(os.path.join(path, file))


## Run Codes


### Validating the models


In [None]:
for i, (model, v) in enumerate(models.items()):
    t = threading.Thread(target=thread_safe_test, args=(model, v[0], i, True, v[1]))
    t.start()
    t.join()
    q.get()


In [None]:
for i, metric in enumerate(l_metrics):
    print(metric.box.map)


### Label checking


In [None]:
# img_dir = 'datasets/rendered/train'
# img_name = 'bad_9.png'

img_dir = 'datasets/test_cropped'
img_name = 'HHF22220109_5_20_72-MW_1.jpg'

# Load the image
image_path = os.path.join(img_dir, img_name)
image = Image.open(image_path)

image_width, image_height = image.size

# Load the label in YOLO polygon format
label_name = img_name.replace('png', 'txt').replace('jpg', 'txt')
label_path = os.path.join(img_dir, label_name)
with open(label_path, 'r') as f:
    labels = f.read().splitlines()

# Parse the label to extract the polygon coordinates
# Assuming the label is in the format: id x1 y1 x2 y2 x3 y3 x4 y4
shapes = []
for label in labels:
    label = label.split(' ')
    label = [ float(x) for x in label ]
    point1 = (label[1]*image_width, label[2]*image_height)
    point2 = (label[3]*image_width, label[4]*image_height)
    point3 = (label[5]*image_width, label[6]*image_height)
    point4 = (label[7]*image_width, label[8]*image_height)

    shapes.append((label[0],[point1, point2, point3, point4]))


# Plot the image
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')

for data in shapes:
    i, shape = data
    polygon = patches.Polygon(shape, linewidth=1, edgecolor='r', facecolor='none')    

    # Add the patch to the Axes
    ax.add_patch(polygon)

    # Add label to the patches according to obj_name
    label = obj_name[i]
    ax.text(shape[0][0], shape[0][1], label, color='white')

plt.show()


### Get labelled output


In [None]:
model = YOLO('runs/segment/final_yolov9c-seg_combined_sled/weights/best.pt')

get_all_images(test_path)


# interest: 22

images =  [
    "C:/Users/KohCo/Desktop/FYP/ASPLProject/datasets/test/HHH23630057_9_13_55-DPI_1.jpg",
    "C:/Users/KohCo/Desktop/FYP/ASPLProject/datasets/test/HHH23630078_40_11_55-DPI_1.jpg",
    "C:/Users/KohCo/Desktop/FYP/ASPLProject/datasets/test/HHH23740006_4_9_55-DPI_1.jpg"
]

results = model(images, stream=True)

no_led = []

results_list = []

for i, result in enumerate(results):
    no_led.append((images[i],result))
    results_list.append(result)
    


### Display outputs


In [None]:
# WIP

# for i, result in enumerate(results_list):
#     result.show()


for i, (image, result) in enumerate(no_led):

    fig, ax = plt.subplots()

    data = result.summary()
    img_name = image.split('/')[-1].split('.')[0]

    img = cv2.imread(image)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # thresholding
    # blur = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    # th = cv2.threshold(blur, 40, 128, cv2.THRESH_BINARY)[1]

    for item in data:

        segment = item["segments"]

        contours = []
        for i, seg in enumerate(segment['x']):
            contours.append(np.array([seg, segment['y'][i]]))
        ctr = np.array(contours).reshape((-1,1,2)).astype(np.int32)

        rect = cv2.minAreaRect(ctr)
        box = np.intp(cv2.boxPoints(rect))

        if item["class"] == 1:
        # if True:
            if rect[1][0] > rect[1][1]:
                angle = rect[2]
            else:
                angle = 90 - rect[2]
            
            M = cv2.getRotationMatrix2D(rect[0], angle, 1)
            # img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
            # th = cv2.warpAffine(th, M, (th.shape[1], th.shape[0]))
            box = cv2.transform(np.array([box]), M)[0]

        # M = cv2.moments(th)
        # cX = int(M["m10"] / M["m00"])
        # cY = int(M["m01"] / M["m00"])

        # # Check if the centroid is in the upper half of the image
        # if cY < img.shape[0] / 2:
        #     # Rotate the image by 180 degrees
        #     M = cv2.getRotationMatrix2D((img.shape[1] / 2, img.shape[0] / 2), 180, 1)
        #     img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
        #     box = cv2.transform(np.array([box]), M)[0]

        box_center = np.mean(box, axis=0)
        box_width = None
        box_height = None
        max_x = 0
        min_x = 1000000000
        max_y = 0
        min_y = 1000000000
        for i, _ in enumerate(box):
            max_x = max(max_x, box[i][0])
            min_x = min(min_x, box[i][0])
            max_y = max(max_y, box[i][1])
            min_y = min(min_y, box[i][1])
        box_width = max_x - min_x
        box_height = max_y - min_y

        if item["class"] == 1:

            box_1 = box.copy()
            box_2 = box.copy()

            # subtract box1 by box_width /2
            for i, _ in enumerate(box_1):
                box_1[i][0] = (box_1[i][0] * 0.35) + (1-0.35) * box_center[0] - 0.25 * box_width
                box_1[i][1] = (box_1[i][1] * 0.7) + (1-0.7) * box_center[1]
            
            # add box2 by box_width /2
            for i, _ in enumerate(box_2):
                box_2[i][0] = (box_2[i][0] * 0.33) + (1-0.33) * box_center[0] + 0.25 * box_width
                box_2[i][1] = (box_2[i][1] * 0.66) + (1-0.66) * box_center[1]

            box_polygon = patches.Polygon(box, linewidth=4, edgecolor='salmon', facecolor='none')
            ax.add_patch(box_polygon)

            label = "Segmented Pad"
            ax.text(box_center[0]-box_width/2+2, box_center[1]-box_height/2-15, label, color='salmon')

            for p in contours:
                plt.scatter(p[0], p[1], c='white', s=0.5)

        elif item["class"] == 3:
            chip_polygon = patches.Polygon(box, linewidth=2, edgecolor='skyblue', facecolor='none', alpha=0.3)
            ax.add_patch(chip_polygon)

            label = "Segmented Chip"
            ax.text(box_center[0]-box_width/2+2, box_center[1]-box_height/2-15, label, color='skyblue')

            for p in contours:
                plt.scatter(p[0], p[1], c='white', s=0.5)

    box_polygon_1 = patches.Polygon(box_1, linewidth=2, edgecolor='lightgreen', facecolor='none')
    box_polygon_2 = patches.Polygon(box_2, linewidth=2, edgecolor='lightgreen', facecolor='none')

    ax.add_patch(box_polygon_1)
    ax.add_patch(box_polygon_2)

    ax.autoscale_view()
    ax.set_aspect('equal')
    
    # plt.imshow(img)
    plt.axis('off')
    plt.savefig(f'./results/{img_name}_predict.png', bbox_inches='tight', pad_inches=0, transparent=True)



        

