### Exploring data augmentation and classic image segmentation methods
### Ming Ong

#### Import necessary libraries

In [None]:
import cv2
import joblib
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import random
import re
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

#### Set paths for images, and splits data into train/test set, assumes segmentation folder is in parent directory

In [None]:
folder_dir = 'USA_segmentation'
masks_dir = 'masks'
nrg_dir = 'NRG_images'
rgb_dir = 'RGB_images'
pwd = Path('../')
rgb_images_dir = pwd / folder_dir / rgb_dir
nrg_images_dir = pwd / folder_dir / nrg_dir
mask_images_dir = pwd / folder_dir / masks_dir
filenames = [re.search(r'(?:^[^_]*_)(.*)', re.search(r'[^\/]*$', str(f)).group()).group(1) for f in list(mask_images_dir.iterdir())]
train_files, test_files = train_test_split(filenames, test_size=0.2, random_state=322)
print(len(train_files))
print(len(test_files))

#### Select a random image and plot its RGB/NRG/mask side-by-side

In [None]:
# random.seed(42)
base_filename = random.sample(filenames, 1)[0]
rgb_filename = str(rgb_images_dir/('RGB_' + base_filename))
nrg_filename = str(nrg_images_dir/('NRG_' + base_filename))
mask_filename = str(mask_images_dir/('mask_' + base_filename))

rgb_image = cv2.imread(rgb_filename)
rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)

nrg_image = cv2.imread(nrg_filename)
nrg_image = cv2.cvtColor(nrg_image, cv2.COLOR_BGR2RGB)

mask_image = cv2.imread(mask_filename)

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12, 4))
fig.suptitle(base_filename)

axes[0].imshow(rgb_image)
axes[0].axis('off')
axes[0].set_title('RGB')

axes[1].imshow(nrg_image)
axes[1].axis('off')
axes[1].set_title('NRG')

axes[2].set_title('mask')
axes[2].imshow(mask_image)
axes[2].axis('off')

# fig.tight_layout()
fig.show()

#### Define a function to train a classifier either using the RGB or NRG images or both

In [None]:
def train_classifier(train_files: list[str], model, rgb=True, nrg=True, image_size=(256, 256)) -> None:
    if (not (rgb or nrg)):
        print('RGB or NRG have to be true')
        return
    X = []
    y = []
    for base_filename in train_files:
        rgb_filename = str(rgb_images_dir/('RGB_' + base_filename))
        nrg_filename = str(nrg_images_dir/('NRG_' + base_filename))
        mask_filename = str(mask_images_dir/('mask_' + base_filename))

        mask = cv2.resize(cv2.imread(mask_filename, flags=cv2.IMREAD_GRAYSCALE), image_size)
        mask = mask.reshape(-1)
        mask = (mask > 127).astype(np.uint8)

        if (not rgb):
            image = cv2.resize(cv2.imread(nrg_filename), image_size)
            image = image.reshape(-1, 3)
        elif (not nrg):
            image = cv2.resize(cv2.imread(rgb_filename), image_size)
            image = image.reshape(-1, 3)
        else:
            rgb_image = cv2.resize(cv2.imread(rgb_filename), image_size)
            nrg_image = cv2.resize(cv2.imread(nrg_filename), image_size)
            image = np.concatenate((rgb_image, nrg_image), axis=2)
            image = image.reshape(-1, 6)

        X.append(image)
        y.append(mask)
    
    X = np.vstack(X)
    y = np.hstack(y)
    model.fit(X, y)

#### Define a bunch of different models, train them and save them to a file

In [None]:
IMAGE_SIZE = (256, 256)

# model = joblib.load('random_forest_42_RGB_NRG.joblib')
model = RandomForestClassifier(n_jobs=-1)
train_classifier(train_files, model, True, True, IMAGE_SIZE)
joblib.dump(model, 'random_forest_322_RGB_NRG.joblib')

# model = joblib.load('random_forest_42_RGB.joblib')
model = RandomForestClassifier(n_jobs=-1)
train_classifier(train_files, model, True, False, IMAGE_SIZE)
joblib.dump(model, 'random_forest_322_RGB.joblib')

# model = joblib.load('random_forest_42_NRG.joblib')
model = RandomForestClassifier(n_jobs=-1)
train_classifier(train_files, model, False, True, IMAGE_SIZE)
joblib.dump(model, 'random_forest_322_NRG.joblib')

#### Define function to test a classifier on unseen test data

In [None]:
def test_classifier(test_file: str, model, rgb=True, nrg=True, image_size=(256, 256)): 
    if (not (rgb or nrg)):
        print('RGB or NRG have to be true')
        return

    rgb_filename = str(rgb_images_dir/('RGB_' + test_file))
    nrg_filename = str(nrg_images_dir/('NRG_' + test_file))
    mask_filename = str(mask_images_dir/('mask_' + test_file))

    mask = cv2.resize(cv2.imread(mask_filename, flags=cv2.IMREAD_GRAYSCALE), image_size)
    mask = mask.reshape(-1)
    mask = (mask > 127).astype(np.uint8)

    if (not rgb):
        image = cv2.resize(cv2.imread(nrg_filename), image_size)
        image = image.reshape(-1, 3)
    elif (not nrg):
        image = cv2.resize(cv2.imread(rgb_filename), image_size)
        image = image.reshape(-1, 3)
    else:
        rgb_image = cv2.resize(cv2.imread(rgb_filename), image_size)
        nrg_image = cv2.resize(cv2.imread(nrg_filename), image_size)
        image = np.concatenate((rgb_image, nrg_image), axis=2)
        image = image.reshape(-1, 6)
    
    y_pred = model.predict(image)
    return y_pred.reshape(image_size)


In [None]:
def calculate_iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    if union == 0:
        return 1.0
    else:
        return intersection / union

#### Test the classifier on a test file and visualise outputs

In [None]:
NUM_TEST_SAMPLES = 5

random.seed(69)
test_samples = random.sample(test_files, NUM_TEST_SAMPLES)
fig, axes = plt.subplots(nrows=NUM_TEST_SAMPLES, ncols=4, figsize=(16, 20))
for test_file, i in zip(test_samples, range(5)):
    mask_filename = str(mask_images_dir/('mask_' + test_file))
    rgb_filename = str(rgb_images_dir/('RGB_' + test_file))
    nrg_filename = str(nrg_images_dir/('NRG_' + test_file))
    mask = cv2.resize(cv2.imread(mask_filename, flags=cv2.IMREAD_GRAYSCALE), IMAGE_SIZE)
    mask = (mask > 127).astype(np.uint8)
    pred_mask = test_classifier(test_file, model, False, True, IMAGE_SIZE)
    rgb_image = cv2.resize(cv2.imread(rgb_filename), IMAGE_SIZE)
    rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2RGB)
    nrg_image = cv2.resize(cv2.imread(nrg_filename), IMAGE_SIZE)
    nrg_image = cv2.cvtColor(nrg_image, cv2.COLOR_BGR2RGB)

    iou = calculate_iou(mask, pred_mask)
    print(f'IOU: {iou}')
    
    axes[i][0].imshow(rgb_image)
    axes[i][0].axis('off')
    axes[i][0].set_title('RGB')

    axes[i][1].imshow(nrg_image)
    axes[i][1].axis('off')
    axes[i][1].set_title('NRG')

    axes[i][2].imshow(mask, cmap='gray')
    axes[i][2].axis('off')
    axes[i][2].set_title('y_true')

    axes[i][3].imshow(pred_mask, cmap='gray')
    axes[i][3].axis('off')
    axes[i][3].set_title('y_pred')

fig.show()