In [1]:
import tensorflow as tf
import numpy as np
import cv2
from IPython.display import clear_output, Image, display, HTML


In [2]:
img = cv2.imread('IMG_2.jpg', 0)
print(img.shape)
img = np.array(img)
img = (img - 127.5) / 128

(684, 1024)


In [3]:
#Visualize the graph
  
def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

In [9]:
def conv2d(x, w):
    return tf.nn.conv2d(x, w, strides = [1, 1, 1, 1], padding = 'SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'SAME')

def inf(x):
    # s net ###########################################################
    w_conv1_1 = tf.get_variable('w_conv1_1', [5, 5, 1, 24])
    b_conv1_1 = tf.get_variable('b_conv1_1', [24])
    h_conv1_1 = tf.nn.relu(self.conv2d(x, w_conv1_1) + b_conv1_1)

    h_pool1_1 = self.max_pool_2x2(h_conv1_1)

    w_conv2_1 = tf.get_variable('w_conv2_1', [3, 3, 24, 48])
    b_conv2_1 = tf.get_variable('b_conv2_1', [48])        
    h_conv2_1 = tf.nn.relu(self.conv2d(h_pool1_1, w_conv2_1) + b_conv2_1)

    h_pool2_1 = self.max_pool_2x2(h_conv2_1)
    w_conv3_1 = tf.get_variable('w_conv3_1', [3, 3, 48, 24])
    b_conv3_1 = tf.get_variable('b_conv3_1', [24])
    h_conv3_1 = tf.nn.relu(self.conv2d(h_pool2_1, w_conv3_1) + b_conv3_1)
    w_conv4_1 = tf.get_variable('w_conv4_1', [3, 3, 24, 12])
    b_conv4_1 = tf.get_variable('b_conv4_1', [12])
    h_conv4_1 = tf.nn.relu(self.conv2d(h_conv3_1, w_conv4_1) + b_conv4_1)
        
    # m net ###########################################################
    w_conv1_2 = tf.get_variable('w_conv1_2', [7, 7, 1, 20])
    b_conv1_2 = tf.get_variable('b_conv1_2', [20])
    h_conv1_2 = tf.nn.relu(self.conv2d(x, w_conv1_2) + b_conv1_2)

    h_pool1_2 = self.max_pool_2x2(h_conv1_2)
    w_conv2_2 = tf.get_variable('w_conv2_2', [5, 5, 20, 40])
    b_conv2_2 = tf.get_variable('b_conv2_2', [40])
    h_conv2_2 = tf.nn.relu(self.conv2d(h_pool1_2, w_conv2_2) + b_conv2_2)

    h_pool2_2 = self.max_pool_2x2(h_conv2_2)
    w_conv3_2 = tf.get_variable('w_conv3_2', [5, 5, 40, 20])
    b_conv3_2 = tf.get_variable('b_conv3_2', [20])
    h_conv3_2 = tf.nn.relu(self.conv2d(h_pool2_2, w_conv3_2) + b_conv3_2)

    w_conv4_2 = tf.get_variable('w_conv4_2', [5, 5, 20, 10])
    b_conv4_2 = tf.get_variable('b_conv4_2', [10])
    h_conv4_2 = tf.nn.relu(self.conv2d(h_conv3_2, w_conv4_2) + b_conv4_2)
        
     # l net ###########################################################
    w_conv1_3 = tf.get_variable('w_conv1_3', [9, 9, 1, 16])
    b_conv1_3 = tf.get_variable('b_conv1_3', [16])
    h_conv1_3 = tf.nn.relu(self.conv2d(x, w_conv1_3) + b_conv1_3)

    h_pool1_3 = self.max_pool_2x2(h_conv1_3)

    w_conv2_3 = tf.get_variable('w_conv2_3', [7, 7, 16, 32])
    b_conv2_3 = tf.get_variable('b_conv2_3', [32])
    h_conv2_3 = tf.nn.relu(self.conv2d(h_pool1_3, w_conv2_3) + b_conv2_3)

    h_pool2_3 = self.max_pool_2x2(h_conv2_3)

    w_conv3_3 = tf.get_variable('w_conv3_3', [7, 7, 32, 16])
    b_conv3_3 = tf.get_variable('b_conv3_3', [16])
    h_conv3_3 = tf.nn.relu(self.conv2d(h_pool2_3, w_conv3_3) + b_conv3_3)

    w_conv4_3 = tf.get_variable('w_conv4_3', [7, 7, 16, 8])
    b_conv4_3 = tf.get_variable('b_conv4_3', [8])
    h_conv4_3 = tf.nn.relu(self.conv2d(h_conv3_3, w_conv4_3) + b_conv4_3)
        
    # merge ###########################################################
    h_conv4_merge = tf.concat([h_conv4_1, h_conv4_2, h_conv4_3], 3)
        
    w_conv5 = tf.get_variable('w_conv5', [1, 1, 30, 1])
    b_conv5 = tf.get_variable('b_conv5', [1])
    h_conv5 = self.conv2d(h_conv4_merge, w_conv5) + b_conv5
        
    y_pre = h_conv5

    return y_pre

In [13]:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.ckpt.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    #latest_ckp = tf.train.latest_checkpoint('./')
    #print_tensors_in_checkpoint_file(latest_ckp, all_tensors=True, tensor_name='')
    graph = tf.get_default_graph()
    #print(tf.get_collection())
    #show_graph(graph.as_graph_def())
    op_to_restore = graph.get_tensor_by_name("add_12:0")
    print(op_to_restore)
    x_in = np.reshape(img, (1, img.shape[0], img.shape[1], 1))
    print(x_in.shape, x_in.dtype)
    x_in = np.float32(x_in)
    print(x_in.shape, x_in.dtype)
    y_pred = []
    x = graph.get_tensor_by_name('Placeholder:0')#tf.placeholder(tf.float32, [None, None, None, 1])#tf.get_collection("Placeholder_1")[0] #graph.get_variable("x", shape = [None, None, None, 1])#t
    y_pred = sess.run(op_to_restore, feed_dict={x: x_in})

INFO:tensorflow:Restoring parameters from ./model.ckpt
Tensor("add_12:0", shape=(?, ?, ?, 1), dtype=float32)
(1, 684, 1024, 1) float64
(1, 684, 1024, 1) float32


In [15]:
print(np.sum(y_pred))

1246.12


In [None]:
#sess.graph.get_operations()