In [1]:
import numpy as np
import tensorflow as tf
import tensornets as nets
import os
from PIL import Image
from tqdm import tqdm, tqdm_notebook

In [2]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 2945214265914033555
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 10973274766645313657
physical_device_desc: "device: XLA_CPU device"
, name: "/device:XLA_GPU:0"
device_type: "XLA_GPU"
memory_limit: 17179869184
locality {
}
incarnation: 17201280164647085454
physical_device_desc: "device: XLA_GPU device"
]


In [15]:
def getData(path, batchsize = 16):
    try:
        labels = np.load("batchedData/"+path+"/labels")
    except:
        print("No labels found", "batchedData/"+path+"/labels")
        return 0
    try:
        features = []
        i = 0
        while True:
            features.append(np.load("batchedData/"+path+'/'+str(i)))
            i = i+1
    except:
        return features, labels

In [16]:
valid_features, valid_labels = getData('preprocess_validation.p')

In [17]:
valid_features

[array([[[175, 171, 168],
         [176, 172, 169],
         [172, 168, 165],
         ...,
         [129, 119, 107],
         [131, 121, 111],
         [132, 122, 112]],
 
        [[172, 168, 165],
         [174, 170, 167],
         [174, 170, 167],
         ...,
         [130, 120, 108],
         [137, 127, 117],
         [134, 124, 114]],
 
        [[172, 168, 165],
         [174, 170, 167],
         [172, 168, 165],
         ...,
         [130, 120, 108],
         [132, 122, 112],
         [140, 130, 120]],
 
        ...,
 
        [[161, 146, 125],
         [160, 151, 136],
         [161, 154, 142],
         ...,
         [ 97,  79,  66],
         [104,  88,  74],
         [138, 123, 106]],
 
        [[152, 137, 116],
         [134, 125, 110],
         [144, 137, 126],
         ...,
         [133,  97,  71],
         [135, 105,  79],
         [127,  98,  68]],
 
        [[104,  94,  84],
         [115, 105,  95],
         [106,  97,  86],
         ...,
         [130,  98,  73],
  

In [18]:
x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input_x')
y = tf.placeholder(tf.float32, shape=(None, 5), name='output_y')

In [19]:
### HYPER-PARAMETERS
learning_rate = 0.00001
epochs = 6
batch_size = 16

In [20]:
logits = nets.VGG19(x, is_training=True, classes=5)
model = tf.identity(logits,name='logits')
loss = tf.losses.softmax_cross_entropy(y,logits)
train = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(loss)

correct_pred = tf.equal(tf.argmax(model,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')

In [21]:
logits.print_outputs()

Scope: vgg19
conv1/1/conv/BiasAdd:0 (?, 224, 224, 64)
conv1/1/Relu:0 (?, 224, 224, 64)
conv1/2/conv/BiasAdd:0 (?, 224, 224, 64)
conv1/2/Relu:0 (?, 224, 224, 64)
conv1/pool/MaxPool:0 (?, 112, 112, 64)
conv2/1/conv/BiasAdd:0 (?, 112, 112, 128)
conv2/1/Relu:0 (?, 112, 112, 128)
conv2/2/conv/BiasAdd:0 (?, 112, 112, 128)
conv2/2/Relu:0 (?, 112, 112, 128)
conv2/pool/MaxPool:0 (?, 56, 56, 128)
conv3/1/conv/BiasAdd:0 (?, 56, 56, 256)
conv3/1/Relu:0 (?, 56, 56, 256)
conv3/2/conv/BiasAdd:0 (?, 56, 56, 256)
conv3/2/Relu:0 (?, 56, 56, 256)
conv3/3/conv/BiasAdd:0 (?, 56, 56, 256)
conv3/3/Relu:0 (?, 56, 56, 256)
conv3/4/conv/BiasAdd:0 (?, 56, 56, 256)
conv3/4/Relu:0 (?, 56, 56, 256)
conv3/pool/MaxPool:0 (?, 28, 28, 256)
conv4/1/conv/BiasAdd:0 (?, 28, 28, 512)
conv4/1/Relu:0 (?, 28, 28, 512)
conv4/2/conv/BiasAdd:0 (?, 28, 28, 512)
conv4/2/Relu:0 (?, 28, 28, 512)
conv4/3/conv/BiasAdd:0 (?, 28, 28, 512)
conv4/3/Relu:0 (?, 28, 28, 512)
conv4/4/conv/BiasAdd:0 (?, 28, 28, 512)
conv4/4/Relu:0 (?, 28, 28, 5

In [22]:
logits.print_summary()

Scope: vgg19
Total layers: 19
Total weights: 114
Total parameters: 418,772,175


In [23]:
def batch_features_labels(features, labels, batch_size):
    """
    Split features and labels into batches
    """
    for start in range(0, len(features), batch_size):
        end = min(start + batch_size, len(features))
        yield features[start:end], labels[start:end]

In [24]:
def load_preprocess_training_batch(batch_id, batch_size):
    """
    Load the Preprocessed Training data and return them in batches of <batch_size> or less
    """
    filename = 'preprocess_batch_' + str(batch_id) + '.p'
    features, labels = getData(filename)

    # Return the training data in batches of size <batch_size> or less
    return batch_features_labels(features, labels, batch_size)

In [25]:
save_model_path = './image_classification'

print('Training...')
with tf.Session() as sess:    
    # Initializing the variables
    sess.run(tf.global_variables_initializer())
    print('global_variables_initializer ... done ...')
    sess.run(logits.pretrained())
    print('model.pretrained ... done ... ')    
    
    # Training cycle
    print('starting training ... ')
    for epoch in range(epochs):
        # Loop over all batches
        n_batches = 4
        for batch_i in range(1, n_batches + 1):
            for batch_features, batch_labels in load_preprocess_training_batch(batch_i, batch_size):
                sess.run(train, {x: batch_features, y: batch_labels})
                
            print('Epoch {:>2}, Batch {}:  '.format(epoch + 1, batch_i), end='')
            
            # calculate the mean accuracy over all validation dataset
            valid_acc = 0
            for batch_valid_features, batch_valid_labels in batch_features_labels(valid_features, valid_labels, batch_size):
                valid_acc += sess.run(accuracy, {x:batch_valid_features, y:batch_valid_labels})
            
            tmp_num = len(valid_features)/batch_size
            print('Validation Accuracy: {:.6f}'.format(valid_acc/tmp_num))
            
    # Save Model
    saver = tf.train.Saver()
    save_path = saver.save(sess, save_model_path)


Training...
global_variables_initializer ... done ...
model.pretrained ... done ... 
starting training ... 
Epoch  1, Batch 1:  Validation Accuracy: 0.133333
Epoch  1, Batch 2:  

KeyboardInterrupt: 