In [None]:
import tensorflow as tf
import numpy as np
from skimage import io
import matplotlib.pyplot as plt
from net.refinenet import RefineNet
%matplotlib inline

inp = tf.placeholder(tf.float32,shape=[None,224,224,4],name='input')
refine_net = RefineNet(inp,'resnet_v1_50','exp/test3/iters-24309')

net,end_points = refine_net.net,refine_net.end_points
coarse_net_end = end_points['tail/linear2']

# Attach sigmoid and reshape
coarse_out= tf.reshape(tf.sigmoid(coarse_net_end),[-1,56,56,1])
refine_out = tf.sigmoid(net)

sess = tf.InteractiveSession()
refine_net.initialize(sess,'exp/refine-test_1/iters-38609')
#saver = tf.train.Saver()
#saver.restore(sess, 'exp/test3/iters-20929')

print("Model Loaded")

In [None]:
from dataprovider.preprocess import vgg_preprocess
import os
from skimage import transform
BASE_DIR = os.path.join('/work/george','DAVIS')
IMAGESETS = os.path.join('ImageSets','480p')
    
RESIZE_HEIGHT = 224
RESIZE_WIDTH = 224

def read_image(imageFile,prev_mask=None):
        
    # Fix full file path
    rgbFile =BASE_DIR + imageFile
    maskFile = BASE_DIR + prevMaskFile
        
    # Read images 
    rgb = io.imread(rgbFile)
    

    mask = prev_mask
    
    mask = np.expand_dims(mask,axis=2)
    
        
    # Concatenate images
    image =  np.concatenate((rgb, mask), 2)
    
    # Resize 
    image = transform.resize(image,[RESIZE_HEIGHT,RESIZE_WIDTH])
       
    #io.imshow(image[:,:,0:3])  
    print('Max',image.max(),image.min())
    image = image*255
    #plt.imshow(image[:,:,0:3].astype(np.uint8)) 
    
    
    image = np.expand_dims(image,axis=0)
    
    image = vgg_preprocess(image)
    
    return image

def read_raw_image(image_file):
    fullfile_path = BASE_DIR + image_file
    image = io.imread(fullfile_path)
    # Resize 
    image = transform.resize(image,[RESIZE_HEIGHT,RESIZE_WIDTH])
    return image

In [None]:
import re
def get_mask_file_name(image_name):
    m=re.match(r"(/.*/.*/)(.*/.*).jpg",image_name)
    frame_no = m.group(2)
    return '/Annotations/480p/{}.png'.format(frame_no)


In [None]:
from skimage import morphology


img_seq = '/JPEGImages/480p/bear/00004.jpg'

start_frame= 0;
end_frame = 15;

img_list = [['/JPEGImages/480p/bear/00004.jpg' ,'/Annotations/480p/bear/00003.png'],
['/JPEGImages/480p/bmx-bumps/00086.jpg' ,'/Annotations/480p/bmx-bumps/00085.png'],
           ['/JPEGImages/480p/breakdance-flare/00003.jpg', '/Annotations/480p/breakdance-flare/00002.png'],
           ['/JPEGImages/480p/boat/00008.jpg','/Annotations/480p/boat/00007.png' ],
           ['/JPEGImages/480p/blackswan/00007.jpg', '/Annotations/480p/blackswan/00006.png' ],
           ['/JPEGImages/480p/bmx-trees/00066.jpg', '/Annotations/480p/bmx-trees/00065.png' ],
           ['/JPEGImages/480p/car-roundabout/00042.jpg', '/Annotations/480p/car-roundabout/00041.png' ],
           ['/JPEGImages/480p/kite-surf/00034.jpg', '/Annotations/480p/kite-surf/00033.png' ]]

_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
_M_MEAN = 127

start_frame_name ='{0}{1:05}.jpg'.format(img_seq,start_frame)
mask_file_name = get_mask_file_name(start_frame_name)

prev_mask = io.imread(mask_file_name,as_grey=True)
prev_mask = transform.resize(prev_mask,[RESIZE_HEIGHT,RESIZE_WIDTH])

for img_no in range(len(img_list)):

    image_file ='{0}{1:05}.jpg'.format(img_seq,img_no)
   
    
    image = read_image(image_file,prev_mask)
    #image[0,:,:,3] = morphology.erosion(image[0,:,:,3],np.ones([3,3]))

    result = sess.run([coarse_out,refine_out],feed_dict={inp:image})

    plt.figure()


    #Plot input image
    plt.subplot(1,5,1)
    frame1 = plt.gca()
    frame1.axes.set_title('Input 1-3')
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
    rgb = image[0,:,:,0:3]
    #print(rgb.shape)
    means = [_R_MEAN, _G_MEAN ,_B_MEAN]
    rgb = rgb + means
    plt.imshow(rgb.astype(np.uint8))

    #Plot previous label
    plt.subplot(1,5,2)
    
    frame1 = plt.gca()
    frame1.axes.set_title('Input 4')
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
    plt.imshow(image[0,:,:,3])

    #Plot output
    plt.subplot(1,5,3)
    frame1 = plt.gca()
    frame1.axes.set_title('Coarse Out.')
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
    plt.imshow(result[0][0,:,:,0])
    
    #Plot output
    plt.subplot(1,5,4)
    frame1 = plt.gca()
    frame1.axes.set_title('Refine Out.')
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
    plt.imshow(result[1][0,:,:,0])
    
    #Plot label
    plt.subplot(1,5,5)
    frame1 = plt.gca()
    frame1.axes.set_title('Target')
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
    plt.imshow(read_raw_image(get_mask_file_name(img_list[image_no][0])))