In [None]:
import math
from typing import cast, Optional

import cv2
import numpy as np
import rasterio

In [None]:
FILENAME = "Test_Images/Test_Image_3.JPG"
IMAGE_BITS = 8

red_channel = 1
green_channel = 2
blue_channel = 3
nir_channel = math.inf
re_channel = math.inf

## Setup

In [None]:
import PIL.Image

def img_scale(img: np.ndarray) -> np.ndarray:
    return (img * 255).astype(np.uint8)

def display_image(img: np.ndarray, _: Optional[str] = None):
    display(PIL.Image.fromarray(img))

def display_contours(img: np.ndarray, contours, *, colour_rgb: tuple[int, int, int] = (0, 255, 0)):
    img_with_contours = img.copy()
    for contour in contours:
        cv2.drawContours(img_with_contours, [contour], 0, colour_rgb, 2)
    display_image(img_with_contours)

## Load image

In [None]:
blue_raw: Optional[np.ndarray] = None
green_raw: Optional[np.ndarray] = None
red_raw: Optional[np.ndarray] = None
nir_raw: Optional[np.ndarray] = None
re_raw: Optional[np.ndarray] = None

# TODO should probably use cv2 to load the image
with rasterio.open(FILENAME, 'r') as raster_img:
    raster_img = cast(rasterio.DatasetReader, raster_img)
    band_count = cast(int, raster_img.count)

    if (band_count >= red_channel):
        red_raw = raster_img.read(red_channel)
        print("red present")

    if (band_count >= green_channel):
        green_raw = raster_img.read(green_channel)
        print("green present")

    if (band_count >= blue_channel):
        blue_raw = raster_img.read(blue_channel)
        print("blue present")

    if (band_count >= nir_channel):
        nir_raw = raster_img.read(nir_channel)
        print("nir present")

    if (band_count >= re_channel):
        re_raw = raster_img.read(re_channel)
        print("re present")

In [None]:
# convert from ints to 0-1 floats

red: Optional[np.ndarray]
green: Optional[np.ndarray]
blue: Optional[np.ndarray]
nir: Optional[np.ndarray]
re: Optional[np.ndarray]

image_max_value = 2 ** IMAGE_BITS - 1

if red_raw is not None:
    red = red_raw.astype(float) / image_max_value

if green_raw is not None:
    green = green_raw.astype(float) / image_max_value

if blue_raw is not None:
    blue = blue_raw.astype(float) / image_max_value

if nir_raw is not None:
    nir = nir_raw.astype(float) / image_max_value

if re_raw is not None:
    re = re_raw.astype(float) / image_max_value

In [None]:
if red is None or green is None or blue is None:
    raise ValueError("not all rgb channels available")

# resolve type errors
red = red
green = green
blue = blue

np.seterr(divide='ignore', invalid='ignore')

In [None]:
img = img_scale(cv2.merge([blue, green, red]))
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
display_image(img_rgb)

## Generate mask

In [None]:
NGRDI = (green - red) / (green + red)
display_image(img_scale(NGRDI))

In [None]:
HUE = np.arctan((2 * (red - green - blue)) / (30.5 * (green - blue)))
display_image(img_scale(HUE))

In [None]:
NGRDI_mask = cv2.threshold(NGRDI, 0, 255, cv2.THRESH_BINARY)[1].astype(np.uint8)
display_image(NGRDI_mask)

## Process mask

In [None]:
img_masked = cv2.bitwise_or(img, img, mask=NGRDI_mask)
display_image(cv2.cvtColor(img_masked, cv2.COLOR_BGR2RGB))

In [None]:
img_masked_grey = cv2.cvtColor(img_masked, cv2.COLOR_BGR2GRAY)
display_image(img_masked_grey)

In [None]:
_, img_masked_thresholded = cv2.threshold(img_masked_grey, 10, 255, cv2.THRESH_BINARY)
display_image(img_masked_thresholded)

In [None]:
img_masked_eroded = cv2.erode(img_masked_thresholded, np.ones((4, 4), np.uint8), iterations = 2)
display_image(img_masked_eroded)

In [None]:
img_masked_dilated = cv2.dilate(img_masked_eroded, np.ones((3, 3), np.uint8), iterations = 8)
display_image(img_masked_dilated)

## Count plants

In [None]:
img_masked_grey_col = cv2.cvtColor(img_masked_grey, cv2.COLOR_GRAY2RGB)

In [None]:
contours_initial, _ = cv2.findContours(img_masked_dilated, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
print(len(contours_initial))
display_contours(img_masked_grey_col, contours_initial)

In [None]:
contours_filtered = []

for contour in contours_initial:
    # area = cv2.contourArea(contour)
    # if area < 600: continue
    
    perimeter = cv2.arcLength(contour, True)
    if perimeter < 40: continue

    contours_filtered.append(contour)

print(len(contours_filtered))
display_contours(img_masked_grey_col, contours_filtered)
display_contours(img_rgb, contours_filtered, colour_rgb=(0, 0, 255))