In [1]:
from datetime import datetime
import tensorflow as tf
import numpy as np
from PIL import Image
from tqdm import tqdm

from alexnet import AlexNet
from image_loader import load_images
from util import shuffle, get_batch

In [2]:
learning_rate = 0.001
num_epochs = 2
batch_size = 128

dropout_rate = 0.5
num_classes = 2

# how often to write tf.summary to disk
display_step = 20

In [3]:
X = tf.placeholder(tf.float32, shape=(batch_size, 150, 150, 1))
Y = tf.placeholder(tf.float32, shape=(batch_size, num_classes))
dropout = tf.placeholder(tf.float32)

model = AlexNet(X, dropout, num_classes)

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=model.logits, labels=Y))

with tf.name_scope("train"):
    optimizer = tf.train.AdamOptimizer(learning_rate)
    optimization_op = optimizer.minimize(loss)

tf.summary.scalar('loss', loss)

with tf.name_scope("accuracy"):
    predictions = tf.argmax(model.logits, axis=1)
    correct_predictions = tf.equal(predictions, tf.argmax(Y, axis=1))
    accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

tf.summary.scalar('training_accuracy', accuracy)

merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter("./tensorboard")
saver = tf.train.Saver()

In [4]:
molecule_images = load_images("./positives/images/*.png")
non_molecule_images = load_images("./negatives/tiles/*.png")

In [5]:
n_pos = molecule_images.shape[0]
n_neg = molecule_images.shape[0]

positive_labels = np.stack((np.ones(n_pos), np.zeros(n_pos)), axis=1)
negative_labels = np.stack((np.zeros(n_neg), np.ones(n_neg)), axis=1)

x = np.vstack((molecule_images, non_molecule_images))
y = np.vstack((positive_labels, negative_labels))

m = x.shape[0]

x, y = shuffle(x, y)

train_x = x[:-2000]
train_y = y[:-2000]

test_x = x[-2000:]
test_y = y[-2000:]

In [6]:
training_batches_per_epoch = int(np.floor(train_x.shape[0] / batch_size))
test_batches_per_epoch = int(np.floor(test_x.shape[0] / batch_size))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    writer.add_graph(sess.graph)

    for epoch in range(num_epochs):
        print("{} epoch number: {}".format(datetime.now(), epoch + 1))
        train_x, train_y = shuffle(train_x, train_y)
        
        for step in tqdm(range(training_batches_per_epoch)):
            x_batch, y_batch = get_batch(train_x, train_y, step, batch_size)
            sess.run(optimization_op, feed_dict={X: x_batch, Y: y_batch, dropout: dropout_rate })

            if step % display_step == 0:
                summ = sess.run(merged_summary, feed_dict={X: x_batch, Y: y_batch, dropout: 0 })
                writer.add_summary(summ, epoch * training_batches_per_epoch + step)
                writer.flush()

        test_acc = 0
        test_count = 0
        for step in range(test_batches_per_epoch):
            x_batch, y_batch = get_batch(test_x, test_y, step, batch_size)
            acc = sess.run(accuracy, feed_dict={X: x_batch, Y: y_batch, dropout: 0 })
            test_acc += acc
            test_count += 1

        print("test accuracy = {:.4f}".format(test_acc / test_count))
    
    saver.save(sess, "./model.ckpt")

2017-12-26 20:41:33.834493 epoch number: 1


  6%|▌         | 8/138 [00:19<05:15,  2.42s/it]

KeyboardInterrupt: 