In [None]:
import yaml
import os
import numpy as np
import cv2
import math
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import sys

sys.path.append("..")
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.optimizers import Adam

import dataloader.coco
import models.fcn

In [None]:
def load_coco_dataset(cfg):
    training_data_generator = dataloader.coco.CocoDataGenerator(cfg, "train")
    validation_data_generator = dataloader.coco.CocoDataGenerator(cfg, "validate")
    classes = dataloader.coco.get_super_class(cfg)

    return training_data_generator, validation_data_generator, classes

In [None]:
fp = open("../configs/coco_fcn.yaml")
cfg = yaml.load(fp)

training_data_generator, validation_data_generator, classes = load_coco_dataset(cfg)

classes_count = len(classes)

img_width, img_height = training_data_generator.image_width, training_data_generator.image_height
channels_count = training_data_generator.image_num_chans

# Get the path for saving checkpoints
checkpoint_path = "../pretrained/fcn_weights.h5"

In [None]:
fcn = models.fcn.FCN()
model = fcn.resnet50(input_shape=(img_width, img_height, channels_count), classes=classes_count)

model.load_weights(checkpoint_path)

In [None]:
dataset_use = training_data_generator
    
path_fetch_output = dataset_use.labels_dir
path_fetch_output = path_fetch_output  + "/"
path_fetch_input = dataset_use.dataset_dir + "/"

start_inference = 0
end_inference = 50
scale_factor = math.floor((255.0 / classes_count))

files_input = os.listdir(path_fetch_input)
files_input = files_input[start_inference:end_inference]

for x in range(start_inference,end_inference,1):
    files_output = files_input[x].replace(".jpg",".jpg.npy")
    true_mask = np.load(path_fetch_output+files_output)
    true_image = cv2.imread(path_fetch_input+files_input[x])
    #Reshape to a tensor
    true_image_pass = true_image.reshape(1,true_image.shape[0],true_image.shape[1],true_image.shape[2])
    prediction = model.predict(true_image_pass, verbose = 1)
    prediction = np.argmax(prediction, -1)
    prediction = prediction[0,:,:]
    
    prediction = prediction * scale_factor
    true_mask = true_mask * scale_factor
    
    
    f, axarr = plt.subplots(1,3)
    axarr[0].imshow(true_image)
    axarr[1].imshow(true_mask)
    axarr[2].imshow(prediction)
    plt.show()