In [None]:
import sys
import os
sys.path.append("..")
from cores.utils.CRF import CRF
from cores.utils import misc
from cores.utils.voc_cmap import get_cmap
import numpy as np
import mxnet as mx
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from cores.config import conf
import cores.symbols.final_fcn_resnet50 as net_symbol
%matplotlib inline

In [None]:
IM_PATH = "2007_000925.jpg"
EPOCH_NUM = 23

USE_CRF = True
CTX = mx.gpu(0) #gpu index

#colormap
cmap = get_cmap()
cmap = LinearSegmentedColormap.from_list("my_colormap", cmap.reshape(-1, 3)/255.)

#crf object
crf = CRF(pos_xy_std=conf.CRF_POS_XY_STD, pos_w=conf.CRF_POS_W, bi_xy_std=conf.CRF_BI_XY_STD,
              bi_rgb_std=conf.CRF_BI_RGB_STD, bi_w=conf.CRF_BI_W)
#preprocessing
ori_im = Image.open(IM_PATH)
im_w, im_h = ori_im.size
im_arr = np.array(ori_im).astype(np.float32)
im_arr -= np.array(conf.MEAN_RGB).reshape(1, 1, 3)
im_arr = mx.nd.array(np.expand_dims(im_arr.transpose([2, 0, 1]), 0))

#Initialize network
seg_net = net_symbol.create_infer(conf.CLASS_NUM, conf.WORKSPACE)
seg_net_prefix = os.path.join("..", conf.SNAPSHOT_FOLDER, "final_fcn_resnet50")
arg_dict, aux_dict, _ = misc.load_checkpoint(seg_net_prefix, EPOCH_NUM)

im, orig_size = misc.pad_image(im_arr, 8)
mod = mx.mod.Module(seg_net, data_names=["data", "orig_data"], label_names=[], context=CTX)
mod.bind(data_shapes=[("data", im_arr.shape), ("orig_data", (1, 3, orig_size[0], orig_size[1]))],
             for_training=False, grad_req="null")
initializer = mx.init.Normal()
mod.init_params(initializer=initializer, arg_params=arg_dict, aux_params=aux_dict, allow_missing=True)

#do forward and get prediction
mod.forward(mx.io.DataBatch(data=[im, mx.nd.zeros((1, 3, orig_size[0], orig_size[1]))]))
score = mx.nd.transpose(mod.get_outputs()[0].copyto(mx.cpu()), [0, 2, 3, 1])
score = mx.nd.reshape(score, (score.shape[1], score.shape[2], score.shape[3]))
up_score = mx.nd.transpose(mx.image.imresize(score, im_w, im_h, interp=1), [2, 0, 1])

if USE_CRF:
    final_scoremaps = mx.nd.log(up_score).asnumpy()
    final_scoremaps = crf.inference(np.array(ori_im), final_scoremaps)
else:
    final_scoremaps = up_score.asnumpy()
pred_label = final_scoremaps.argmax(0)

#show results
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(121)
ax.imshow(ori_im)
ax = fig.add_subplot(122)
ax.matshow(pred_label, vmin=0, vmax=255, cmap=cmap)

