In [1]:
import os
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import copy

from PConv.libs.pconv_model import PConvUnet
from PConv.libs.util import MaskGenerator, ImageChunker

%load_ext autoreload
%autoreload 2

Using TensorFlow backend.


#  Model  Loading

In [3]:
model = PConvUnet(vgg_weights=None, inference_only=True)
model.load('./PConv/pretrained_models/pconv_imagenet.26-1.07.h5', train_bn=False)

# Useful Function Declaration

In [4]:
# Used for chunking up images & stiching them back together
chunker = ImageChunker(512, 512, 30)

def plot_images(images, s=5):
    _, axes = plt.subplots(1, len(images), figsize=(s*len(images), s))
    if len(images) == 1:
        axes = [axes]
    for img, ax in zip(images, axes):
        ax.imshow(img)
    plt.show()

#  Image  Path Declaration

In [11]:
path = os.path.join('masks')

img_path = {}

for r, d, f in os.walk(path):
    for file in f:
        if file == '.DS_Store':
            continue
        img_name = file.split('.')[0][:18]
        img_path[img_name] = os.path.join('ADE20K_2016_07_26', 'images', 'training', 'b', 'bedroom', img_name + '.jpg')

# Mask Path Declaration

In [13]:
path = os.path.join('masks')

msk_path = {}

for r, d, f in os.walk(path):
    for file in f:
        if file == '.DS_Store':
            continue
        img_name = file.split('.')[0][:18]
        if img_name in msk_path:
            msk_path[img_name].append(os.path.join(r, file))
        else:
            msk_path[img_name]  =  [os.path.join(r, file)]

# Prediction

In [36]:
for img_name in  tqdm(img_path):
    
    # Load image
    img = np.array(Image.open(img_path[img_name]).resize((512, 512))) / 255
    
    if not os.path.exists(os.path.join('results', img_name)):
        os.makedirs(os.path.join('results', img_name))
    
    img_id = 0
    for msk in msk_path[img_name]:
        
        # Load mask
        mask = np.array(Image.open(msk).resize((512, 512))) / 255

        # Image + mask
        masked_img = copy.deepcopy(img)
        masked_img[mask==0] = 1
        
        # Process sample
        chunked_images = chunker.dimension_preprocess(deepcopy(img))
        chunked_masks = chunker.dimension_preprocess(deepcopy(mask))
        pred_imgs = model.predict([chunked_images, chunked_masks])
        reconstructed_image = chunker.dimension_postprocess(pred_imgs, img)
        
        # Save results
        cv2.imwrite(os.path.join('results', img_name, str(img_id) + '.jpg'), cv2.cvtColor(pred_imgs[0]*255, cv2.COLOR_RGB2BGR))
        cv2.imwrite(os.path.join('results', img_name, str(img_id) + '_mask.png'), chunked_masks[0]*255)
        
        img_id += 1

100%|██████████| 208/208 [1:12:29<00:00, 22.98s/it]
