# Mardi Inference Server Logic

## Diagram
![alt](../images/diagram.png)

## Introduction
This notebook presents an advanced methodology for processing, classifying, and segmenting images using state-of-the-art computer vision and deep learning techniques. The project implements cutting-edge models such as *You Only Look Once (YOLO)* and *Segment Anything Model (SAM)*, achieving
a comprehensive and efficient workflow that allows for precise image analysis for detection and segmentation.

An important aspect of this project was the addition of noise to images to test the model's robustness. Noise refers to random disturbances in a signal, and in our case, the signal was an image. Random disturbances in the brightness and color of an image are called image noise. Specifically, we introduced *salt-and-pepper noise*, which is a type of impulse noise found only in grayscale images.

Finally, we utilized a *Random Forest* classification model to predict the week of planting based on features extracted from the segmented images. *Random Forest* is an ensemble learning method known for its high accuracy and ability to handle many input features.

The methodology employed in this project demonstrated a comprehensive solution for image processing, classification, segmentation, and annotation using state-of-the-art deep learning and computer vision techniques. By integrating *You Only Look Once (YOLO)* for detection, *Segment Anything Model (SAM)* for segmentation, and *Random Forest* for prediction, the workflow ensured high precision and efficiency.

## Preprocessing
The first step involved preparing the images for analysis. All images were resized to a standard
dimension of 224x224 pixels. This standardization ensured that each image input into the model had
consistent dimensions, thereby enhancing the model's accuracy and performance.

In [None]:
import os
import cv2
import matplotlib.pyplot as plt

# Read and process the image
image_path = '../images/week_2.jpg'
original_image = cv2.imread(image_path)
resized_image = cv2.resize(original_image, (224, 224))

# Convert BGR to RGB
resized_image_rgb = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)

# Display the image using matplotlib
plt.imshow(resized_image_rgb)
plt.axis('off')  # Hide the axis
plt.title(os.path.basename(image_path))
plt.show()

## Age Group Classification using YOLO
Image classification involved assigning an entire image to one of a set of predefined classes. The output of an image classifier was a single class label accompanied by a confidence score. This task was fundamental when the objective was to determine the overall category to which an image belonged, without needing to specify the locations or shapes of individual objects within the image. The *You Only Look Once (YOLO)* model was utilized for this purpose due to its efficiency and accuracy in real-time applications. The images were classified into two classes based on the following criteria:
• *Class 1*: Include images of crops at 1, 2, and 3 weeks of age.
• *Class 2*: Include images of crops at 4, and 5 weeks of age.
The training process involved using the *YOLOv8* model for classification. The training process for the YOLO model was configured with several parameters to optimize performance.

In [None]:
import torch
import numpy as np
from models.common import DetectMultiBackend
from utils.general import check_img_size
from utils.torch_utils import select_device

device = select_device('0')

age_group_weights = '/data/models/model_classification.onnx'
age_group_model = DetectMultiBackend(age_group_weights, device=device, dnn=False, data='data/coco128.yaml', fp16=False)
age_group_stride = age_group_model.stride
age_group_imgsz = check_img_size((224, 224), s=age_group_stride)  # check image size
age_group_model.warmup(imgsz=(1, 3, *age_group_imgsz))

def preprocess_image_yolo_classification(image: np.ndarray) -> np.ndarray:
    im = image.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    im = np.ascontiguousarray(im)  # contiguous
    im = torch.from_numpy(im).to(age_group_model.device)
    im = im.half() if age_group_model.fp16 else im.float()  # uint8 to fp16/32
    im /= 255  # 0 - 255 to 0.0 - 1.0
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim
    return im

im = preprocess_image_yolo_classification(resized_image)
results = age_group_model(im)
top_class = results.argmax(dim=1).item()
age_group = age_group_model.names[top_class]

print(f'Predicted age group: {age_group}')

## Object Detection using YOLO
Object detection was performed using the *YOLOv9* model, specifically targeting the detection of plants within the images classified as *Class 1*. This process was automated, utilizing the pre-trained model *"gelan-c.pt"*. The use of this model allowed for high-resolution real-time detection without the need for additional training or manual annotations.

The model was specifically designed to automatically annotate the images using the *COCO (Common Objects in Context)* dataset. The COCO dataset is a large-scale object detection, segmentation, and captioning dataset that is widely used for benchmarking computer vision models. It includes a diverse set of object categories and provides high-quality annotations for research in object detection tasks.

In [None]:
import yaml
from utils.augmentations import letterbox
from utils.general import non_max_suppression, scale_boxes, xyxy2xywh

# Define the desired classes
desired_classes = [25, 58]

# Random color map for each class
color_map = {}
for class_id in desired_classes:
    color_map[class_id] = tuple(np.random.randint(0, 256, 3).tolist())

conf_thres = 0.1
iou_thres = 0.45
classes = None
agnostic_nms = False
max_det = 1000

object_detection_weights = '/data/yolov9/weights/gelan-c.pt'
object_detection_model = DetectMultiBackend(object_detection_weights, device=device, dnn=False, data='data/coco.yaml', fp16=False)
object_detection_stride, object_detection_pt = object_detection_model.stride, object_detection_model.pt
object_detection_imgsz = check_img_size((640, 640), s=object_detection_stride)  # check image size
object_detection_model.warmup(imgsz=(1 if object_detection_pt or object_detection_model.triton else 1, 3, *object_detection_imgsz))

def preprocess_image_object_detection(self, image: np.ndarray) -> np.ndarray:
    im = letterbox(image, 640, stride=stride_yolov9, auto=True)[0]  # padded resize
    im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    im = np.ascontiguousarray(im)  # contiguous
    im = torch.from_numpy(im).to(yolov9_model.device)
    im = im.half() if yolov9_model.fp16 else im.float()  # uint8 to fp16/32
    im /= 255  # 0 - 255 to 0.0 - 1.0
    if len(im.shape) == 3:
        im = im[None]  # expand for batch dim
    return im

def show_box(image, box, label, conf_score, color):
    x0, y0 = int(box[0]), int(box[1])
    x1, y1 = int(box[2]), int(box[3])
    cv2.rectangle(image, (x0, y0), (x1, y1), color, 2)
    label_text = f'{label} {conf_score:.2f}'
    cv2.putText(image, label_text, (x0, y0 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

# Load class names from YAML file
with open('data/coco.yaml', 'r') as file:
    coco_data = yaml.safe_load(file)
    class_names = coco_data['names']

class_ids = []
bboxes = []
conf_scores = []

if age_group == "Class1":
    im = preprocess_image_object_detection(resized_image)
    pred = object_detection_model(im, augment=False, visualize=False)
    pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

    image_height, image_width, _ = resized_image.shape
    gn = torch.tensor(resized_image.shape)[[1, 0, 1, 0]]  # normalization gain whwh
    aggregate_mask = np.zeros(resized_image.shape[:2], dtype=np.uint8)

    # Process predictions
    for i, det in enumerate(pred):
        if len(det):
            # Rescale boxes from img_size to image size
            det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], resized_image.shape).round()

            # Write results
            for *xyxy, conf, cls in reversed(det):
                if not cls in desired_classes:
                    continue

                xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                cx, cy, w, h = xywh

                # Convert from normalized [0, 1] to image scale
                cx *= image_width
                cy *= image_height
                w *= image_width
                h *= image_height

                # Convert center x, y, width and height to xmin, ymin, xmax, ymax
                xmin = cx - w / 2
                ymin = cy - h / 2
                xmax = cx + w / 2
                ymax = cy + h / 2

                class_ids.append(cls)
                bboxes.append((xmin, ymin, xmax, ymax))
                conf_scores.append(conf)

# Create a copy of the image to draw bounding boxes
image_with_bboxes = resized_image.copy()

# Draw bounding boxes
for class_id, bbox, conf_score in zip(class_ids, bboxes, conf_scores):
    class_name = class_names[class_id]
    color = color_map[class_id]
    show_box(image_with_bboxes, bbox, class_name, conf_score, color)

# Convert BGR to RGB for displaying with matplotlib
image_with_bboxes_rgb = cv2.cvtColor(image_with_bboxes, cv2.COLOR_BGR2RGB)

# Display the image using matplotlib
plt.imshow(image_with_bboxes_rgb)
plt.axis('off')  # Hide the axis
plt.title('Image with Bounding Boxes')
plt.show()

## Object Segmentation using SAM
First, *SAM* was used to extract masks from the detections made by *YOLO* in *Class 1*. This stage was crucial for isolating the objects detected by *YOLO* and preparing them for more detailed segmentation. We called it *SAM1* for this process.

In [None]:

from segment_anything import sam_model_registry, SamPredictor

model_type = "vit_h"
sam_checkpoint = "/data/models/sam_vit_h_4b8939.pth"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam_predictor_model = SamPredictor(sam)

# Find the most central bounding box if there are multiple detections
def find_central_bbox(bboxes):
    center_x, center_y = image_width / 2, image_height / 2
    min_distance = float('inf')
    central_bbox = None
    central_class_id = None
    central_conf_score = None

    for class_id, bbox, conf_score in zip(class_ids, bboxes, conf_scores):
        if class_id in desired_classes:
            bbox_center_x = (bbox[0] + bbox[2]) / 2
            bbox_center_y = (bbox[1] + bbox[3]) / 2
            distance = np.sqrt((center_x - bbox_center_x) ** 2 + (center_y - bbox_center_y) ** 2)
            if distance < min_distance:
                min_distance = distance
                central_bbox = bbox
                central_class_id = class_id
                central_conf_score = conf_score

    return central_class_id, central_bbox, central_conf_score


# Get the most central bounding box
central_class_id, central_bbox, central_conf_score = find_central_bbox(bboxes)

if central_class_id is not None and central_bbox is not None:
    # Generate and accumulate masks for each bounding box
    class_name = class_names[central_class_id]
    color = color_map[central_class_id]

    image_with_central_bbox = resized_image.copy()
    show_box(image_with_central_bbox, central_bbox, class_name, central_conf_score, color)
    image_with_central_bbox_rgb = cv2.cvtColor(image_with_central_bbox, cv2.COLOR_BGR2RGB)

    # Display the image using matplotlib
    plt.imshow(image_with_central_bbox_rgb)
    plt.axis('off')  # Hide the axis
    plt.title('Image with Central Bounding Box')
    plt.show()

    # Generate mask for the central bounding box
    input_box = np.array(central_bbox).reshape(1, 4)
    masks, _, _ = sam_predictor_model.predict(
        point_coords=None,
        point_labels=None,
        box=input_box,
        multimask_output=False,
    )
    aggregate_mask = np.where(masks[0] > 0.5, 1, aggregate_mask)

    # Convert the aggregated segmentation mask to a binary mask
    binary_mask = np.where(aggregate_mask == 1, 1, 0)

    # Create a white background with the same size as the image
    white_background = np.ones_like(resized_image) * 255

    # Applying the binary mask to the original image
    # Where the binary mask is 0 (background), use white background; otherwise, use the original image.
    segmented_image = white_background * (1 - binary_mask[..., np.newaxis]) + resized_image * binary_mask[..., np.newaxis]

    # Replace resized_image with segmented_image for the next steps
    resized_image = segmented_image.astype(np.uint8)

    # Show the new image
    segmented_image_rgb = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
    plt.imshow(segmented_image_rgb)
    plt.axis('off')  # Hide the axis
    plt.title('Segmented Image')
    plt.show()

## Adding Noise and Processing
*Salt-and-pepper noise* was added to the images to test the model's robustness. Noise referred to random disturbances in a signal, and in our case, the signal was an image. *Salt-and-pepper noise* is a type of impulse noise found only in grayscale images, introducing white spots in dark regions and black spots in light regions. Adding this noise simulated real-world conditions where images are not always perfect, ensuring that the model remained robust and reliable across various scenarios.

In [None]:
from utils.mardi import add_noise

gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
noisy_img = add_noise(gray_image)

image_bgr = cv2.cvtColor(noisy_img, cv2.COLOR_GRAY2BGR)
image_rgb = cv2.cvtColor(noisy_img, cv2.COLOR_GRAY2RGB)

# Display the image using matplotlib
plt.imshow(image_rgb)
plt.axis('off')  # Hide the axis
plt.title('Noisy Image')
plt.show()

## Feature Extraction and Classification
After extracting the initial masks with *SAM1*, an additional *SAM* was applied to perform more detailed and precise segmentation of objects in both classes (*Class 1* and *Class 2*). This second stage allowed for automatic and detailed segmentation. We called it *SAM2* for this process.

Feature extraction involved calculating specific characteristics from the segmented images, which were then used as input features for the Random Forest model. The primary features extracted were the number of segments (count) and the area of each segment. For each image, the following statistical measures of the segment areas were calculated:
• *Standard Deviation of Area (area_std)*: This measure provided insights into the variability of segment sizes within an image.
• *Mean of Area (area_mean)*: This measure indicates the average size of the segments.
• *Median of Area (area_median)*: This measure highlighted the central tendency of the segment sizes, less affected by outliers.
These features were essential for capturing the distribution and variability of plant segment sizes within the images.

In [None]:
import pandas as pd
import onnxruntime as rt
from segment_anything import SamAutomaticMaskGenerator
from utils.mardi import filter_out_background, generate_colors, is_white_background

 # Classification Group 1 model
classification1_model_path = "/data/models/rf_model_class1.onnx"
classification1_model = rt.InferenceSession(classification1_model_path, providers=["CPUExecutionProvider"])
classification1_input_name = classification1_model.get_inputs()[0].name
classification1_label_name = classification1_model.get_outputs()[0].name

# Classification Group 2 model
classification2_model_path = "/data/models/rf_model_class2.onnx"
classification2_model = rt.InferenceSession(classification2_model_path, providers=["CPUExecutionProvider"])
classification2_input_name = classification2_model.get_inputs()[0].name
classification2_label_name = classification2_model.get_outputs()[0].name

mask_generator = SamAutomaticMaskGenerator(sam)

sam_result = mask_generator.generate(image_rgb)

# Filter out the background
filtered_masks = filter_out_background(sam_result)

annotated_image = image_bgr.copy()
colors = generate_colors(len(filtered_masks))

segment_count = 0
mask_data = []
for i, mask in enumerate(filtered_masks):
    area = mask['area']
    segmentation = mask['segmentation'].astype('uint8')

    # Skip white background areas
    if is_white_background(segmentation, image_rgb):
        continue

    # Color the segmented area
    color = colors[i]
    r, g, b = color
    annotated_image[segmentation > 0] = cv2.addWeighted(annotated_image, 0.5, np.full_like(annotated_image, color), 0.5, 0)[segmentation > 0]

    # Append data
    segment_count += 1
    mask_data.append([area, r, g, b])

columns = ['area', 'r', 'g', 'b']
data = pd.DataFrame(mask_data, columns=columns)

# Create derived features, handling potential division by zero
data['r/g'] = (data['r'] / (data['g'] + 1e-8)).round(4)  # Add a small constant to avoid division by zero and round to 4 decimals
data['r/b'] = (data['r'] / (data['b'] + 1e-8)).round(4)
data['g/b'] = (data['g'] / (data['b'] + 1e-8)).round(4)

# Calculate aggregated statistics for 'area'
agg_area = data['area'].agg(['mean', 'median', 'std']).reset_index()
agg_area.columns = ['statistic', 'value']

# Round aggregated statistics to 4 decimal places
area_mean = agg_area.loc[agg_area['statistic'] == 'mean', 'value'].round(4).values[0]
area_median = agg_area.loc[agg_area['statistic'] == 'median', 'value'].round(4).values[0]
area_std = agg_area.loc[agg_area['statistic'] == 'std', 'value'].round(4).values[0]

classification_input = np.array([[segment_count, area_std, area_mean, area_median]]).astype(np.float32)

if age_group == "Class1":
    age = classification1_model.run([classification1_label_name], {classification1_input_name: classification_input})[0]
else:
    age = classification2_model.run([classification2_label_name], {classification2_input_name: classification_input})[0]

if len(age) > 0:
    age = age[0]
else:
    age = None

print(f"Segment count: {segment_count}")
print(f"Age: {age}")

# Display the annotated image
annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
plt.imshow(annotated_image_rgb)
plt.axis('off')  # Hide the axis
plt.title('Annotated Image')
plt.show()