# Segment and Prediction Notebook

## Setup

### Imports and Global Variables


In [1]:
import os
import threading

from queue import Queue

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.transforms as transforms
import numpy as np 
import pandas as pd 

from PIL import Image
from ultralytics import YOLO

obj_name = {
    0: 'Main Pad',
    1: 'LED Pad',
    2: 'Main Chip',
    3: 'LED Chip'
}

models = {
    'final_yolov8s-seg_cgi_sled': ['seg', '320'],
    'final_yolov8s-seg_human_sled': ['seg', '320'],
    'final_yolov8s-seg_combined_sled': ['seg', '320'],
    'final_yolov9c-seg_cgi_sled': ['seg', '320'],
    'final_yolov9c-seg_human_sled': ['seg', '320'],
    'final_yolov9c-seg_combined_sled': ['seg', '320'],
    'final_yolov8s-obb_cgi_sled': ['obb', '400'],
    'final_yolov8s-obb_human_sled': ['obb', '400'],
    'final_yolov8s-obb_combined_sled': ['obb', '400'],
}

l_metrics = [None] * len(models)
q = Queue(maxsize=1)

test_cropped_path = os.path.relpath("./datasets/test_cropped")
test_path = os.path.relpath("./datasets/test")

test_cropped_yaml_path = os.path.relpath("./datasets_yaml/test_cropped.yaml")
test_yaml_path = os.path.relpath("./datasets_yaml/sled/combined_sled.yaml")

images = []
no_led = []


### Function to visualize labeled data


In [None]:
def label_check(dir, image_name):
    # Load the image
    image_path = os.path.join(dir, image_name)
    image = Image.open(image_path)

    image_width, image_height = image.size

    # Load the label in YOLO polygon format
    label_name = image_name.replace('png', 'txt').replace('jpg', 'txt')
    label_path = os.path.join(dir, label_name)
    with open(label_path, 'r') as f:
        labels = f.read().splitlines()

    # Parse the label to extract the polygon coordinates
    # Assuming the label is in the format: id x1 y1 x2 y2 x3 y3 x4 y4
    shapes = []
    for label in labels:
        label = label.split(' ')
        label = [ float(x) for x in label ]
        point1 = (label[1]*image_width, label[2]*image_height)
        point2 = (label[3]*image_width, label[4]*image_height)
        point3 = (label[5]*image_width, label[6]*image_height)
        point4 = (label[7]*image_width, label[8]*image_height)

        shapes.append((label[0],[point1, point2, point3, point4]))

    # just plot the fucking points as corner of squares
    # Plot the image
    _, ax = plt.subplots()
    ax.imshow(image)
    ax.axis('off')

    for data in shapes:
        i, shape = data
        polygon = patches.Polygon(shape, linewidth=1, edgecolor='r', facecolor='none')    

        # Add the patch to the Axes
        ax.add_patch(polygon)

        # Add label to the patches according to obj_name
        label = obj_name[i]
        ax.text(shape[0][0], shape[0][1], label, color='white')

    plt.show()


### Thread safe function to test and validate models


In [None]:
def thread_safe_test(model, type, i, cropped_test, modified_imgsz):

    q.put(1)

    model = YOLO(os.path.join("runs", type, model, "weights", "best.pt"))

    if cropped_test:
        l_metrics[i] = model.val(data=test_cropped_yaml_path, plots=True, device=0, imgsz=modified_imgsz)
    else:
        l_metrics[i] = model.val(data=test_yaml_path, plots=True, device=0)

    q.get()


### Function to get a list of all images


In [None]:
def get_all_images(path):
    for file in os.listdir(path):
        if file.endswith(".jpg") or file.endswith(".png"):
            images.append(os.path.join(path, file))


## Run Codes


### Validating the models


In [None]:
for i, (model, v) in enumerate(models.items()):
    threading.Thread(target=thread_safe_test, args=(model, v[0], i, False, v[1])).start()

for i, metric in enumerate(l_metrics):
    metric.box.map


### Label checking


In [None]:
img_dir = 'datasets/rendered/train'
img_name = 'bad_9.png'

# img_dir = 'datasets/aspl/train'
# img_name = 'HHF22150120_40_10_56-MD_1.jpg'

label_check(img_dir, img_name)


### Get labelled output


In [None]:
model = YOLO('runs/segment/final_yolov8s-seg_cgi_sled/weights/best.pt')

results = model(images, stream=True)

for i, result in enumerate(results):
    no_led.append((images[i],result))


### Display outputs


In [None]:
# WIP

for i, (image, result) in enumerate(no_led):
    data = result.summary()
    if i % 10 != 0:
        continue
    for item in data:


        img = cv2.imread(image)

        # thresholding
        blur = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        th = cv2.threshold(blur, 40, 128, cv2.THRESH_BINARY)[1]

        segment = item["segments"]

        contours = []
        for i, seg in enumerate(segment['x']):
            contours.append(np.array([seg, segment['y'][i]]))
        ctr = np.array(contours).reshape((-1,1,2)).astype(np.int32)

        rect = cv2.minAreaRect(ctr)
        box = np.intp(cv2.boxPoints(rect))




        if rect[1][0] > rect[1][1]:
            angle = rect[2]
        else:
            angle = 90 + rect[2]
        
        M = cv2.getRotationMatrix2D(rect[0], angle, 1)
        img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
        th = cv2.warpAffine(th, M, (th.shape[1], th.shape[0]))
        box = cv2.transform(np.array([box]), M)[0]




        M = cv2.moments(th)
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])

        # Check if the centroid is in the upper half of the image
        if cY < img.shape[0] / 2:
            # Rotate the image by 180 degrees
            M = cv2.getRotationMatrix2D((img.shape[1] / 2, img.shape[0] / 2), 180, 1)
            img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
            box = cv2.transform(np.array([box]), M)[0]




        box_center = np.mean(box, axis=0)
        box_width = np.linalg.norm(box[0] - box[1])
        box_height = np.linalg.norm(box[1] - box[2])

        print(box_center, box_width, box_height)

        box_1 = box.copy()
        box_2 = box.copy()

        # subtract box1 by box_width /2
        for i, _ in enumerate(box_1):
            box_1[i][0] = (box_1[i][0] * 0.35) + (1-0.35) * box_center[0]
            box_1[i][1] = (box_1[i][1] * 0.7) + (1-0.7) * box_center[1]
        
        # add box2 by box_width /2
        for i, _ in enumerate(box_2):
            box_2[i][0] = (box_2[i][0] * 0.33) + (1-0.33) * box_center[0] 
            box_2[i][1] = (box_2[i][1] * 0.66) + (1-0.66) * box_center[1]




        box_polygon = patches.Polygon(box, linewidth=1, edgecolor='r', facecolor='none')
        box_polygon_1 = patches.Polygon(box_1, linewidth=1, edgecolor='r', facecolor='none')
        box_polygon_2 = patches.Polygon(box_2, linewidth=1, edgecolor='r', facecolor='none')

        plt.gca().add_patch(box_polygon)
        plt.gca().add_patch(box_polygon_1)
        plt.gca().add_patch(box_polygon_2)
        plt.imshow(img)
        plt.axis('off')
        plt.show()
        
