In [1]:
from segment_anything import sam_model_registry, SamPredictor
from ultralytics import YOLO, settings
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import torchvision
import requests
import shutil
import torch
import json
import cv2
import sys
import os

torch.cuda.empty_cache()
CUDA = torch.cuda.is_available()
print("CUDA is available:", CUDA)

CUDA is available: True


In [2]:
MODEL_PATH_SEGMENT = '../models/plant_segmentation_v1.pt'
IMG_PATH_FIXED = '../images/cropped_scales/fixed'
IMG_PATH_RANDOM = '../images/cropped_scales/random'
DATA_PATH = '../data/processed'
DEVICE = "cuda" if CUDA else "cpu"
settings.update({'runs_dir': rf'/home/floris/Projects/NTNU/models/runs'})

model_seg = YOLO(MODEL_PATH_SEGMENT)

In [3]:
def measure_scale_fixed_via_colorboard(image_path):
    """
    Function to measure the scale of the image using a fixed colorboard
    The colorboard is a set of 7 boxes with bright colors
    The width of the colorboard is 4.5 cm
    The function returns the pixels per cm
    """
    # Load the image
    image = cv2.imread(image_path)

    # Convert to HSV (Hue, Saturation, Value) color space for easier color segmentation
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # Define range of bright colors in HSV
    lower_color = np.array([0, 100, 100])
    upper_color = np.array([179, 255, 255])

    # Threshold the HSV image to get only bright colors
    mask = cv2.inRange(hsv, lower_color, upper_color)

    # Find contours in the mask
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    # Filter out small contours that are not our boxes
    box_contours = [cnt for cnt in contours if cv2.contourArea(cnt) > 100]

    # Calculate bounding boxes for each contour
    bounding_boxes = [cv2.boundingRect(cnt) for cnt in box_contours]

    # Determine the midpoint of the image width
    midpoint = image.shape[1] / 2

    # Keep only the boxes that have an x-coordinate greater than the midpoint
    right_half_boxes = [box for box in bounding_boxes if box[0] > midpoint]

    # Sort these boxes by their x-coordinate to ensure rightmost first
    sorted_right_half_boxes = sorted(right_half_boxes, key=lambda x: x[0], reverse=True)

    # Draw these seven boxes on the image
    for (x, y, w, h) in sorted_right_half_boxes:
        cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), 2)

    # Calculate the total width in pixels of these seven boxes
    total_width_in_pixels = sum([box[2] for box in sorted_right_half_boxes])

    # Since 7 boxes = 4.5 cm, calculate the pixels per cm
    pixels_per_cm = total_width_in_pixels / 5.79

    return pixels_per_cm

def transform_px_to_cm(box, px_per_cm):
    """
    Function to transform the width and height of a box from pixels to cm
    """
    w = np.abs((box[2] - box[0]).cpu())
    h = np.abs((box[3] - box[1]).cpu())
    return w / px_per_cm, h / px_per_cm

def get_masked_image(image, mask):
    """
    Apply a mask to an image with transparency
    """
    # Remove single-dimensional entry from the shape of the mask
    mask_squeezed = np.squeeze(mask)  # This should change mask shape to (5831, 3391)
    # Generate an alpha channel where mask is True (255) and False (0)
    alpha_channel = np.where(mask_squeezed, 255, 0).astype(np.uint8)
    # Ensure alpha channel is correctly shaped [H, W] -> [H, W, 1]
    alpha_channel_shaped = np.expand_dims(alpha_channel, axis=-1)

    # print("Image shape:", image.size)
    # print("Alpha channel shape:", alpha_channel_shaped.shape)

    # Concatenate the alpha channel with the image to create an RGBA image
    rgba_image = np.concatenate((image, alpha_channel_shaped), axis=-1)
    return rgba_image

def get_cropped_image(image, box):
    """
    Crop an image with a given box
    """
    if isinstance(box, list):
        box = box[0]
    if isinstance(box, torch.Tensor):
        box = box.cpu().numpy() 
    else:
        box = np.array(box) 

    if len(box.shape) > 1:
        box = box[0]
    x, y, w, h = int(box[0]), int(box[1]), int(box[2] - box[0]), int(box[3] - box[1])
    return image[y:y+h, x:x+w]

def apply_crop_mask(image, mask, box):
    """
    Apply a mask to an image and crop the image with a given box
    Returns a list of tuples with the masked image and the cropped image
    """
    images = []
    if len(np.array(mask).shape) == 3:
        for i, m in enumerate(mask):
            m_img = get_masked_image(image, m)
            b = box[i].cpu().numpy() if type(box) == torch.Tensor else box[i]
            crop_img = get_cropped_image(m_img, b)
            images.append((m_img, crop_img))
    else:
        m_img = get_masked_image(image, mask)
        b = box.cpu().numpy() if type(box) == torch.Tensor else box
        crop_img = get_cropped_image(m_img, b)
        images.append((m_img, crop_img))
    return images

def find_dominant_color(image, k=3):
    # Convert image to numpy array
    img_array = np.array(image)
    # Reshape it to a list of RGB values
    img_vector = img_array.reshape((-1, 3))
    # Run k-means on the pixel colors (fit only on a subsample to speed up)
    kmeans = KMeans(n_clusters=k, random_state=0).fit(img_vector[::50])
    # Get the dominant color
    dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
    # Create a mask for pixels within a certain distance from the dominant color
    distances = np.sqrt(np.sum((img_vector - dominant_color) ** 2, axis=1))
    mask = distances < np.std(distances)
    # Turn the dominant color range to white
    img_vector[mask] = [255, 255, 255]
    result_img_array = img_vector.reshape(img_array.shape)
    # turn image back to PIL
    result_img = Image.fromarray(result_img_array.astype(np.uint8))
    return dominant_color, result_img

def calculate_mask_area(masked_pixels, pixels_per_cm):
    area_square_cm = masked_pixels / (pixels_per_cm ** 2)
    return area_square_cm

def get_images(path, range_left=0, range_right=-1):
    if not os.path.exists(path):
        print(f"Path {path} does not exist")
        return []
    if len(os.listdir(path)) == 0:
        print(f"Path {path} is empty")
        return []
    
    images = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.jpg') and 'only' not in f and 'grid' not in f]
    return images[range_left:range_right]

def calc_non_transparent_pixel_count(image):
    image = Image.open(image)
    img_arr = np.array(image)
    ntp = np.where(img_arr[:, :, 3] != 0)
    return len(ntp[0])

def generate_output(images, model):

    results = model(images, verbose=False, retina_masks=True, conf=0.5)
    output = []

    for result in results:

        image_path = result.path
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        boxes = result.boxes.xyxy
        masks = result.masks.data[0].cpu().numpy()

        output.append({'image_path': image_path, 'image': image, 'boxes': boxes, 'masks': masks})
    
    return output

def main(output):
    """
    Function to display the output of the model
    It displays the image with the boxes and masks, and the width and height of the boxes in cm
    """
    results = []

    for _, res in enumerate(output):

        path = res['image_path']
        image = res['image']
        mask = res['masks']
        boxes = res['boxes']
        name = os.path.basename(path)

        image = find_dominant_color(image)[1]
                
        # look for cropped scale
        scale_path = path.replace('.jpg', '_scale_only.jpg')
        px_per_cm = measure_scale_fixed_via_colorboard(scale_path)

        all_masks_with_sq_cm = []

        if len(mask.shape) == 3:
            for m in mask:
                m_sum = m[0].sum().tolist()
                square_cm = calculate_mask_area(m_sum, px_per_cm)
                all_masks_with_sq_cm.append((m, square_cm))
        else:
            m_sum = mask.sum().tolist()
            square_cm = calculate_mask_area(m_sum, px_per_cm)
            all_masks_with_sq_cm.append((mask, square_cm))
            

        all_boxes = []
        for box in boxes:
            w_cm, h_cm = transform_px_to_cm(box, px_per_cm)
            all_boxes.append((box, {'width_cm': w_cm, 'height_cm': h_cm}))
 
        
        masked_and_cropped_images = apply_crop_mask(image, mask, boxes) # list of tuples (masked_image, cropped_image) per box / mask

        results.append({'image': image, 'image_path': path, 'image_name': name, 'boxes': all_boxes, 'masks_and_sqcm': all_masks_with_sq_cm, 'px_per_cm': px_per_cm, 'masked_and_cropped_images': masked_and_cropped_images})
    
    return results

def save_results(results, path):

    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

    for res in results:

        image = res['image']
        mask = [r[0] for r in res['masks_and_sqcm']]
        box = [box[0] for box in res['boxes']]

        imgs = apply_crop_mask(image, mask, box)
        for idx, i in enumerate(imgs):
            img_name = res['image_name'].replace('.jpg', f'_plant_mask_crop_{chr(idx + 97)}.png')
            img_path = os.path.join(path, img_name)

            # to turn the white pixels to transparent
            threshold = 250
            pil_img = Image.fromarray(i[1])
            datas = pil_img.getdata()
            new_image_data = []
            for item in datas:
                if item[0] > threshold and item[1] > threshold and item[2] > threshold:
                    new_image_data.append((255, 255, 255, 0))
                else:
                    new_image_data.append(item)

            pil_img.putdata(new_image_data)
            pil_img.save(img_path)

In [None]:
images = get_images(IMG_PATH_FIXED)

output = generate_output(images[0:20], model_seg)

results = main(output)

save_results(results, DATA_PATH)

In [16]:
number = 4
sq_cm_bef = results[number]['masks_and_sqcm'][0][1]
print(f'Before: {sq_cm_bef:.2f} cm²')
img_name = os.path.basename(results[number]['image_path']).replace('.jpg', '_plant_mask_crop_a.png')
img_path = os.path.join(DATA_PATH, img_name)
ntpx = calc_non_transparent_pixel_count(img_path)
area = calculate_mask_area(ntpx, results[number]['px_per_cm'])
print(f'After removing white space: {area:.2f} cm²')

Before: 31.71 cm²
After removing white space: 19.50 cm²
