In [1]:
import numpy as np

def validate(model, X, Y):
    def get_crops(image):
        x = image.shape[0] - 227
        y = image.shape[1] - 227
        cx = int(x/2)
        cy = int(y/2)
        images = np.zeros((5, 227, 227, 3))
        images[0,:,:,:] = image[x:x+227,:227,:]
        images[1,:,:,:] = image[x:x+227,y:y+227,:]
        images[2,:,:,:] = image[:,y:y+227,:]
        images[3,:,:,:] = image[:227,:227,:]
        images[4,:,:,:] = image[cx:227+cx,cy:cy+227,:]
        return images

    good1 = 0
    good5 = 0
    for i in range(len(X)):
        x = X[i]
        y = Y[i]
        img = get_crops(x)
        scores = model.predict(img)
        scores = np.sum(scores, axis=0)
        label = np.argmax(y)
        if label in scores.argsort()[-1:][::-1]:
            good1 += 1
        if label in scores.argsort()[-5:][::-1]:
            good5 += 1
    return (good1/len(X), good5/len(X))

In [2]:
from __future__ import division, print_function, absolute_import

import tflearn
import tensorflow as tf
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
from tflearn.data_augmentation import ImageAugmentation
from tflearn.layers.normalization import batch_normalization
from tflearn.activations import softmax
from tflearn.optimizers import Adam

from sklearn.model_selection import train_test_split
import tflearn.datasets.oxflower17 as oxflower17
X, Y = oxflower17.load_data(one_hot=True, resize_pics=(256, 256))

trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.1, random_state=42)


def conv(network, filters, kernel, strides, activation):
    network = batch_normalization(network)
    network = conv_2d(network, filters, kernel, strides=strides, activation=activation, binarize=binarize)
    return network

def fc(network, n_units, activation):
    network = batch_normalization(network)
    network = fully_connected(network, n_units, activation=activation, binarize=binarize)
    return network

inc = lambda x: x

# Choose model binarization type
binarize=None

# set model name
name = 'alexnet-flower17'

# Accuracy array
acc = []

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([227, 227], padding=0)

# Building 'AlexNet'
network = input_data(shape=[None, 227, 227, 3],
                     data_preprocessing=img_prep,
                     data_augmentation=img_aug)
network = conv(network, inc(96), 11, 4, 'relu')
network = max_pool_2d(network, 3, strides=2)
network = conv(network, inc(256), 5, 1, 'relu')
network = max_pool_2d(network, 3, strides=2)
network = conv(network, inc(384), 3, 1, 'relu')
network = conv(network, inc(384), 3, 1, 'relu')
network = conv(network, inc(256), 3, 1, 'relu')
network = max_pool_2d(network, 3, strides=2)
network = fc(network, 4096, 'tanh')
network = dropout(network, 0.5)
network = fc(network, 4096, 'tanh')
network = dropout(network, 0.5)
network = fc(network, 17, None)
network = softmax(network)

if binarize is None or binarize == 'weights':
    network = regression(network, optimizer='adam', loss='categorical_crossentropy', learning_rate=0.001)

if binarize == 'full':
    adam = Adam(0.0001, 0.9, 0.99)
    network = regression(network, optimizer=adam, loss='categorical_crossentropy', learning_rate=0.0001)

acc = []
model = tflearn.DNN(network, tensorboard_verbose=0)
for i in range(50):
    model.fit(trainX, trainY, n_epoch=1, shuffle=True, show_metric=True, batch_size=64, 
              run_id=name)
    acc.append(validate(model, testX, testY))

Training Step: 999  | total loss: [1m[32m0.30323[0m[0m | time: 16.451s
| Adam | epoch: 050 | loss: 0.30323 - acc: 0.9090 -- iter: 1216/1224
Training Step: 1000  | total loss: [1m[32m0.32860[0m[0m | time: 16.662s
| Adam | epoch: 050 | loss: 0.32860 - acc: 0.9009 -- iter: 1224/1224
--


In [4]:
max(acc)

(0.7647058823529411, 0.9852941176470589)