In [1]:
import tensorflow as tf
from tensorflow.contrib.slim import nets
slim = tf.contrib.slim
import numpy as np

# 本文涉及到一个数据集food-101，出处为：https://www.vision.ee.ethz.ch/datasets_extra/food-101 。
# 所涉及到的预训练文件resnet_v1_50.ckpt可以在https://github.com/tensorflow/models/tree/master/research/slim 下载。

  from ._conv import register_converters as _register_converters


In [None]:
class Res50(object):   
    
    def __init__(self, lr, batch_size, iter_num):
        self.lr = lr   # 学习率
        self.batch_size = batch_size
        self.iter_num = iter_num   # 总共训练多少次
        
        tf.reset_default_graph()   # 重置图。有时候大家运行程序时候会提示某某tensor已经被构造。这是因为之前创建的图还在，然后重新运行一遍代码又创建了一个新图。可以在这里加一句tf.reset_default_graph()
        
        self.X = tf.placeholder(tf.float32, [None, 224, 224, 3])
        self.y = tf.placeholder(tf.float32, [None, 101])   # 食物数据集有101个类
        self.dropRate = tf.placeholder(tf.float32)
        self.isTraining = tf.placeholder(tf.bool)
        with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
            net, endpoints = nets.resnet_v1.resnet_v1_50(self.X, is_training=self.isTraining, num_classes=None)        
            # 在这里，我们直接使用预置的模型。
        net = tf.reshape(net , [-1, 2048])
        # 下面这些，大家应该非常熟悉了，和MNIST的一样的
        net = tf.nn.dropout(net, self.dropRate)
        logits = tf.layers.dense(net, 101, use_bias=True,
                                 kernel_initializer=tf.constant_initializer(0),
                                 bias_initializer=tf.constant_initializer(0))
        self.logits = logits
        self.loss = tf.losses.softmax_cross_entropy(onehot_labels=self.y, logits=logits)
        self.train_step = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss)
        
        # 用于模型训练
        self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(logits, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
        
#         用于保存训练好的模型
        self.saver = tf.train.Saver()
    def read_image_label_list(self, class_file, data_file):
        # 读取图像文件和标注列表
        
        img_list=[]
        label_list=[]
        class_dict = {}
        
        with open(class_file) as fr:
            for i, line in enumerate(fr.readlines()):
                class_dict[line.strip()] = i
                
        with open(data_file) as fr:
            l = fr.readline()
            while(l):
                img_r = 'food-101/images/%s.jpg' % l.strip()
                label_r = class_dict[l.split('/')[0]]
                img_list.append(img_r)
                label_list.append(int(label_r))
                l = fr.readline()
        return img_list, label_list    
    def read_file(self,class_file, data_file):
        image_list, label_list = self.read_image_label_list(class_file, data_file)
        imagepaths, labels = tf.train.slice_input_producer([image_list, label_list], shuffle=True)
        image = tf.read_file(imagepaths)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize_images(image, [224, 224])
        image = (image * 1.0 / 127.5 - 1)
        label = tf.one_hot(labels, 101)
        X, Y = tf.train.batch([image, label], batch_size=self.batch_size, num_threads=2, capacity=self.batch_size*4)    
        return X, Y
    def train(self):
        training_images, training_labels = self.read_file(r'food-101/meta/classes.txt',r'food-101/meta/train.txt')
        test_images, test_labels = self.read_file(r'food-101/meta/classes.txt',r'food-101/meta/test.txt')
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)      
            variables_to_restore = slim.get_variables_to_restore()
            init_fn = slim.assign_from_checkpoint_fn(r'pre_trained/resnet_v1_50.ckpt',
                                         variables_to_restore,
                                         ignore_missing_vars=True)
            init_fn(sess)
            for i in range(self.iter_num):   
                images, labels = sess.run([training_images, training_labels])       

                feed_dict = {self.dropRate: 0.5,
                         self.X :images,
                         self.y :labels,
                         self.isTraining:True}           

                loss, _ = sess.run([self.loss, self.train_step], 
                                  feed_dict=feed_dict)   # 每调用一次sess.run，就像拧开水管一样，所有self.loss和self.train_step涉及到的运算都会被调用一次。
                if i%100 == 0:   
                    images, labels = sess.run([training_images, training_labels]) 
                    train_accuracy = sess.run(self.accuracy, feed_dict={self.X: images, self.y: labels, self.dropRate: 1., self.isTraining:True})  # 把训练集数据装填进去
                    images, labels = sess.run([test_images, test_labels])
                    test_accuracy = sess.run(self.accuracy, feed_dict={self.X: images, self.y: labels, self.dropRate: 1., self.isTraining:True})  # 把训练集数据装填进去
                    
                    print ('iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy, test_accuracy))
           
            self.saver.save(sess, 'model/foodModel') # 保存模型                              
             
            coord.request_stop()
            coord.join(threads)
    def test(self):
        test_images, test_labels = self.read_file(r'food-101/meta/classes.txt',r'food-101/meta/test.txt')
        with tf.Session() as sess:
            self.saver.restore(sess, 'model/foodModel')
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)     
            Accuracy = []
            
         
            for i in range(int(25250/self.batch_size)):
                images, labels = sess.run([test_images, test_labels])
                test_accuracy = sess.run(self.accuracy, feed_dict={self.X: images, self.y: labels, self.dropRate: 1., self.isTraining:True})  # 把训练集数据装填进去
                Accuracy.append(test_accuracy)

            print('==' * 15) 
            print( 'Test Accuracy: ', np.mean(np.array(Accuracy))   ) 
            coord.request_stop()
            coord.join(threads)

In [None]:
model = Res50(0.1, 64, 10000)   # 学习率为0.1，每批传入64张图，训练10000次
model.train()      # 训练模型

In [None]:
model.test()