# Imports

In [None]:
MODEL_V3_BEST = {
    "path": 'models/BEST_V3.DEEP5.1-saved_weights.h5',
    "net_depth": 5
}
MODEL_TO_IMPORT = MODEL_V3_BEST # can be either a model (dict) or a string
#MODEL_PP = 'models/V3.PP.1.h5'
MODEL_PP = None # set None to deactivate Post Processing model

DEFAULT = {
    "mode": 'rgb',
    "seed": 1000,
    "lk_alpha": .1,
    "channels_size": 3,
    "batchnorm": True,
    "residual": True,
    "net_depth": 4
}
IMAGE_SIZE = 608
PATCH_SIZE = 16
PTHRESHOLD = 0.25
NB_TST_IMG = 50

SUBMISSION_FILE = "submission.csv"

#MODEL_TO_IMPORT = 'models/V2.2.004_step17_BEST.h5'

In [None]:
%matplotlib inline
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import os, sys
from PIL import Image
import re
from IPython import display

IN_COLAB = 'google.colab' in sys.modules

In [None]:
if (IN_COLAB):
    from google.colab import drive
    drive.mount('/gdrive')
    root_dir = "/gdrive/My Drive/ML/training/"
else:
    root_dir = "."
test_dir = os.path.join(root_dir, "test_set_images/test_")

# Load Model

In [None]:
def get_model_prop(prop_name):
    if isinstance(MODEL_TO_IMPORT, str) or prop_name not in MODEL_TO_IMPORT:
        return DEFAULT[prop_name]
    return MODEL_TO_IMPORT[prop_name]

In [None]:
from CNNv2 import CNN

model = CNN(
    rootdir=root_dir,
    window_size=IMAGE_SIZE,
    lk_alpha=get_model_prop('lk_alpha'),
    random_seed=get_model_prop('seed'),
    channels_size=get_model_prop('channels_size'),
    batchnorm=get_model_prop('batchnorm'),
    residual=get_model_prop('residual'),
    net_depth=get_model_prop('net_depth'),
)
model_path = MODEL_TO_IMPORT if isinstance(MODEL_TO_IMPORT, str) else MODEL_TO_IMPORT['path']
model.load(os.path.join(root_dir, model_path))

In [None]:
model_pp = None
if MODEL_PP is not None:
    model_pp = CNN(
        rootdir=root_dir,
        window_size=IMAGE_SIZE,
        lk_alpha=get_model_prop('lk_alpha'),
        random_seed=get_model_prop('seed'),
        channels_size=1,
        batchnorm=get_model_prop('batchnorm'),
        residual=get_model_prop('residual'),
    )
    model_pp.load(os.path.join(root_dir, MODEL_PP))

# Visualization 

In [None]:
from helpers import *
from skimage.color import rgb2hsv, rgb2lab, rgb2hed, rgb2yuv

def load_image(filename, mode = 'rgb'):
    if mode == 'hsv':
        img = rgb2hsv(mpimg.imread(filename))
    elif mode == 'lab':
        img = rgb2lab(mpimg.imread(filename))
    elif mode == 'hed':
        img = rgb2hed(mpimg.imread(filename))
    elif mode == 'yuv':
        img = rgb2yuv(mpimg.imread(filename))
    elif mode == 'rgb':
        img = mpimg.imread(filename)
    else:
        raise NotImplemented
    return np.expand_dims(img, axis=0)

def get_path_for_img_nb(img_nb):
    return test_dir+str(img_nb)+'/test_' + str(img_nb) + '.png'

def get_image_filenames(img_nb = None):
    image_filenames = []
    if img_nb == None:
        for i in range(1, NB_TST_IMG+1):
            image_filenames += [get_path_for_img_nb(i)]
    elif type(img_nb) is int:
        image_filenames += [get_path_for_img_nb(img_nb)]
    else:
        for i in img_nb:
            image_filenames += [get_path_for_img_nb(i)]
    return image_filenames
    
def visualize_step(idx, input_image, Xi_raw, Xi, ground, mask = None, animate = False):
    input_image = np.squeeze(input_image)
    fig, axs = plt.subplots(1, 4 if mask is None else 5, figsize=(15, 15))
    i = 0
    axs[i].imshow(input_image)
    axs[i].set_title(f'image {idx+1}')
    i += 1
    axs[i].imshow(np.squeeze(Xi_raw))
    axs[i].set_title('real prediction')
    i += 1
    axs[i].imshow(np.squeeze(Xi))
    axs[i].set_title('thresholded prediction')
    if mask is not None:
        i += 1
        axs[i].imshow(mask)
        axs[i].set_title('mask')
    i += 1
    axs[i].imshow(ground)
    axs[i].set_title('label prediction')
    # remove the x and y ticks
    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])
    display.clear_output(wait=True)
    if animate:
        plt.show()

def get_predicted_mask(image, mode, model):
    if isinstance(image, str):
        image = load_image(image, mode)
    Xi_raw = model.model.predict(image)
    Xi = np.where(Xi_raw>0.5, 1, 0)
    Xi = np.squeeze(Xi)
    return Xi_raw, Xi

def get_steps_for_img(filename, mode, model = model, model_pp = model_pp):
    img_input = load_image(filename, mode)
    Xi_raw, Xi = get_predicted_mask(filename, mode, model=model)
    if model_pp is not None:
        _, mask = get_predicted_mask(Xi_raw, mode, model=model_pp)
    else:
        mask = None
    ground = get_ground_img(
        Xi if mask is None else mask,
        patch_size = PATCH_SIZE,
        foreground_threshold = PTHRESHOLD)
    return (img_input, Xi_raw, Xi, ground, mask)

def visualize(img_nb = None, mode = DEFAULT['mode'], save_masks = False):
    image_filenames = get_image_filenames(img_nb)
    for i, filename in enumerate(image_filenames[0:]):
        (img_input, Xi_raw, Xi, ground, mask) = get_steps_for_img(filename, mode)
        visualize_step(i, img_input, Xi_raw, Xi, ground, mask)
        if save_masks:
            mpimg.imsave(os.path.join(root_dir, 'masks', f'test_{i+1}.png'), np.squeeze(Xi_raw[0]))
    plt.show()
        
def generate_submission(img_nb = None, plot = True, submission_filename = SUBMISSION_FILE, mode = DEFAULT['mode']):
    """ Generate a .csv containing the classification of the test set. """
    image_filenames = get_image_filenames(img_nb)
    print(f'Generating file: {submission_filename}...')
    with open(os.path.join(root_dir, submission_filename), 'w') as f:
        f.write('id,prediction\n')
        for i, filename in enumerate(image_filenames[0:]):
            (img_input, Xi_raw, Xi, ground, mask) = get_steps_for_img(filename, mode)
            if plot:
                visualize_step(i, img_input, Xi_raw, Xi, ground, mask, True)
            else:
                print(f'Img {i+1}...')
            f.writelines([
                "{:03d}_{}_{},{}\n".format(i+1, j*PATCH_SIZE, k*PATCH_SIZE, ground[k,j])
                for j in range(ground.shape[1]) for k in range(ground.shape[0])
            ])
    print(f'Submission generated at {submission_filename}!')

In [None]:
#Visualize prediction on input image
# Parameters idx can be:
#     int: 1-50 for a specific image
#     array: [1,2,4,6] for a specific range
#     None: run all images
visualize([1,2,6,10])

# Submission

In [None]:
assert True == False, "prevent next cells execution when run all"

In [None]:
generate_submission(mode=get_model_prop('mode'))