In [None]:
from __future__ import print_function

%matplotlib inline

import sys
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.contrib.framework.python.ops import variables

from collections import OrderedDict

import numpy as np
import scipy.misc
from input import *
from model import *
from train import *
import matplotlib.pyplot as plt
tf.reset_default_graph()

In [None]:
image_shape = (160,576)
iterator, filename = get_train_inputs(batch_size=100,
                                      repeat=True, 
                                      num_classes=2, 
                                      image_shape=image_shape)

In [None]:
_show = False
if _show:
    with tf.Session() as sess:
        sess.run(iterator.initializer, feed_dict={filename: ['data/kitti_segmentation.tfrecord']})
        next_element = iterator.get_next()
        i = 1
        while i < 10:
            i += 1
            print("*"*10)
            image, label = sess.run(next_element)
            print(image.shape, label.shape)
            plt.imshow(np.uint8(image))
            plt.imshow(label[:,:,0], cmap='jet', alpha=0.5)
            plt.show()

In [None]:
# record_iterator = tf.python_io.tf_record_iterator(path='data/kitti_segmentation.tfrecord')
# string_record = next(record_iterator)
# example = tf.train.Example()
# example.ParseFromString(string_record)

## Model

In [None]:
encoder = SlimModelEncoder(name="vgg_16", num_classes=2, is_training=True)
image, label = iterator.get_next()
restore_fn, end_points = encoder.build(image=image, image_shape=image_shape)

In [None]:
end_points.keys()

In [None]:
tf.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)

In [None]:
tensor = tf.get_default_graph().get_tensor_by_name('vgg_16/conv1/conv1_1/kernel/Regularizer/l2_regularizer:0')

In [None]:
print(end_points['vgg_16/conv3/conv3_3'], 
    end_points['vgg_16/conv4/conv4_3'], 
    end_points['vgg_16/conv5/conv5_3'],
    end_points['vgg_16/fc6'],
    end_points['vgg_16/fc7'],sep='\n')

In [None]:
decoder = FCNDecoder(end_points, nb_classes=2, scope='decoder')

In [None]:
tensors_to_connect = OrderedDict()
tensors_to_connect["vgg_16/fc7"] = (2,2)
tensors_to_connect['vgg_16/conv5/conv5_3'] = (2,2)
tensors_to_connect['vgg_16/conv4/conv4_3'] = (8,8)

In [None]:
net = decoder.build(tensors_to_connect)

In [None]:
assert tuple(tf.get_default_graph().get_tensor_by_name('logit:0').get_shape().as_list()[1:3]) == image_shape

## Train

In [None]:
trainer = Trainer(nb_classes=2, optimizer=tf.train.AdamOptimizer, learning_rate=1e-6)

In [None]:
trainer.build(predictions=net, labels=label)

In [None]:
trainer.train(iterator, restore_fn=restore_fn, filename=['data/kitti_segmentation.tfrecord'])

## Predict

In [None]:
TRAIN_DIR = "model_checkpoints/"
with tf.Graph().as_default() as graph:
    image_shape = (160,576)
    iterator, filename = get_train_inputs(batch_size=100,
                                      repeat=False, 
                                      num_classes=2, 
                                      image_shape=image_shape)
    encoder = SlimModelEncoder(name="vgg_16", num_classes=2, is_training=False)
    image, label = iterator.get_next()
    assign_op, feed_dict, end_points = encoder.build(image=image, image_shape=image_shape)
    # tensors to connect and encoder
    decoder = FCNDecoder(end_points, nb_classes=2, scope='decoder')
    tensors_to_connect = OrderedDict()
    tensors_to_connect["vgg_16/fc7"] = (2,2)
    tensors_to_connect['vgg_16/conv5/conv5_3'] = (2,2)
    tensors_to_connect['vgg_16/conv4/conv4_3'] = (8,8)
    net = decoder.build(tensors_to_connect)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, os.path.join(TRAIN_DIR,"model.ckpt-10000"))
        input_tensor = graph.get_tensor_by_name('training_data/input:0')
        sess.run(iterator.initializer, feed_dict={input_tensor: ['data/kitti_segmentation.tfrecord']})
        # inference
        print(net)
        net = tf.squeeze(net)
        out = sess.run(net)
        label = sess.run(label)
#         # (160, 576, 2)
#         pred = tf.argmax(net, axis=2)
#         print(pred)
#         # (160,576)
#         label = sess.run(pred)
#         img = sess.run(image)
#         plt.imshow(np.uint8(img))
#         plt.imshow(label, cmap='jet', alpha=0.5)
#         print(label.mean())

In [None]:
_argmax = np.argmax(label,axis=2)
plt.imshow(_argmax)

In [None]:
_argmax = np.argmax(out,axis=2)
plt.imshow(_argmax)