In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python

import os
import numpy as np # linear algebra
import tensorflow as tf
from PIL import Image
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
import yaml

os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

severstal_test_dir = './data/severstal-steel-defect-detection/test_images/'
# with open('SETTINGS.yaml') as f:
#     cfg = yaml.load(f)
# dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')
# val_imgs = dataset.get_image_list('validation')


# More feature channels and trained for longer. Threshold tuned by grid search.
seg_model_path = './seg_model_20190917-233355.h5' 
thresh = [0.9, 0.9, 0.9, 0.8]
upper_thresh = [0.95, 0.95, 0.95, 0.95]
num_px_thresh = [5000, 5000, 5000, 5000]


def load_img(img_path):
    img = np.array(Image.open(img_path))
    img_gray = img[:, :, :1] # All channels are the same
    return img_gray

def onehottify(x, n=None, dtype=float):
    '''1-hot encode x with the max value n (computed from data if n is None).
    '''
    x = np.asarray(x)
    n = np.max(x) + 1 if n is None else n
    return np.eye(n, dtype=dtype)[x]

def postprocess(y, thresh=None, upper_thresh=None, num_px_thresh=None):
    if thresh is None:
        thresh = [0.85, 0.85, 0.85, 0.85]
    if upper_thresh is None:
        upper_thresh = [0.85, 0.85, 0.85, 0.85]
    if num_px_thresh is None:
        num_px_thresh = [5000, 5000, 5000, 5000]
    # TODO: handle batches properly
    batches, height, width, classes = y.shape
    assert batches == 1
    
    # Only allow one class at each pixel
    y_argmax = np.argmax(y, axis=-1)
    y_one_hot = onehottify(y_argmax, y.shape[-1], dtype=int)
    for c in range(classes):
        y_one_hot[:, :, :, c][y[:, :, :, c] < thresh[c]] = 0 # Background
    
    # Making predictions on an empty mask is very costly (score immediately goes from 1 to 0)
    # So, only predict a mask if there are many pixels (num_px_thresh) above a high threshold (upper_thresh)
    for c in range(classes):
        pixels_above_upper = np.sum(y[:, :, :, c] > upper_thresh[c])
        if pixels_above_upper < num_px_thresh[c]:
            y_one_hot[:, :, :, c] = 0
    return y_one_hot

def dense_to_rle(dense):
    '''Convert the dense np ndarray representation of a single class mask to the equivalent rle
    representation.
    '''
    assert len(dense.shape) == 2
    # Use Fortran (column-major) ordering
    pixels = dense.flatten(order='F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def store_preds(y_one_hot, img_name, preds):
    for c in range(y_one_hot.shape[-1]):
        preds.append(f'{img_name}_{c+1},{dense_to_rle(y_one_hot[0, :, :, c])}')


# Load model
seg_model = tf.keras.models.load_model(seg_model_path)

# Make preds
preds = []
preds.append('ImageId_ClassId,EncodedPixels')
#val_imgs.sort()
img_names = os.listdir(severstal_test_dir)
img_names.sort()
for i, img_name in enumerate(img_names):
    if i % 100 == 0:
        print(f'Running inference on image {i} / {len(img_names)}')
    img_path = os.path.join(severstal_test_dir, img_name)
    img = load_img(img_path)
    img_batch = np.expand_dims(img, axis=0)

    y = seg_model.predict(img_batch)
    y_one_hot = postprocess(y, thresh, upper_thresh, num_px_thresh)
    
    store_preds(y_one_hot, img_name, preds)

# Save to file
with open('submission.csv', 'w') as f:
    f.writelines([p + '\n' for p in preds])