In [1]:
from __future__ import print_function, division

import tensorflow as tf

print(tf.__version__)

import keras

import matplotlib.pyplot as plt

import sys

from sklearn import preprocessing
from sklearn.cross_validation import train_test_split

import time
from tqdm import tqdm

from imutils import paths
from numpy import *
import numpy as np

from matplotlib.colors import Normalize

import argparse
import imutils,sklearn
import os, cv2, re, random, shutil, imageio, pickle

%matplotlib inline  

1.3.0


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
#定义最大灰度级数
gray_level = 128

def maxGrayLevel(img):
    max_gray_level=0
    (height,width)=img.shape
    for y in range(height):
        for x in range(width):
            if img[y][x] > max_gray_level:
                max_gray_level = img[y][x]
    return max_gray_level+1

def getGlcm(input,d_x,d_y):
    srcdata=input.copy()
    ret=[[0.0 for i in range(gray_level)] for j in range(gray_level)]
    (height,width) = input.shape

    max_gray_level=maxGrayLevel(input)

    #若灰度级数大于gray_level，则将图像的灰度级缩小至gray_level，减小灰度共生矩阵的大小
    if max_gray_level > gray_level:
        for j in range(height):
            for i in range(width):
                srcdata[j][i] = srcdata[j][i]*gray_level / max_gray_level

    for j in range(height-d_y):
        for i in range(width-d_x):
            rows = srcdata[j][i]
            cols = srcdata[j + d_y][i+d_x]
            ret[rows][cols]+=1.0

    for i in range(gray_level):
        for j in range(gray_level):
            ret[i][j]/=float(height*width)

    return np.array(ret)
    
def batch_GLCM(images):
    greycomatrix_list = []
    for i in tqdm(range(len(images))):
        img = (images.astype(np.float32)* 255)[i, :, :, :].astype(np.uint8)
        gray=cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        glcm_list = []
        
        glcm_0=getGlcm(gray, 1,0)
        glcm_0=cv2.normalize(glcm_0, glcm_0, 0, 1, cv2.NORM_MINMAX) 
        glcm_list.append(glcm_0)
        
        glcm_1=getGlcm(gray, 0,1)
        glcm_1=cv2.normalize(glcm_1, glcm_1, 0, 1, cv2.NORM_MINMAX) 
        glcm_list.append(glcm_1)
        
        glcm_2=getGlcm(gray, 1,1)
        glcm_2=cv2.normalize(glcm_2, glcm_2, 0, 1, cv2.NORM_MINMAX) 
        glcm_list.append(glcm_2)

        glcm_list=np.array(glcm_list,dtype = float32)
        greycomatrix_list.append(glcm_list)
        
    greycomatrix_list = np.array(greycomatrix_list,dtype = float32).reshape(
        [images.shape[0],images.shape[1],images.shape[2],3])
    
    print (greycomatrix_list.shape)
    return greycomatrix_list

In [3]:
def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split('(\d+)', text) ]

def load_flower_data():
    # grab the list of images that we'll be describing
    print("[INFO] handling images...")
    TRAIN_ORIGINAL_DIR = '../train/'
    TRAIN_SUB_DIR = '../subsample/'
    TRAIN_GAN = '../../image_gan/'
    TEST_DIR = '../../test/'

    # use this for full dataset
    train_images_gan = [TRAIN_GAN + i for i in os.listdir(TRAIN_GAN)]
    test_images = [TEST_DIR + i for i in os.listdir(TEST_DIR)]
    
    train_images = train_images_gan
    
    train_images.sort(key=natural_keys)
    test_images.sort(key=natural_keys)

    # initialize the features matrix and labels list
    trainImage = []
    trainLabels = []
    testImage = []
    testLabels = []

    # loop over the input images
    for (i, imagePath) in enumerate(train_images):
        # extract the class label
        # get the labels from the name of the images by extract the string before "_"
        label = imagePath.split(os.path.sep)[-1].split("_")[0]

        # read and resize image
        img = cv2.imread(imagePath)
        img = cv2.resize(img, (128,128))

        # add the messages we got to features and labels matricies
        trainImage.append(img)
        trainLabels.append(label)

        # show an update every 100 images until the last image
        if i > 0 and ((i + 1) % 1000 == 0 or i == len(train_images) - 1):
            print("[INFO] processed {}/{}".format(i + 1, len(train_images)))
            
      # loop over the input images
    for (i, imagePath) in enumerate(test_images):
        # extract the class label
        # our images were named as labels.image_number.format
        # get the labels from the name of the images by extract the string before "."
        label = imagePath.split(os.path.sep)[-1].split("_")[0]

        # extract CNN features in the image
        img = cv2.imread(imagePath)
        img = cv2.resize(img, (128,128))

        # add the messages we got to features and labels matricies
        testImage.append(img)
        testLabels.append(label)

        # show an update every 100 images until the last image
        if i > 0 and ((i + 1) % 1000 == 0 or i == len(test_images) - 1):
            print("[INFO] processed {}/{}".format(i + 1, len(test_images)))


    trainImage = np.array(trainImage,dtype = float32)
    trainLabels = np.array(trainLabels)
    testImage = np.array(testImage,dtype = float32)
    testLabels = np.array(testLabels)
    print (trainImage.shape)
    
    trainImage = trainImage.astype(np.float32) / 255
    testImage = testImage.astype(np.float32) / 255
    
    le = preprocessing.LabelEncoder()
    le.fit(trainLabels)
    list(le.classes_)
    trainLabels = le.transform(trainLabels) 
    testLabels = le.transform(testLabels) 
    
    return trainImage, trainLabels, testImage, testLabels

In [4]:
trainImage, trainLabels, testImage, testLabels = load_flower_data()

trainImage_GLCM = batch_GLCM(trainImage)
testImage_GLCM  = batch_GLCM(testImage)
nb_classes = 2

# Convert class vectors to binary class matrices.
trainLabels = keras.utils.to_categorical(trainLabels, nb_classes)
print (trainLabels)
testLabels = keras.utils.to_categorical(testLabels, nb_classes)
print (testLabels)
print (testLabels.shape)

np.save('../trainImage.npy', trainImage)
np.save('../trainLabels.npy', trainLabels)
np.save('../testImage.npy', testImage)
np.save('../testLabels.npy', testLabels)
np.save('../trainImage_GLCM.npy', trainImage_GLCM)
np.save('../testImage_GLCM.npy', testImage_GLCM)

print("[INFO] trainImage matrix: {:.2f}MB".format(
    (trainImage.nbytes) / (1024 * 1000.0)))
print("[INFO] trainLabels matrix: {:.4f}MB".format(
    (trainLabels.nbytes) / (1024 * 1000.0)))
print("[INFO] testImage matrix: {:.2f}MB".format(
    (testImage.nbytes) / (1024 * 1000.0)))
print("[INFO] testLabels matrix: {:.4f}MB".format(
    (testLabels.nbytes) / (1024 * 1000.0)))
print("[INFO] trainImage_GLCM matrix: {:.2f}MB".format(
    (trainImage_GLCM.nbytes) / (1024 * 1000.0)))
print("[INFO] testImage_GLCM matrix: {:.4f}MB".format(
    (testImage_GLCM.nbytes) / (1024 * 1000.0)))


[INFO] handling images...
[INFO] processed 1000/6000
[INFO] processed 2000/6000
[INFO] processed 3000/6000
[INFO] processed 4000/6000
[INFO] processed 5000/6000
[INFO] processed 6000/6000
[INFO] processed 154/154
(6000, 128, 128, 3)


100%|██████████| 6000/6000 [1:32:08<00:00,  1.09it/s]
  1%|          | 1/154 [00:00<00:27,  5.54it/s]

(6000, 128, 128, 3)


100%|██████████| 154/154 [00:28<00:00,  5.36it/s]


(154, 128, 128, 3)
[[1. 0.]
 [1. 0.]
 [1. 0.]
 ...
 [0. 1.]
 [0. 1.]
 [0. 1.]]
[[1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0. 1.]
 [0

In [5]:
class CNN(object):
    def __init__(self, sess, epoch, batch_size, dataset_name, checkpoint_dir, log_dir, trainhist_dir):
        self.sess = sess
        self.dataset_name = dataset_name
        self.checkpoint_dir = checkpoint_dir
        self.log_dir = log_dir
        self.trainhist_dir = trainhist_dir
        self.epoch = epoch
        self.batch_size = batch_size
        
        self.classname = ['Iris', 'Pansy']

        # parameters
        self.input_height = 128
        self.input_width = 128
        self.c_dim = 3  # color dimension
        self.nb_class = 2
        
        # number of convolutional filters to use  
        self.nb_CNN = [32, 64, 64, 64, 128]  
        # number of dense filters to use  
        self.nb_Dense = [256] 
        # size of pooling area for max pooling  
        self.pool_size = (2, 2)  
        # convolution kernel size  
        self.kernel_size = (3, 3)
        self.batch_normalization_control = True
        
        # name for checkpoint
        self.model_name = 'CNN_GLCM_C%d_D%d_Kernel(%d,%d)_%d_lrdecay' % (len(self.nb_CNN), len(self.nb_Dense),
                                                          self.kernel_size[0], self.kernel_size[1], max(self.nb_CNN))

        # train
        #设置一个全局的计数器
        self.global_step = tf.Variable(0, trainable=False)
        self.lr = tf.train.exponential_decay(0.001, 
                                             global_step=self.global_step, 
                                             decay_steps=10, 
                                             decay_rate=0.9, 
                                             staircase=True)
        self.beta1 = 0.5
        #max model to keep saving
        self.max_to_keep = 300
        
        # test

        #load_flower_data
        self.train_x = np.load('../trainImage.npy')
        self.train_y = np.load('../trainLabels.npy')
        self.test_x = np.load('../testImage.npy')
        self.test_y = np.load('../testLabels.npy')
        self.train_x_glcm = np.load('../trainImage_GLCM.npy')
        self.test_x_glcm = np.load('../testImage_GLCM.npy')
        
        #记录
        self.train_hist = {}
        self.train_hist['losses'] = []
        self.train_hist['accuracy'] = []
        self.train_hist['learning_rate'] = []
        self.train_hist['per_epoch_ptimes'] = []
        self.train_hist['total_ptime'] = []
        
        # get number of batches for a single epoch
        self.num_batches_train = len(self.train_x) // self.batch_size
        self.num_batches_test= len(self.test_x) // self.batch_size

    def cnn_model(self, x, x_GLCM, keep_prob, is_training=True, reuse=False):
        with tf.variable_scope("cnn", reuse=reuse):
             
            #初始化参数
            W = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
            B = tf.constant_initializer(0.0)
        
            print("CNN:x",x.get_shape()) # 128, 128, 3 
            print("CNN:x_GLCM",x_GLCM.get_shape()) # 128, 128, 3 
            
            #输入x,卷积核为3*3 输出维度为32
            net1_1 = tf.layers.conv2d(inputs = x,                 # 输入,
                                    filters = self.nb_CNN[0],      # 卷积核个数,
                                    kernel_size = self.kernel_size,          # 卷积核尺寸
                                    strides = (1, 1),
                                    padding = 'same',              # padding方法
                                    kernel_initializer = W,
                                    bias_initializer = B,
                                    kernel_regularizer = None,
                                    bias_regularizer = None,
                                    activity_regularizer = None,
                                    name = 'conv_1_1'               # 命名用于获取变量
                                    )
            print("CNN:",net1_1.get_shape())
            
            #输入x,卷积核为3*3 输出维度为32
            net1_2 = tf.layers.conv2d(inputs = x_GLCM,                 # 输入,
                                    filters = self.nb_CNN[0],      # 卷积核个数,
                                    kernel_size = self.kernel_size,          # 卷积核尺寸
                                    strides = (1, 1),
                                    padding = 'same',              # padding方法
                                    kernel_initializer = W,
                                    bias_initializer = B,
                                    kernel_regularizer = None,
                                    bias_regularizer = None,
                                    activity_regularizer = None,
                                    name = 'conv_1_2'               # 命名用于获取变量
                                    )
            print("CNN:",net1_2.get_shape())

            #把数据和边缘进行连接
            net = tf.concat([net1_1, net1_2], 3)
            net = tf.layers.batch_normalization(net, training=is_training)
            net = tf.nn.relu(net, name = 'relu_conv_1')
            print("CNN:",net.get_shape())
            net = tf.layers.max_pooling2d(inputs = net,
                                              pool_size = self.pool_size,
                                              strides = (2, 2),
                                              padding = 'same',
                                              name = 'pool_conv_1'
                                             )
            
            for i in range(2,len(self.nb_CNN)+1):
                net = tf.layers.conv2d(inputs = net,                 # 输入,
                                       filters = self.nb_CNN[i-1],      # 卷积核个数,
                                       kernel_size = self.kernel_size,          # 卷积核尺寸
                                       strides = (1, 1),
                                       padding = 'same',              # padding方法
                                       kernel_initializer = W,
                                       bias_initializer = B,
                                       kernel_regularizer = None,
                                       bias_regularizer = None,
                                       activity_regularizer = None,
                                       name = 'conv_'+ str(i)        # 命名用于获取变量
                                       )
                print("CNN:",net.get_shape())
                if self.batch_normalization_control:
                    net = tf.layers.batch_normalization(net, training=is_training)
                net = tf.nn.relu( net, name = 'relu_conv_' + str(i))
                net = tf.layers.max_pooling2d(inputs = net,
                                              pool_size = self.pool_size,
                                              strides = (2, 2),
                                              padding = 'same',
                                              name = 'pool_conv_' + str(i)
                                             )
                print("CNN:",net.get_shape())
            
            #flatten
            net = tf.reshape(net, [-1, int(net.get_shape()[1]*net.get_shape()[2]*net.get_shape()[3])],name='flatten')
            print("CNN:",net.get_shape())
            
            #dense layer
            for i in range(1,len(self.nb_Dense)+1):
                net = tf.layers.dense(inputs = net,
                                    units = self.nb_Dense[i-1],
                                    kernel_initializer = W,
                                    bias_initializer = B,
                                    kernel_regularizer=None,
                                    bias_regularizer=None,
                                    activity_regularizer=None,
                                    name = 'dense_' + str(i)
                                    )
#                 net = tf.layers.batch_normalization(net, training=is_training)
                net = tf.nn.relu( net, name = 'relu_dense_' + str(i))
                net = tf.layers.dropout(inputs = net,
                                        rate=keep_prob,
                                        noise_shape=None,
                                        seed=None,
                                        training = is_training,
                                        name= 'dropout_dense_' + str(i)
                                        )
            #output
            logit = tf.layers.dense(inputs = net,
                                    units = self.nb_class,
                                    kernel_initializer = W,
                                    bias_initializer = B,
                                    kernel_regularizer=None,
                                    bias_regularizer=None,
                                    activity_regularizer=None,
                                    name = 'dense_output'
                                    )
            out_logit = tf.nn.softmax(logit, name="softmax")
            print("CNN:out_logit",out_logit.get_shape())
            print("------------------------")    

            return out_logit, logit


    def build_model(self):

        """ Graph Input """
        # images
        self.x = tf.placeholder(tf.float32, shape=[self.batch_size,self.input_height, self.input_width, self.c_dim], 
                                name='x_image')
        
        self.x_GLCM = tf.placeholder(tf.float32, shape=[self.batch_size,self.input_height, self.input_width, self.c_dim], 
                                name='x_GLCM')

        self.y = tf.placeholder(tf.float32, shape=[self.batch_size, self.nb_class], name='y_label')
        
        self.keep_prob = tf.placeholder(tf.float32)
        
        self.add_global = self.global_step.assign_add(1)

        """ Loss Function """

        # output of cnn_model
        self.out_logit, self.logit = self.cnn_model(self.x, self.x_GLCM, self.keep_prob, is_training=True, reuse=False)
        
        self.loss_cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y,
                                                                                         logits =self.logit))
        
        """ Training """
        # trainable variables into a group
        tf_vars = tf.trainable_variables()
        cnn_vars = [var for var in tf_vars if var.name.startswith('cnn')]

        # optimizers
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            self.cnn_optim = tf.train.AdamOptimizer(self.lr, beta1=self.beta1).minimize(self.loss_cross_entropy,
                                                                                        var_list=cnn_vars)
        """" Testing """
        # for test
        # output of cnn_model
        self.out_logit_test, self.logit_test = self.cnn_model(self.x, self.x_GLCM, self.keep_prob, is_training=False, reuse=True)
        self.correct_prediction = tf.equal(tf.argmax(self.logit_test, 1), tf.argmax(self.y, 1))
        self.predict = tf.argmax(self.logit_test, 1)
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))

        """ Summary """
        self.loss_sum = tf.summary.scalar("loss", self.loss_cross_entropy)


    def train(self):

        # initialize all variables
        tf.global_variables_initializer().run()
#         sess.run(tf.global_variables_initializer())

        # saver to save model
        self.saver = tf.train.Saver(max_to_keep = self.max_to_keep)

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_epoch = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_epoch) + 1
            counter = 1
            f = open(self.trainhist_dir + '/' + self.model_name+'.pkl', 'rb') 
            self.train_hist = pickle.load(f)
            f.close()
            print(" [*] Load SUCCESS")
            print(" [!] START_EPOCH is ", start_epoch)
        else:
            start_epoch = 1
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        for epoch_loop in range(start_epoch, self.epoch + 1):

            CNN_losses = []
  
            epoch_start_time = time.time()
            shuffle_idxs = random.sample(range(0, self.train_x.shape[0]), self.train_x.shape[0])
            shuffled_set = self.train_x[shuffle_idxs]
            shuffled_set_glcm = self.train_x_glcm[shuffle_idxs]
            shuffled_label = self.train_y[shuffle_idxs]
    
            # get batch data
            for idx in range(self.num_batches_train):
                batch_x = shuffled_set[idx*self.batch_size:(idx+1)*self.batch_size]
#                 batch_x_GLCM = self.batch_GLCM(batch_x)
                batch_x_GLCM = shuffled_set_glcm[idx*self.batch_size:(idx+1)*self.batch_size]
                batch_y = shuffled_label[idx*self.batch_size:(idx+1)*self.batch_size].reshape(
                                        [self.batch_size, self.nb_class])
                

                # update D network
                _, summary_str, cnn_loss = self.sess.run([self.cnn_optim, self.loss_sum, self.loss_cross_entropy],
                                               feed_dict={self.x: batch_x,
                                                          self.x_GLCM: batch_x_GLCM,
                                                          self.y: batch_y,
                                                          self.keep_prob: 0.5}
                                                      )
                self.writer.add_summary(summary_str, counter)

                CNN_losses.append(cnn_loss)

                # display training status
                counter += 1
                print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f" % (epoch_loop, idx, self.num_batches_train, 
                                                                          time.time() - start_time, cnn_loss))

            # After an epoch
            # Evaluates accuracy on test set
            test_accuracy_list = []
            for idx_test in range(self.num_batches_test):
                batch_x_test = self.test_x[idx_test*self.batch_size:(idx_test+1)*self.batch_size]
                batch_x_GLCM_test =self.test_x_glcm[idx_test*self.batch_size:(idx_test+1)*self.batch_size]
                batch_y_tes = self.test_y[idx_test*self.batch_size:(idx_test+1)*self.batch_size].reshape(
                                        [self.batch_size, self.nb_class])
                accuracy = self.sess.run([self.accuracy],feed_dict={self.x: batch_x_test, 
                                                                    self.x_GLCM: batch_x_GLCM_test,
                                                                    self.y: batch_y_tes,
                                                                    self.keep_prob: 1.0})
                test_accuracy_list.append(accuracy)
            test_accuracy = np.sum(test_accuracy_list)/self.num_batches_test
        
            #update learning rate
            _, rate = sess.run([self.add_global, self.lr])
            
            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            
            print('[%d/%d] - ptime: %.4f loss: %.8f acc: %.5f lr: %.8f'% (epoch_loop, self.epoch, per_epoch_ptime, 
                                                                    np.mean(CNN_losses), test_accuracy, rate))
            
            self.train_hist['losses'].append(np.mean(CNN_losses))
            self.train_hist['accuracy'].append( test_accuracy)
            self.train_hist['learning_rate'].append(rate)
            self.train_hist['per_epoch_ptimes'].append(per_epoch_ptime)
            
            # save model
            self.save(self.checkpoint_dir, epoch_loop)
            
            # save trainhist for train
            f = open(self.trainhist_dir + '/' + self.model_name + '.pkl', 'wb') 
            pickle.dump(self.train_hist, f)
            f.close()
            self.show_train_hist(self.train_hist, save=True, path= self.trainhist_dir + '/' 
                                 + self.model_name + '.png')

        end_time = time.time()
        total_ptime = end_time - start_time
        self.train_hist['total_ptime'].append(total_ptime)
        print('Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f' % (np.mean(self.train_hist['per_epoch_ptimes']), 
                                                                          self.epoch, total_ptime))
        print(" [*] Training finished!")
        
        """test after train"""
        best_acc = max(self.train_hist['accuracy'])
        beat_epoch = self.train_hist['accuracy'].index(best_acc) + 1
        print (" [*] Best Epoch: ", beat_epoch, ", Accuracy: ", best_acc)
        path_model = self.checkpoint_dir + '/' + self.model_name + '/'+ self.model_name +'-'+ str(beat_epoch)

        """ restore epoch """
        new_saver = tf.train.import_meta_graph(path_model + '.meta' )
        new_saver.restore(self.sess,path_model)

        # Evaluates accuracy on test set
        test_accuracy_list = []
        for idx_test in range(self.num_batches_test):
            batch_x_test = self.test_x[idx_test*self.batch_size:(idx_test+1)*self.batch_size]
            batch_x_GLCM_test =self.test_x_glcm[idx_test*self.batch_size:(idx_test+1)*self.batch_size]
            batch_y_tes = self.test_y[idx_test*self.batch_size:(idx_test+1)*self.batch_size].reshape(
                                    [self.batch_size, self.nb_class])
            accuracy = self.sess.run([self.accuracy],feed_dict={self.x: batch_x_test, 
                                                                self.x_GLCM: batch_x_GLCM_test,
                                                                self.y: batch_y_tes,
                                                                self.keep_prob: 1.0})
            test_accuracy_list.append(accuracy)
        test_accuracy = np.sum(test_accuracy_list)/self.num_batches_test
        print(" [*] Finished testing Best Epoch:", beat_epoch, ", accuracy: ",test_accuracy, "!")

    def test(self, epoch):
        path_model = self.checkpoint_dir + '/' + self.model_name + '/'+ self.model_name +'-'+ str(epoch)

        """ restore epoch """
        new_saver = tf.train.import_meta_graph(path_model + '.meta' )
        new_saver.restore(self.sess,path_model)

        # Evaluates accuracy on test set
        test_accuracy_list = []
        for idx_test in range(self.num_batches_test):
            batch_x_test = self.test_x[idx_test*self.batch_size:(idx_test+1)*self.batch_size]
            batch_x_GLCM_test =self.test_x_glcm[idx_test*self.batch_size:(idx_test+1)*self.batch_size]
            batch_y_tes = self.test_y[idx_test*self.batch_size:(idx_test+1)*self.batch_size].reshape(
                                    [self.batch_size, self.nb_class])
            accuracy = self.sess.run([self.accuracy],feed_dict={self.x: batch_x_test, 
                                                                self.x_GLCM: batch_x_GLCM_test,
                                                                self.y: batch_y_tes,
                                                                self.keep_prob: 1.0})
            test_accuracy_list.append(accuracy)
        test_accuracy = np.sum(test_accuracy_list)/self.num_batches_test
        print(" [*] Finished testing Epoch:", epoch, ", accuracy: ",test_accuracy, "!")
        
    def show_all_variables(self):
        model_vars = tf.trainable_variables()
        tf.contrib.slim.model_analyzer.analyze_vars(model_vars, print_info=True) 

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_name)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,os.path.join(checkpoint_dir, self.model_name), global_step=step)

    def load(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_name)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            epoch = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
            print(" [*] Success to read [{}], epoch [{}]".format(ckpt_name,epoch))
            return True, epoch
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0
        
    def show_train_hist(self, hist, show = False, save = False, path = 'Train_hist.png'):
        x = range(1, len(hist['losses'])+1)

        y1 = hist['losses']
        y2 = hist['accuracy']
        
        fig, ax1 = plt.subplots()
                            
        ax2 = ax1.twinx()  

        ax1.plot(x, y1, 'b')
        ax2.plot(x, y2, 'r')
                            
        ax1.set_xlabel('Epoch')
                            
        ax1.set_ylabel('CNN_loss')    
        ax2.set_ylabel('accuracy')

        plt.grid(True)
        plt.tight_layout()

        if save:
            plt.savefig(path, dpi = 400)

        if show:
            plt.show()
        else:
            plt.close()
   

In [6]:
dataset = '4_Flowers_1s'
epoch = 50
batch_size = 100
checkpoint_dir = 'checkpoint'
log_dir = 'logs'
trainhist_dir = 'trainhist'

if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

# --log_dir
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    
# --trainhist_dir
if not os.path.exists(trainhist_dir):
    os.makedirs(trainhist_dir)

# open session
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    
    # declare instance for GAN
    CNN = CNN(sess, epoch=epoch, batch_size=batch_size, dataset_name=dataset, checkpoint_dir=checkpoint_dir, 
                log_dir=log_dir, trainhist_dir=trainhist_dir)

    # build graph
    CNN.build_model()

    # show network architecture
    CNN.show_all_variables()

    # launch the graph in a session
    CNN.train()
    
#     CNN.test(epoch)
        
sess.close()
        
# lrdecay
# Avg per epoch ptime: 16.17, total 50 epochs ptime: 830.44
#  [*] Training finished!
#  [*] Best Epoch:  6 , Accuracy:  0.8500000238418579
# INFO:tensorflow:Restoring parameters from checkpoint/CNN_GLCM_C5_D1_Kernel(3,3)_128_lrdecay/CNN_GLCM_C5_D1_Kernel(3,3)_128_lrdecay-6
#  [*] Finished testing Best Epoch: 6 , accuracy:  0.8500000238418579 !

CNN:x (100, 128, 128, 3)
CNN:x_GLCM (100, 128, 128, 3)
CNN: (100, 128, 128, 32)
CNN: (100, 128, 128, 32)
CNN: (100, 128, 128, 64)
CNN: (100, 64, 64, 64)
CNN: (100, 32, 32, 64)
CNN: (100, 32, 32, 64)
CNN: (100, 16, 16, 64)
CNN: (100, 16, 16, 64)
CNN: (100, 8, 8, 64)
CNN: (100, 8, 8, 128)
CNN: (100, 4, 4, 128)
CNN: (100, 2048)
CNN:out_logit (100, 2)
------------------------
CNN:x (100, 128, 128, 3)
CNN:x_GLCM (100, 128, 128, 3)
CNN: (100, 128, 128, 32)
CNN: (100, 128, 128, 32)
CNN: (100, 128, 128, 64)
CNN: (100, 64, 64, 64)
CNN: (100, 32, 32, 64)
CNN: (100, 32, 32, 64)
CNN: (100, 16, 16, 64)
CNN: (100, 16, 16, 64)
CNN: (100, 8, 8, 64)
CNN: (100, 8, 8, 128)
CNN: (100, 4, 4, 128)
CNN: (100, 2048)
CNN:out_logit (100, 2)
------------------------
---------
Variables: name (type shape) [size]
---------
cnn/conv_1_1/kernel:0 (float32_ref 3x3x3x32) [864, bytes: 3456]
cnn/conv_1_1/bias:0 (float32_ref 32) [32, bytes: 128]
cnn/conv_1_2/kernel:0 (float32_ref 3x3x3x32) [864, bytes: 3456]
cnn/conv_1_2

Epoch: [ 2] [  40/  60] time: 35.8982, loss: 0.07541696
Epoch: [ 2] [  41/  60] time: 36.1432, loss: 0.15556292
Epoch: [ 2] [  42/  60] time: 36.3899, loss: 0.06230674
Epoch: [ 2] [  43/  60] time: 36.6330, loss: 0.11221454
Epoch: [ 2] [  44/  60] time: 36.8800, loss: 0.09476766
Epoch: [ 2] [  45/  60] time: 37.1262, loss: 0.03699632
Epoch: [ 2] [  46/  60] time: 37.3737, loss: 0.09664087
Epoch: [ 2] [  47/  60] time: 37.6228, loss: 0.02136010
Epoch: [ 2] [  48/  60] time: 37.8694, loss: 0.04935395
Epoch: [ 2] [  49/  60] time: 38.1170, loss: 0.02233471
Epoch: [ 2] [  50/  60] time: 38.3658, loss: 0.05430487
Epoch: [ 2] [  51/  60] time: 38.6123, loss: 0.04340994
Epoch: [ 2] [  52/  60] time: 38.8597, loss: 0.03378976
Epoch: [ 2] [  53/  60] time: 39.1070, loss: 0.04248616
Epoch: [ 2] [  54/  60] time: 39.3542, loss: 0.07656178
Epoch: [ 2] [  55/  60] time: 39.6010, loss: 0.08834961
Epoch: [ 2] [  56/  60] time: 39.8477, loss: 0.09795763
Epoch: [ 2] [  57/  60] time: 40.0969, loss: 0.1

Epoch: [ 5] [   3/  60] time: 76.5586, loss: 0.01816551
Epoch: [ 5] [   4/  60] time: 76.8060, loss: 0.04573194
Epoch: [ 5] [   5/  60] time: 77.0534, loss: 0.03545017
Epoch: [ 5] [   6/  60] time: 77.3013, loss: 0.03636487
Epoch: [ 5] [   7/  60] time: 77.5471, loss: 0.02704026
Epoch: [ 5] [   8/  60] time: 77.7962, loss: 0.01861658
Epoch: [ 5] [   9/  60] time: 78.0438, loss: 0.04615198
Epoch: [ 5] [  10/  60] time: 78.2922, loss: 0.02500204
Epoch: [ 5] [  11/  60] time: 78.5418, loss: 0.00325396
Epoch: [ 5] [  12/  60] time: 78.7897, loss: 0.05094259
Epoch: [ 5] [  13/  60] time: 79.0391, loss: 0.02229457
Epoch: [ 5] [  14/  60] time: 79.2864, loss: 0.05882150
Epoch: [ 5] [  15/  60] time: 79.5389, loss: 0.02417917
Epoch: [ 5] [  16/  60] time: 79.7859, loss: 0.04267686
Epoch: [ 5] [  17/  60] time: 80.0337, loss: 0.01654977
Epoch: [ 5] [  18/  60] time: 80.2833, loss: 0.08175111
Epoch: [ 5] [  19/  60] time: 80.5329, loss: 0.02372357
Epoch: [ 5] [  20/  60] time: 80.7749, loss: 0.0

Epoch: [ 7] [  26/  60] time: 115.0206, loss: 0.00462979
Epoch: [ 7] [  27/  60] time: 115.2692, loss: 0.01728161
Epoch: [ 7] [  28/  60] time: 115.5205, loss: 0.00963710
Epoch: [ 7] [  29/  60] time: 115.7687, loss: 0.02206315
Epoch: [ 7] [  30/  60] time: 116.0168, loss: 0.01959185
Epoch: [ 7] [  31/  60] time: 116.2651, loss: 0.00494419
Epoch: [ 7] [  32/  60] time: 116.5154, loss: 0.02525906
Epoch: [ 7] [  33/  60] time: 116.7632, loss: 0.04734979
Epoch: [ 7] [  34/  60] time: 117.0106, loss: 0.02699576
Epoch: [ 7] [  35/  60] time: 117.2578, loss: 0.04857077
Epoch: [ 7] [  36/  60] time: 117.5089, loss: 0.07557540
Epoch: [ 7] [  37/  60] time: 117.7559, loss: 0.01559114
Epoch: [ 7] [  38/  60] time: 118.0049, loss: 0.05790457
Epoch: [ 7] [  39/  60] time: 118.2520, loss: 0.00827674
Epoch: [ 7] [  40/  60] time: 118.5004, loss: 0.05701111
Epoch: [ 7] [  41/  60] time: 118.7499, loss: 0.02615828
Epoch: [ 7] [  42/  60] time: 118.9986, loss: 0.01501617
Epoch: [ 7] [  43/  60] time: 1

Epoch: [ 9] [  48/  60] time: 153.3432, loss: 0.01356924
Epoch: [ 9] [  49/  60] time: 153.5923, loss: 0.02026898
Epoch: [ 9] [  50/  60] time: 153.8396, loss: 0.01877137
Epoch: [ 9] [  51/  60] time: 154.0892, loss: 0.02442599
Epoch: [ 9] [  52/  60] time: 154.3369, loss: 0.02708188
Epoch: [ 9] [  53/  60] time: 154.5864, loss: 0.00103279
Epoch: [ 9] [  54/  60] time: 154.8359, loss: 0.02573194
Epoch: [ 9] [  55/  60] time: 155.0827, loss: 0.01969676
Epoch: [ 9] [  56/  60] time: 155.3316, loss: 0.00105105
Epoch: [ 9] [  57/  60] time: 155.5806, loss: 0.00248933
Epoch: [ 9] [  58/  60] time: 155.8315, loss: 0.00226348
Epoch: [ 9] [  59/  60] time: 156.0792, loss: 0.00066864
[9/50] - ptime: 16.0152 loss: 0.01297797 acc: 0.74000 lr: 0.00100000
Epoch: [10] [   0/  60] time: 157.8635, loss: 0.00420440
Epoch: [10] [   1/  60] time: 158.1112, loss: 0.01074447
Epoch: [10] [   2/  60] time: 158.3620, loss: 0.02695555
Epoch: [10] [   3/  60] time: 158.6106, loss: 0.00378522
Epoch: [10] [   4/ 

Epoch: [12] [   9/  60] time: 192.9312, loss: 0.03926517
Epoch: [12] [  10/  60] time: 193.1801, loss: 0.00502549
Epoch: [12] [  11/  60] time: 193.4311, loss: 0.00731752
Epoch: [12] [  12/  60] time: 193.6787, loss: 0.03726403
Epoch: [12] [  13/  60] time: 193.9282, loss: 0.04325581
Epoch: [12] [  14/  60] time: 194.1766, loss: 0.00306457
Epoch: [12] [  15/  60] time: 194.4296, loss: 0.00479924
Epoch: [12] [  16/  60] time: 194.6777, loss: 0.01356514
Epoch: [12] [  17/  60] time: 194.9279, loss: 0.02400257
Epoch: [12] [  18/  60] time: 195.1776, loss: 0.02050855
Epoch: [12] [  19/  60] time: 195.4206, loss: 0.00502273
Epoch: [12] [  20/  60] time: 195.6693, loss: 0.00818706
Epoch: [12] [  21/  60] time: 195.9191, loss: 0.00124636
Epoch: [12] [  22/  60] time: 196.1671, loss: 0.04346419
Epoch: [12] [  23/  60] time: 196.4174, loss: 0.05480448
Epoch: [12] [  24/  60] time: 196.6658, loss: 0.05810576
Epoch: [12] [  25/  60] time: 196.9158, loss: 0.00420589
Epoch: [12] [  26/  60] time: 1

Epoch: [14] [  31/  60] time: 231.2543, loss: 0.00453988
Epoch: [14] [  32/  60] time: 231.5062, loss: 0.00039320
Epoch: [14] [  33/  60] time: 231.7557, loss: 0.00095430
Epoch: [14] [  34/  60] time: 232.0067, loss: 0.00213502
Epoch: [14] [  35/  60] time: 232.2546, loss: 0.00089446
Epoch: [14] [  36/  60] time: 232.5046, loss: 0.01218012
Epoch: [14] [  37/  60] time: 232.7538, loss: 0.00003953
Epoch: [14] [  38/  60] time: 233.0041, loss: 0.00668534
Epoch: [14] [  39/  60] time: 233.2524, loss: 0.00028774
Epoch: [14] [  40/  60] time: 233.5006, loss: 0.00141204
Epoch: [14] [  41/  60] time: 233.7489, loss: 0.00042209
Epoch: [14] [  42/  60] time: 233.9995, loss: 0.00011753
Epoch: [14] [  43/  60] time: 234.2488, loss: 0.00313328
Epoch: [14] [  44/  60] time: 234.4978, loss: 0.01455283
Epoch: [14] [  45/  60] time: 234.7462, loss: 0.00065173
Epoch: [14] [  46/  60] time: 234.9969, loss: 0.00380584
Epoch: [14] [  47/  60] time: 235.2455, loss: 0.00123090
Epoch: [14] [  48/  60] time: 2

Epoch: [16] [  53/  60] time: 269.5869, loss: 0.01322745
Epoch: [16] [  54/  60] time: 269.8362, loss: 0.00176093
Epoch: [16] [  55/  60] time: 270.0850, loss: 0.02203233
Epoch: [16] [  56/  60] time: 270.3348, loss: 0.02070765
Epoch: [16] [  57/  60] time: 270.5851, loss: 0.01710228
Epoch: [16] [  58/  60] time: 270.8344, loss: 0.01668903
Epoch: [16] [  59/  60] time: 271.0851, loss: 0.00665305
[16/50] - ptime: 16.0174 loss: 0.03012851 acc: 0.45000 lr: 0.00090000
Epoch: [17] [   0/  60] time: 272.8182, loss: 0.00539756
Epoch: [17] [   1/  60] time: 273.0678, loss: 0.00484510
Epoch: [17] [   2/  60] time: 273.3170, loss: 0.00137802
Epoch: [17] [   3/  60] time: 273.5695, loss: 0.00111750
Epoch: [17] [   4/  60] time: 273.8161, loss: 0.04469505
Epoch: [17] [   5/  60] time: 274.0693, loss: 0.01855184
Epoch: [17] [   6/  60] time: 274.3186, loss: 0.02208290
Epoch: [17] [   7/  60] time: 274.5692, loss: 0.00801853
Epoch: [17] [   8/  60] time: 274.8183, loss: 0.00306018
Epoch: [17] [   9/

Epoch: [19] [  14/  60] time: 309.2375, loss: 0.00899740
Epoch: [19] [  15/  60] time: 309.4874, loss: 0.00086017
Epoch: [19] [  16/  60] time: 309.7370, loss: 0.00501702
Epoch: [19] [  17/  60] time: 309.9869, loss: 0.00013800
Epoch: [19] [  18/  60] time: 310.2365, loss: 0.00053703
Epoch: [19] [  19/  60] time: 310.4923, loss: 0.00036822
Epoch: [19] [  20/  60] time: 310.7416, loss: 0.00007370
Epoch: [19] [  21/  60] time: 310.9907, loss: 0.00168416
Epoch: [19] [  22/  60] time: 311.2408, loss: 0.00025762
Epoch: [19] [  23/  60] time: 311.4915, loss: 0.00336827
Epoch: [19] [  24/  60] time: 311.7429, loss: 0.00059473
Epoch: [19] [  25/  60] time: 311.9919, loss: 0.00091946
Epoch: [19] [  26/  60] time: 312.2414, loss: 0.00059420
Epoch: [19] [  27/  60] time: 312.4929, loss: 0.00059215
Epoch: [19] [  28/  60] time: 312.7452, loss: 0.00017579
Epoch: [19] [  29/  60] time: 312.9949, loss: 0.00108786
Epoch: [19] [  30/  60] time: 313.2484, loss: 0.00008851
Epoch: [19] [  31/  60] time: 3

Epoch: [21] [  36/  60] time: 347.6352, loss: 0.00002035
Epoch: [21] [  37/  60] time: 347.8828, loss: 0.00004479
Epoch: [21] [  38/  60] time: 348.1308, loss: 0.00009389
Epoch: [21] [  39/  60] time: 348.3810, loss: 0.00023043
Epoch: [21] [  40/  60] time: 348.6256, loss: 0.00063448
Epoch: [21] [  41/  60] time: 348.8755, loss: 0.00043728
Epoch: [21] [  42/  60] time: 349.1278, loss: 0.00152881
Epoch: [21] [  43/  60] time: 349.3762, loss: 0.00299275
Epoch: [21] [  44/  60] time: 349.6251, loss: 0.00040386
Epoch: [21] [  45/  60] time: 349.8733, loss: 0.00048496
Epoch: [21] [  46/  60] time: 350.1251, loss: 0.00001564
Epoch: [21] [  47/  60] time: 350.3729, loss: 0.00094547
Epoch: [21] [  48/  60] time: 350.6226, loss: 0.00015855
Epoch: [21] [  49/  60] time: 350.8726, loss: 0.00294307
Epoch: [21] [  50/  60] time: 351.1212, loss: 0.00004233
Epoch: [21] [  51/  60] time: 351.3713, loss: 0.00816099
Epoch: [21] [  52/  60] time: 351.6213, loss: 0.00015709
Epoch: [21] [  53/  60] time: 3

Epoch: [23] [  58/  60] time: 385.9547, loss: 0.00166073
Epoch: [23] [  59/  60] time: 386.2032, loss: 0.00000348
[23/50] - ptime: 16.0010 loss: 0.00052871 acc: 0.64000 lr: 0.00081000
Epoch: [24] [   0/  60] time: 387.9457, loss: 0.00033579
Epoch: [24] [   1/  60] time: 388.1955, loss: 0.00000260
Epoch: [24] [   2/  60] time: 388.4460, loss: 0.00003893
Epoch: [24] [   3/  60] time: 388.6947, loss: 0.00001433
Epoch: [24] [   4/  60] time: 388.9434, loss: 0.00119804
Epoch: [24] [   5/  60] time: 389.1930, loss: 0.00000291
Epoch: [24] [   6/  60] time: 389.4455, loss: 0.00011475
Epoch: [24] [   7/  60] time: 389.6934, loss: 0.00002436
Epoch: [24] [   8/  60] time: 389.9422, loss: 0.00005962
Epoch: [24] [   9/  60] time: 390.1907, loss: 0.00002776
Epoch: [24] [  10/  60] time: 390.4412, loss: 0.01245717
Epoch: [24] [  11/  60] time: 390.6896, loss: 0.00007000
Epoch: [24] [  12/  60] time: 390.9397, loss: 0.00018436
Epoch: [24] [  13/  60] time: 391.1895, loss: 0.02014407
Epoch: [24] [  14/

Epoch: [26] [  19/  60] time: 425.5002, loss: 0.00688508
Epoch: [26] [  20/  60] time: 425.7493, loss: 0.00097684
Epoch: [26] [  21/  60] time: 425.9985, loss: 0.00060224
Epoch: [26] [  22/  60] time: 426.2489, loss: 0.00153275
Epoch: [26] [  23/  60] time: 426.4976, loss: 0.00066586
Epoch: [26] [  24/  60] time: 426.7470, loss: 0.00231057
Epoch: [26] [  25/  60] time: 426.9943, loss: 0.00137057
Epoch: [26] [  26/  60] time: 427.2430, loss: 0.00090436
Epoch: [26] [  27/  60] time: 427.4889, loss: 0.00086279
Epoch: [26] [  28/  60] time: 427.7336, loss: 0.00479460
Epoch: [26] [  29/  60] time: 427.9830, loss: 0.00917822
Epoch: [26] [  30/  60] time: 428.2326, loss: 0.03278366
Epoch: [26] [  31/  60] time: 428.4811, loss: 0.00194893
Epoch: [26] [  32/  60] time: 428.7303, loss: 0.00143270
Epoch: [26] [  33/  60] time: 428.9781, loss: 0.01581748
Epoch: [26] [  34/  60] time: 429.2266, loss: 0.00717533
Epoch: [26] [  35/  60] time: 429.4747, loss: 0.00009361
Epoch: [26] [  36/  60] time: 4

Epoch: [28] [  41/  60] time: 463.8374, loss: 0.00017683
Epoch: [28] [  42/  60] time: 464.0862, loss: 0.00069336
Epoch: [28] [  43/  60] time: 464.3350, loss: 0.00005680
Epoch: [28] [  44/  60] time: 464.5834, loss: 0.00008873
Epoch: [28] [  45/  60] time: 464.8318, loss: 0.00002059
Epoch: [28] [  46/  60] time: 465.0821, loss: 0.00005423
Epoch: [28] [  47/  60] time: 465.3337, loss: 0.00029791
Epoch: [28] [  48/  60] time: 465.5833, loss: 0.00011515
Epoch: [28] [  49/  60] time: 465.8335, loss: 0.00003860
Epoch: [28] [  50/  60] time: 466.0826, loss: 0.00002656
Epoch: [28] [  51/  60] time: 466.3350, loss: 0.01388472
Epoch: [28] [  52/  60] time: 466.5801, loss: 0.00032024
Epoch: [28] [  53/  60] time: 466.8289, loss: 0.00002237
Epoch: [28] [  54/  60] time: 467.0774, loss: 0.00017106
Epoch: [28] [  55/  60] time: 467.3269, loss: 0.00002556
Epoch: [28] [  56/  60] time: 467.5780, loss: 0.00048712
Epoch: [28] [  57/  60] time: 467.8285, loss: 0.00005417
Epoch: [28] [  58/  60] time: 4

Epoch: [31] [   2/  60] time: 503.3977, loss: 0.00074971
Epoch: [31] [   3/  60] time: 503.6444, loss: 0.00027791
Epoch: [31] [   4/  60] time: 503.8935, loss: 0.00036056
Epoch: [31] [   5/  60] time: 504.1411, loss: 0.00002220
Epoch: [31] [   6/  60] time: 504.3907, loss: 0.00055971
Epoch: [31] [   7/  60] time: 504.6385, loss: 0.00018867
Epoch: [31] [   8/  60] time: 504.8875, loss: 0.00029199
Epoch: [31] [   9/  60] time: 505.1355, loss: 0.00016970
Epoch: [31] [  10/  60] time: 505.3823, loss: 0.00133084
Epoch: [31] [  11/  60] time: 505.6321, loss: 0.00051713
Epoch: [31] [  12/  60] time: 505.8791, loss: 0.00052908
Epoch: [31] [  13/  60] time: 506.1278, loss: 0.00021523
Epoch: [31] [  14/  60] time: 506.3770, loss: 0.00057828
Epoch: [31] [  15/  60] time: 506.6248, loss: 0.00283157
Epoch: [31] [  16/  60] time: 506.8762, loss: 0.00133582
Epoch: [31] [  17/  60] time: 507.1244, loss: 0.00011881
Epoch: [31] [  18/  60] time: 507.3726, loss: 0.00069702
Epoch: [31] [  19/  60] time: 5

Epoch: [33] [  24/  60] time: 541.7739, loss: 0.00000864
Epoch: [33] [  25/  60] time: 542.0231, loss: 0.00002577
Epoch: [33] [  26/  60] time: 542.2715, loss: 0.00003762
Epoch: [33] [  27/  60] time: 542.5249, loss: 0.00035655
Epoch: [33] [  28/  60] time: 542.7725, loss: 0.00001525
Epoch: [33] [  29/  60] time: 543.0209, loss: 0.00005562
Epoch: [33] [  30/  60] time: 543.2690, loss: 0.00027940
Epoch: [33] [  31/  60] time: 543.5226, loss: 0.00002641
Epoch: [33] [  32/  60] time: 543.7700, loss: 0.00002807
Epoch: [33] [  33/  60] time: 544.0174, loss: 0.00017710
Epoch: [33] [  34/  60] time: 544.2695, loss: 0.00032188
Epoch: [33] [  35/  60] time: 544.5191, loss: 0.00759858
Epoch: [33] [  36/  60] time: 544.7681, loss: 0.00049335
Epoch: [33] [  37/  60] time: 545.0160, loss: 0.00000799
Epoch: [33] [  38/  60] time: 545.2661, loss: 0.00010551
Epoch: [33] [  39/  60] time: 545.5155, loss: 0.00062947
Epoch: [33] [  40/  60] time: 545.7667, loss: 0.00066868
Epoch: [33] [  41/  60] time: 5

Epoch: [35] [  46/  60] time: 580.2495, loss: 0.00010675
Epoch: [35] [  47/  60] time: 580.5030, loss: 0.00004506
Epoch: [35] [  48/  60] time: 580.7552, loss: 0.00002254
Epoch: [35] [  49/  60] time: 581.0157, loss: 0.00003588
Epoch: [35] [  50/  60] time: 581.2645, loss: 0.00000849
Epoch: [35] [  51/  60] time: 581.5149, loss: 0.00041402
Epoch: [35] [  52/  60] time: 581.7649, loss: 0.00043826
Epoch: [35] [  53/  60] time: 582.0141, loss: 0.00003853
Epoch: [35] [  54/  60] time: 582.2632, loss: 0.00000624
Epoch: [35] [  55/  60] time: 582.5147, loss: 0.00003194
Epoch: [35] [  56/  60] time: 582.7646, loss: 0.00004817
Epoch: [35] [  57/  60] time: 583.0258, loss: 0.00009814
Epoch: [35] [  58/  60] time: 583.2776, loss: 0.00000843
Epoch: [35] [  59/  60] time: 583.5295, loss: 0.00004656
[35/50] - ptime: 16.1097 loss: 0.00012675 acc: 0.72000 lr: 0.00072900
Epoch: [36] [   0/  60] time: 585.2887, loss: 0.00000798
Epoch: [36] [   1/  60] time: 585.5384, loss: 0.00037849
Epoch: [36] [   2/

Epoch: [38] [   7/  60] time: 620.0399, loss: 0.00001718
Epoch: [38] [   8/  60] time: 620.2896, loss: 0.00002052
Epoch: [38] [   9/  60] time: 620.5343, loss: 0.00003468
Epoch: [38] [  10/  60] time: 620.7864, loss: 0.00012069
Epoch: [38] [  11/  60] time: 621.0360, loss: 0.00000155
Epoch: [38] [  12/  60] time: 621.2858, loss: 0.00000809
Epoch: [38] [  13/  60] time: 621.5363, loss: 0.00001918
Epoch: [38] [  14/  60] time: 621.7865, loss: 0.00001255
Epoch: [38] [  15/  60] time: 622.0310, loss: 0.00065126
Epoch: [38] [  16/  60] time: 622.2817, loss: 0.00005088
Epoch: [38] [  17/  60] time: 622.5302, loss: 0.00000080
Epoch: [38] [  18/  60] time: 622.7791, loss: 0.00013463
Epoch: [38] [  19/  60] time: 623.0285, loss: 0.00000765
Epoch: [38] [  20/  60] time: 623.2777, loss: 0.00023231
Epoch: [38] [  21/  60] time: 623.5288, loss: 0.00001933
Epoch: [38] [  22/  60] time: 623.7792, loss: 0.00014261
Epoch: [38] [  23/  60] time: 624.0292, loss: 0.00000367
Epoch: [38] [  24/  60] time: 6

Epoch: [40] [  29/  60] time: 658.5049, loss: 0.00001067
Epoch: [40] [  30/  60] time: 658.7542, loss: 0.00001497
Epoch: [40] [  31/  60] time: 659.0044, loss: 0.00000761
Epoch: [40] [  32/  60] time: 659.2536, loss: 0.00001331
Epoch: [40] [  33/  60] time: 659.5065, loss: 0.00006610
Epoch: [40] [  34/  60] time: 659.7567, loss: 0.00000457
Epoch: [40] [  35/  60] time: 660.0054, loss: 0.00001751
Epoch: [40] [  36/  60] time: 660.2581, loss: 0.00005690
Epoch: [40] [  37/  60] time: 660.5095, loss: 0.00001039
Epoch: [40] [  38/  60] time: 660.7586, loss: 0.00000322
Epoch: [40] [  39/  60] time: 661.0109, loss: 0.00001514
Epoch: [40] [  40/  60] time: 661.2622, loss: 0.00000710
Epoch: [40] [  41/  60] time: 661.5114, loss: 0.00001430
Epoch: [40] [  42/  60] time: 661.7623, loss: 0.00004062
Epoch: [40] [  43/  60] time: 662.0106, loss: 0.00001981
Epoch: [40] [  44/  60] time: 662.2598, loss: 0.00001192
Epoch: [40] [  45/  60] time: 662.5137, loss: 0.00000862
Epoch: [40] [  46/  60] time: 6

Epoch: [42] [  51/  60] time: 697.0175, loss: 0.00000773
Epoch: [42] [  52/  60] time: 697.2667, loss: 0.00000222
Epoch: [42] [  53/  60] time: 697.5178, loss: 0.00000152
Epoch: [42] [  54/  60] time: 697.7672, loss: 0.00001156
Epoch: [42] [  55/  60] time: 698.0149, loss: 0.00001183
Epoch: [42] [  56/  60] time: 698.2642, loss: 0.00000860
Epoch: [42] [  57/  60] time: 698.5127, loss: 0.00006229
Epoch: [42] [  58/  60] time: 698.7614, loss: 0.00001078
Epoch: [42] [  59/  60] time: 699.0058, loss: 0.00000459
[42/50] - ptime: 16.0427 loss: 0.00004287 acc: 0.74000 lr: 0.00065610
Epoch: [43] [   0/  60] time: 700.7708, loss: 0.00011721
Epoch: [43] [   1/  60] time: 701.0310, loss: 0.00000427
Epoch: [43] [   2/  60] time: 701.2817, loss: 0.00000248
Epoch: [43] [   3/  60] time: 701.5332, loss: 0.00000584
Epoch: [43] [   4/  60] time: 701.7849, loss: 0.00003378
Epoch: [43] [   5/  60] time: 702.0328, loss: 0.00000385
Epoch: [43] [   6/  60] time: 702.2822, loss: 0.00001073
Epoch: [43] [   7/

Epoch: [45] [  12/  60] time: 736.6552, loss: 0.00001479
Epoch: [45] [  13/  60] time: 736.9045, loss: 0.00000616
Epoch: [45] [  14/  60] time: 737.1522, loss: 0.00000678
Epoch: [45] [  15/  60] time: 737.4005, loss: 0.00008044
Epoch: [45] [  16/  60] time: 737.6479, loss: 0.00000159
Epoch: [45] [  17/  60] time: 737.8954, loss: 0.00000247
Epoch: [45] [  18/  60] time: 738.1450, loss: 0.00013459
Epoch: [45] [  19/  60] time: 738.3921, loss: 0.00006239
Epoch: [45] [  20/  60] time: 738.6420, loss: 0.00000017
Epoch: [45] [  21/  60] time: 738.8915, loss: 0.00000246
Epoch: [45] [  22/  60] time: 739.1406, loss: 0.00000680
Epoch: [45] [  23/  60] time: 739.3877, loss: 0.00000029
Epoch: [45] [  24/  60] time: 739.6377, loss: 0.00000173
Epoch: [45] [  25/  60] time: 739.8858, loss: 0.00000613
Epoch: [45] [  26/  60] time: 740.1370, loss: 0.00011214
Epoch: [45] [  27/  60] time: 740.3852, loss: 0.00000060
Epoch: [45] [  28/  60] time: 740.6344, loss: 0.00002003
Epoch: [45] [  29/  60] time: 7

Epoch: [47] [  34/  60] time: 775.0195, loss: 0.00000416
Epoch: [47] [  35/  60] time: 775.2664, loss: 0.00000175
Epoch: [47] [  36/  60] time: 775.5150, loss: 0.00000836
Epoch: [47] [  37/  60] time: 775.7633, loss: 0.00001123
Epoch: [47] [  38/  60] time: 776.0129, loss: 0.00000485
Epoch: [47] [  39/  60] time: 776.2629, loss: 0.00000837
Epoch: [47] [  40/  60] time: 776.5126, loss: 0.00017417
Epoch: [47] [  41/  60] time: 776.7618, loss: 0.00000562
Epoch: [47] [  42/  60] time: 777.0107, loss: 0.00000336
Epoch: [47] [  43/  60] time: 777.2598, loss: 0.00000174
Epoch: [47] [  44/  60] time: 777.5103, loss: 0.00001099
Epoch: [47] [  45/  60] time: 777.7590, loss: 0.00001040
Epoch: [47] [  46/  60] time: 778.0079, loss: 0.00003303
Epoch: [47] [  47/  60] time: 778.2557, loss: 0.00000026
Epoch: [47] [  48/  60] time: 778.5070, loss: 0.00005315
Epoch: [47] [  49/  60] time: 778.7546, loss: 0.00016019
Epoch: [47] [  50/  60] time: 779.0043, loss: 0.00004775
Epoch: [47] [  51/  60] time: 7

Epoch: [49] [  56/  60] time: 813.0298, loss: 0.00000167
Epoch: [49] [  57/  60] time: 813.2772, loss: 0.00000084
Epoch: [49] [  58/  60] time: 813.5254, loss: 0.00000779
Epoch: [49] [  59/  60] time: 813.7731, loss: 0.00000265
[49/50] - ptime: 15.7355 loss: 0.00002752 acc: 0.73000 lr: 0.00065610
Epoch: [50] [   0/  60] time: 815.4110, loss: 0.00003114
Epoch: [50] [   1/  60] time: 815.6581, loss: 0.00000081
Epoch: [50] [   2/  60] time: 815.9037, loss: 0.00000335
Epoch: [50] [   3/  60] time: 816.1522, loss: 0.00002260
Epoch: [50] [   4/  60] time: 816.3999, loss: 0.00000158
Epoch: [50] [   5/  60] time: 816.6462, loss: 0.00000783
Epoch: [50] [   6/  60] time: 816.8929, loss: 0.00001025
Epoch: [50] [   7/  60] time: 817.1416, loss: 0.00000348
Epoch: [50] [   8/  60] time: 817.3888, loss: 0.00000065
Epoch: [50] [   9/  60] time: 817.6372, loss: 0.00000081
Epoch: [50] [  10/  60] time: 817.8834, loss: 0.00001666
Epoch: [50] [  11/  60] time: 818.1303, loss: 0.00022239
Epoch: [50] [  12/

In [7]:
import pygame

file='/home/huiqy/Music/CloudMusic/All_Time_Low.mp3' #文件名是完整路径名
pygame.mixer.init() #初始化音频
track = pygame.mixer.music.load(file)#载入音乐文件
pygame.mixer.music.play()#开始播放
time.sleep(60)#播放10秒
pygame.mixer.music.stop()#停止播放