In [2]:
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

In [9]:
#原始输入数据的目录，这个目录下有5个子文件夹，每个子目录底下保存属于该类别的所有图片
INPUT_DATA='flower_photos/'
# 输出文件地址
OUTPUT_FILE='processed_data.npy'

# 测试数据和验证数据的比例
VALIDATION_PERCENTAGE=10
TEST_PERCENTAGE=10

# 读取数据并将数据分割成训练数据、验证数据和测试数据
def create_image_lists(sess,testing_percentage,validation_percentage):
    sub_dirs=[x[0] for x in os.walk(INPUT_DATA)]
    is_root_dir=True    
#     初始化各个数据集
    training_images=[]
    training_labels=[]
    testing_images=[]
    testing_labels=[]
    validation_images=[]
    validation_labels=[]
    current_labels=0
    
#     读取所有的子目录
    for sub_dir in sub_dirs:
        if is_root_dir:
            is_root_dir=False
            continue
            
#     获取一个目录中所有的图片的文件
        extensions=['jpg','jpeg','JPG','JPEG']
        file_list=[]
        dir_name=os.path.basename(sub_dir)
        for extension in extensions:
            file_glob=os.path.join(INPUT_DATA,dir_name,'*.'+extension)
            file_list.extend(glob.glob(file_glob))
            if not file_list:continue

            #处理图片数据
            for file_name in file_list:
                #读取并解析图片，将图片转化为299x299以便inception-v3模型来处理
                image_raw_data=gfile.FastGFile(file_name,'rb').read()
                image=tf.image.decode_jpeg(image_raw_data)

                if image.dtype != tf.float32:                
                    image=tf.image.convert_image_dtype(image,dtype=tf.float32)
                image = tf.image.resize_images(image,[299,299])
                image_value = sess.run(image)

    #             随机划分数据集
                chance=np.random.randint(100)
                if chance < validation_percentage:
                    validation_images.append(image_value)
                    validation_labels.append(current_labels)
                elif chance < (testing_percentage+validation_percentage):
                    testing_images.append(image_value)
                    testing_labels.append(current_labels)
                else:
                    training_images.append(image_value)
                    training_labels.append(current_labels)
            current_labels += 1
        
# 将训练数据随机打乱以获得更好的训练效果
    state=np.random.get_state()
    np.random.shuffle(training_images)
    np.random.set_state(state)
    np.random.shuffle(training_labels)
    return np.array([training_images,training_labels,validation_images,validation_labels,
                    testing_images,testing_labels])

def main():
    with tf.Session() as sess:
        process_data=create_image_lists(sess,TEST_PERCENTAGE,VALIDATION_PERCENTAGE)
    np.save(OUTPUT_FILE,process_data)

if __name__ =='__main__':
    main()

In [10]:
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim

# 加载通过tensorflow-slim定义好的inception-v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

In [4]:
INPUT_DATA='processed_data.npy'
TRAIN_FILE='FLOWER_MODEL/'
CKPT_FILE='incepton_v3.ckpt'
# 定义训练中使用的参数
LEARNING_RATE=0.0001
STEPS=300
BATCH=32
N_CLASSES=5

CHECKPOINT_EXCLUDE_SCOPES='InceptionV3/Logits,InceptionV3/AuxLogits'
TRAINABLE_SCOPES='InceptionV3/Logits,InceptionV3/AuxLogits'

# 获取所有需要从谷歌训练好的模型中加载的参数
def get_tuned_variables():
    exclusions=[scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
    variables_to_restore=[]
    for var in slim.get_model_variables():
        excluded=False
        for exclusion in exclusions:
            if var.op.name.startwith(exclusion):
                excluded=True
                break
        if not excluded:
            variables_to_restore.append(var)
    return variables_to_restore

# 获取所有需要训练的变量列表
def get_trainable_variables():
    scopes=[scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
    variables_to_train=[]
    for scope in scopes:
        variables=tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES,scope)
        variables_to_train.extend(variables)
    return variables_to_train

def main():
#     加载预处理的数据
    processed_data=np.load(INPUT_DATA)
    training_images=processed_data[0]
    n_training_example=len(training_images)
    training_labels=processed_data[1]
    validation_images=processed_data[2]
    validation_labels=processed_data[3]
    testing_images=processed_data[4]
    testing_labels=processed_data[5]
    print('%d training examples,%validation examples and %d testing examples'%(n_training_example,len(validation_labels),len(testing_labels)))
    
#     定义inception-v3的输入，images为输入图片，labels为每一张图片对应的标签，
    images=tf.placeholder(tf.float32,[None,299,299,3],name='input_images')
    labels=tf.placeholder(tf.int64,[None],name='labels')
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        logits,_=inception_v3.inception_v3(images,num_classes=N_CLASSES)
#     获取需要训练的变量
    trainable_variables=get_trainable_variables()
#     定义交叉熵损失，在模型定的时候就已经将正则化损失加入损失集合了
    tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0)
#     定义训练过程，这里的minimize的过程制定了需要优化的变量及和
    train_step=tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_losses())
    
#     计算正确率
    with tf.name_scope('evaluation'):
        correction_prediction=tf.equal(tf.argmax(logits,1),labels)
        evaluation_step=tf.reduce_mean(tf.cast(correction_prediction,tf.float32))
#         定义加载模型的函数
        load_fn=slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True)
#     定义保存新的训练好的模型
    saver=tf.train.Saver()
    with tf.Session() as sess:
        init=tf.global_variables_initializer()
        sess.run(init)
#         加载谷歌已经训练好的模型
        print('Loadign tuned variable from %s'%CKPT_FILE)
        print('Loadign tuned variable from {}'.format(CKPT_FILE))
        load_fn(sess)
        
        start=0
        end=BATCH
        for i in range(STEPS):
#             运行训练过程中，不更新全部的参数，只更新部分参数
            sesss.run(train_step,feed_dict={
                images:training_images[start:end],
                labels:training_labels[start:end]
            })
#         输出日志
            if i%30 ==0 or i+1 ==STEPS:
                saver.save(sess,TRAIN_FILE,global_step=i)
                validation_accuracy=sess.run(evaluation_step,feed_dict={images:validation_images,labels:validation_labels})
                print('step {} :,Validation_images,lables:'.format(i,validation_accuracy*100))
                
#                 因为在数据处理的时候已经做过了打乱数据的操作，所以这里只需要顺序使用训练数据就好
                start=end
                if start == n_training_example:
                    start=0
                end=start+BATCH
                if end>n_training_example:
                    end=n_training_example
#                 在最后的测试数据集上测试正确率
            
        
    
    
        
    
    
    
        
        