In [6]:
import os
import tarfile
import cv2
import numpy as np
import tensorflow as tf
import time

In [2]:
class DeepLabModel(object):
    """
    加载 DeepLab 模型；
    推断 Inference.
    """
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    def __init__(self, tarball_path):
        """
        加载预训练模型
        """
        self.graph = tf.Graph()

        graph_def = None
        # Extract frozen graph from tar archive.
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = tf.GraphDef.FromString(file_handle.read())
                break

        tar_file.close()

        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():
            tf.import_graph_def(graph_def, name='')

        self.sess = tf.Session(graph=self.graph)


    def run(self, image):
        
        batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                      feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(image)]})
        seg_map = batch_seg_map[0]
        return seg_map

    
#===============================================================================================================
def create_pascal_label_colormap():
    colormap = np.zeros((256, 3), dtype=int)
    ind = np.arange(256, dtype=int)

    for shift in reversed(range(8)):
        for channel in range(3):
            colormap[:, channel] |= ((ind >> channel) & 1) << shift
            ind >>= 3
    return colormap

def label_to_color_image(label):
    if label.ndim != 2:
        raise ValueError('Expect 2-D input label')

    colormap = create_pascal_label_colormap()

    if np.max(label) >= len(colormap):
        raise ValueError('label value too large.')
    return colormap[label]

def load_model():
    model_path = './deeplabmodels/deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz'#'deeplab_model.tar.gz'
    MODEL = DeepLabModel(model_path)
    #print('model loaded successfully!')
    return MODEL


In [26]:
#title = "composition_vii.model"
#title = "cubist.model"
#title = "feathers.model"
#title = "la_muse.model"
title = "mosaic.model"
#title = "the_scream.model"
#title = "udnie.model"
#title = "wave.model"
class render():
    def __init__(self,modeldir = "./models/"+title, 
                 archdir = "./models/model.meta"):
        #self.input = inputpic
        self.model = modeldir
        self.arch = archdir
        
    def re_store(self,session):

        self.saver = tf.train.import_meta_graph(self.arch, clear_devices = True)
        self.saver.restore(session, self.model)
        self.inputs = tf.get_collection("inputs")[0]
        self.outputs = tf.get_collection("output")[0]
    
    def run(self,input_image):
        
        self.image = input_image.astype(np.float32)
        self.image = np.expand_dims(self.image, axis = 0)
        #self.time_s = time.time()
        self.result = self.outputs.eval({self.inputs : self.image})
        self.result = np.clip(self.result, 0.0, 255.0).astype(np.uint8)
        self.result = np.squeeze(self.result, 0)
        #self.time_t = time.time()
        #print ("Time used: ", self.time_t - self.time_s)
        
        return self.result


In [27]:
model = load_model()
render_test = render()
INPUT_SIZE = 513

In [28]:
#读入视频文件
vc = cv2.VideoCapture('./video/source/test.flv') 
fps = int(vc.get(cv2.CAP_PROP_FPS))#获取帧数

if vc.isOpened(): #判断是否正常打开
    rval,frame = vc.read()
else:
    rval = False
    
#尺寸变换计算
height,width = frame.shape[:2]
resize_ratio = 1.0 * INPUT_SIZE / max(width, height)
image_size = (int(resize_ratio * width), int(resize_ratio * height))

#输出视频路径及格式
video_dir = './video/output/test.avi'  
fourcc = cv2.VideoWriter_fourcc('M','J','P','G') #opencv3.0
videoWriter = cv2.VideoWriter(video_dir, fourcc, fps, image_size)

#============================================================================================

tf.reset_default_graph()
with tf.Session() as sess:
    render_test.re_store(sess)#导入风格转换模型
    time_1 = time.time()#计时开始
    
    while rval:
        #尺寸变换-----------------------------------------------------------------------------
        frame = cv2.resize(frame,image_size)
        #获取分割图---------------------------------------------------------------------------
        seg_map = model.run(frame)
        seg_image = label_to_color_image(seg_map).astype(np.uint8)
        #风格转换-----------------------------------------------------------------------------
        styback = render_test.run(frame)
        result = np.zeros_like(frame)
        styback = cv2.resize(styback,(result.shape[1],result.shape[0]))
        styback = np.array(styback)
        #叠加---------------------------------------------------------------------------------
        result[seg_map > 0] = frame[seg_map > 0]
        result[seg_map == 0] = styback[seg_map == 0]
        #写入新视频---------------------------------------------------------------------------
        videoWriter.write(result)
        #读取下一帧---------------------------------------------------------------------------
        cv2.waitKey(1)
        rval, frame = vc.read()

vc.release()
videoWriter.release()
#=============================================================================================

time_2 = time.time()#计时结束
print ("Time used: ", time_2 - time_1)

'model_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from ./models/wave.model
Time used:  127.42070317268372
