In [None]:
import tensorflow as tf
import os 
from PIL import Image
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from nets import nets_factory
from preprocessing import preprocess_image
from IPython.display import clear_output
slim =tf.contrib.slim

SPLITS_TO_SIZES = {'train': 12936, 'query':3800, 'test': 10000}
_ITEMS_TO_DESCRIPTIONS = {
    'img_raw': 'A [128 x 64 x 3] color image.',
    'label': 'A single integer between 0 and 9',
}
_NUM_CLASSES = 751


feature_map = {
                'vgg_16': 'fc7_end_points',
                'inception_v3':'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0',
                'inception_v4': 'PreLogitsFlatten',
                'inception_resnet_v2': 'PreLogitsFlatten',
                'resnet_v1_50': 'resnet_v1_50/pool5:0',
                'resnet_v2_50': 'resnet_v2_50/pool5:0',
                'mobilenet_v1': 'AvgPool_1a',
               }

# Function for model

In [None]:

def input_fn(record_file,is_training=False):
    dataset = tf.contrib.data.TFRecordDataset([record_file])

    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def _parser(record):
        keys_to_features = {
            "img_raw": tf.FixedLenFeature((), tf.string, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64,default_value=0),
            "cam": tf.FixedLenFeature((), tf.int64,default_value=0)
            }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        image = tf.decode_raw(parsed["img_raw"],tf.uint8)
        image = tf.reshape(image, [128, 64, 3])
        image = preprocess_image(image,224,is_training=is_training)
        image = tf.cast(image,tf.float32)
        label = tf.cast(parsed["label"], tf.int32)
        cam = tf.cast(parsed["cam"], tf.int32)
        return image, label, cam

    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    if is_training:
        dataset = dataset.repeat()
        dataset = dataset.shuffle(10000)
    dataset = dataset.map(_parser)
    dataset = dataset.batch(32)
    iterator = dataset.make_one_shot_iterator()

    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    
    imgs, labels, cams = iterator.get_next()
    return imgs, labels, cams


def get_restore_variabels():
    exclusions = []
    
    checkpoint_exclude_scopes=["resnet_v1_50/logit"]
    
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    # TODO(sguada) variables.filter_variables()
    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return variables_to_restore


def initialize_uninitialized_vars(sess):
    from itertools import compress
    global_vars = tf.global_variables()
    is_not_initialized = sess.run([~(tf.is_variable_initialized(var)) \
                                   for var in global_vars])
    not_initialized_vars = list(compress(global_vars, is_not_initialized))

    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))

        
 

# Training model

In [None]:
mean_loss = 0
with tf.Graph().as_default() as graph:
    ##### get train image
    record_file = '/tmp/Market-1501/market-1501_train.tfrecord'
    images,labels,cam = input_fn(record_file,True)
    labels = tf.contrib.layers.one_hot_encoding(labels,751)
    
    ####  build the network
    network_fn = nets_factory.get_network_fn('resnet_v1_50',num_classes=751,is_training=True,weight_decay=0.0005)
#     image_input = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3],name="input")

    ###  inference and define the loss
    logits, _ = network_fn(images)
    total_loss = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels)
    
    ### get train_op
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
    train_op = optimizer.minimize(total_loss,global_step=tf.train.get_global_step())
    
    ### restore variables
    variables_to_restore = get_restore_variabels()
#     restore_saver = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        
        #### initialize all the variabels
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
#         restore_saver.restore(sess, '/tmp/checkpoints/market-1501/pretrain/resnet_v1_50.ckpt')
        saver.restore(sess,'/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-16000')
#         initialize_uninitialized_vars(sess)

        ### start training
        for i in range(16001,18001):
            try:
                _,loss = sess.run([train_op,total_loss])
                mean_loss += loss
                print("step: %s , loss: %s"%(i,loss))
                if i%30 == 0:
                    clear_output()
                    print('last 30 step mean loss:%f'%(mean_loss/30))
                    mean_loss = 0
                if i % 1000 ==0:
                    saver.save(sess,'/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt',global_step=i)
            except tf.errors.OutOfRangeError:
                break


In [None]:
tf.train.latest_checkpoint('/tmp/checkpoints/market-1501/resnet_v1_50')

# Check graph operations 

In [None]:
tf.reset_default_graph()
# record_file = '/tmp/Market-1501/market-1501_train.tfrecord'
# images,labels,cam = input_fn(record_file,True)
# labels = tf.contrib.layers.one_hot_encoding(labels,751)
# network_fn = nets_factory.get_network_fn('resnet_v1_50',num_classes=751,is_training=True)
# image_input = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3],name="input")
# logits, _ = network_fn(images)
# variables_to_restore = get_restore_variabels()
# saver = tf.train.Saver(variables_to_restore,max_to_keep=4)
saver = tf.train.import_meta_graph('/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-10000.meta')

feature_name = feature_map['resnet_v1_50']
feature = tf.get_default_graph().get_tensor_by_name(name=feature_name)
feature = tf.squeeze(feature)
#         variables_to_restore = get_restore_variabels()
#         saver = tf.train.Saver(variables_to_restore)
sess = tf.Session()
saver.restore(sess,'/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-10000')

# a = sess.run(sess.graph.get_tensor_by_name('IteratorGetNext:0'))
#         softmax_tensor = sess.graph.get_tensor_by_name('resnet_v1_50/Pad/input:0')
#         softmax_tensor
#     tf.GraphKeys.GLOBAL_VARIABLES
#     sess.graph.get_all_collection_keys()


#     tf.get_collection('model_variables')

In [None]:
ops = tf.get_default_graph().get_operations()
print(sess.run(feature).shape)

In [None]:
aa = []
img_tensor = sess.graph.get_tensor_by_name('IteratorGetNext:0')
filenames  = sess.graph.get_tensor_by_name('filenames:0')
count = sess.graph.get_tensor_by_name('count:0')
batch_size = sess.graph.get_tensor_by_name('buffer_size:0')
label = sess.graph.get_tensor_by_name('OneHotEncoding/ToInt64:0')
a = sess.run(label,feed_dict={"filenames:0":['/tmp/Market-1501/market-1501_query.tfrecord'],"count:0":1,"seed:0":0,"seed2:0":0})
print(a)
# while True:
#     try:
#         a = sess.run(img_tensor)
#         aa += [a]
#     except tf.errors.OutOfRangeError:
#         break
# np.concatenate(aa).shape

# Extract features

In [None]:
features = []
classes = []
cameras = []
split_name = 'test'
check_step = 18000
record_file='/tmp/Market-1501/market-1501_%s.tfrecord'%split_name
with tf.Graph().as_default():
    images,labels,cams = input_fn(record_file)
    image_input = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3],name="input")
#     network_fn = nets_factory.get_network_fn('resnet_v1_50',num_classes=751)
        
    logits, _ = network_fn(images)
        
        
#         saver = tf.train.import_meta_graph('/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-16000.meta')
            
    feature_name = feature_map['resnet_v1_50']
    feature = tf.get_default_graph().get_tensor_by_name(name=feature_name)
    feature = tf.squeeze(feature)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,'/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-%s'%check_step)
        initialize_uninitialized_vars(sess)
        while True:
            try:
#                 np_img, np_label, np_cam = sess.run([images,labels,cams])
#                 np_feature = sess.run(feature,feed_dict={image_input:np_img})
                np_feature,np_label,np_cam = sess.run([feature,labels,cams])
                assert np_feature.shape[0] == np_label.shape[0]
                features += [np_feature]
                classes += [np_label]
                cameras += [np_cam]
            except tf.errors.OutOfRangeError:
                break

features = np.concatenate(features)
classes = np.concatenate(classes)
cameras = np.concatenate(cameras)
np.savez('/tmp/Market-1501/feature/%s'%split_name,feature=features,label=classes,cam=cameras)

In [None]:
tf.reset_default_graph()

In [None]:
g1 = tf.Graph()
g2 = tf.Graph()
with g1.as_default():

#     variables_to_restore = get_restore_variabels()
    saver = tf.train.import_meta_graph('/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-16000.meta')
    v1  = slim.get_model_variables()
#     with tf.Session() as sess:
#         saver.restore(sess,'/tmp/checkpoints/market-1501/resnet_v1_50/model.ckpt-16000')
#     for v in slim.get_model_variables():
#         print(v)
with g2.as_default():
    images,labels,cams = input_fn(record_file)
    image_input = tf.placeholder(dtype=tf.float32,shape=[None,224,224,3],name="input")
    network_fn = nets_factory.get_network_fn('resnet_v1_50',num_classes=751)
    logits, _ = network_fn(images)
    v2 = slim.get_model_variables()

for x,y in zip(v1,v2):
    print(x,y,'\n\n\n')
    

In [None]:
len(v1)