In [1]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
import tensorflow as tf
import os
from sklearn.model_selection import train_test_split

# Get the data

In [2]:
data_dir = "./data"
data_dict = np.load(os.path.join(data_dir, 'random.npz'))

In [7]:
wavelengths = data_dict['wavelengths']
X_all = data_dict['layers']
Y_all = data_dict['reflectances']

In [15]:
X_train, X_test, Y_train, Y_test = \
    train_test_split(X_all, Y_all, test_size=1e-2)
X_train, X_val, Y_train, Y_val =\
    train_test_split(X_train, Y_train, test_size=1e-2)

In [8]:
# Function to get a random batch
def random_batch(batch_size, X, Y):
    indices = np.random.randint(X.shape[0], size=batch_size)
    X_ = X[indices]
    y_ = Y[indices]
    return X_, y_

# Make the graph

In [18]:
# Training params
n_epochs = 100
batch_size = 256
n_batches = X_train.shape[0] // batch_size
print("Number of Batches per epoch: {}".format(n_batches))

Number of Batches per epoch: 382


In [22]:
# Saving data
name = "fully_connected_1"
model_dir = os.path.join('.', 'models', name)
now = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S")
log_dir = os.path.join('.', 'tf_logs', name+"_"+now)
for d in [model_dir, log_dir]:
    if not os.path.isdir(d):
        os.makedirs(d)

checkpoint_path = os.path.join(model_dir, "model_ckpt.ckpt")
checkpoint_epoch_path = checkpoint_path + ".epoch"
final_model_path = os.path.join(model_dir, "final_model")

In [30]:
tf.reset_default_graph()

# Make the actual neural net
with tf.name_scope("dnn"):
    # Inputs and outputs
    X_in = tf.placeholder(tf.float32, shape=(None, X_all.shape[1]), name="X")
    y_in = tf.placeholder(tf.float32, shape=(None, Y_all.shape[1]), name="Y")

    # Hidden layers
    h_prev = X_in
    for layernum in range(5):  # number of layers
        h = tf.layers.dense(h_prev, 8, activation=tf.nn.elu,
                            name="h{}".format(layernum))
        h_prev = h
    
    # Output
    output = tf.layers.dense(h_prev, Y_all.shape[1], name="output")
    
with tf.name_scope("loss"):
    error = y_in - output
    rmse = tf.sqrt(tf.reduce_mean(tf.square(error)), name="rmse")
    loss_summary = tf.summary.scalar('rmse_loss', rmse)
    optimizer = tf.train.MomentumOptimizer(learning_rate=1e-3, momentum=0.9,
                                          use_nesterov=True)
    training_op = optimizer.minimize(rmse)

with tf.name_scope("admin"):
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    train_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())
    test_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

In [31]:
# Run the training
with tf.Session() as sess:
    if os.path.isfile(checkpoint_epoch_path):
        # if the checkpoint file exists, restore the model and load the epoch number
        with open(checkpoint_epoch_path, "rb") as f:
            start_epoch = int(f.read())
        print("Training was interrupted. Continuing at epoch", start_epoch)
        saver.restore(sess, checkpoint_path)
    else:
        start_epoch = 0
        sess.run(init)   
    
    for epoch in range(start_epoch, n_epochs):
        
        # Run batches
        for batch_index in range(n_batches):
            X_batch, y_batch = random_batch(batch_size, X_train, Y_train)
            _, summary_str = \
                sess.run([training_op, loss_summary], 
                         feed_dict={X_in: X_batch, y_in: y_batch})
            train_writer.add_summary(summary_str, epoch*n_batches+batch_index)
        
        # Run validation
        rmse_val, summary_str = \
            sess.run([rmse, loss_summary], 
                     feed_dict={X_in: X_val, y_in: Y_val})
        train_writer.add_summary(summary_str, epoch)
        
        if epoch % 1 == 0:
            print("Epoch:", epoch, "\tLoss:", rmse_val)
            saver.save(sess, checkpoint_path)
            with open(checkpoint_epoch_path, "wb") as f:
                f.write(b"%d" % (epoch + 1))

    saver.save(sess, final_model_path)

Epoch: 0 	Loss: 0.14297067
Epoch: 1 	Loss: 0.12398439
Epoch: 2 	Loss: 0.122547805
Epoch: 3 	Loss: 0.122333206
Epoch: 4 	Loss: 0.12300115
Epoch: 5 	Loss: 0.122082956
Epoch: 6 	Loss: 0.12307665
Epoch: 7 	Loss: 0.122096345
Epoch: 8 	Loss: 0.12175741
Epoch: 9 	Loss: 0.12169238
Epoch: 10 	Loss: 0.12182747
Epoch: 11 	Loss: 0.122002654
Epoch: 12 	Loss: 0.12159643
Epoch: 13 	Loss: 0.12178228
Epoch: 14 	Loss: 0.121598065
Epoch: 15 	Loss: 0.121555924
Epoch: 16 	Loss: 0.121581286
Epoch: 17 	Loss: 0.12156291
Epoch: 18 	Loss: 0.121615976
Epoch: 19 	Loss: 0.1215406
Epoch: 20 	Loss: 0.12151281
Epoch: 21 	Loss: 0.12145732
Epoch: 22 	Loss: 0.12148879
Epoch: 23 	Loss: 0.1215261
Epoch: 24 	Loss: 0.12159227
Epoch: 25 	Loss: 0.12147492
Epoch: 26 	Loss: 0.12151682
Epoch: 27 	Loss: 0.12148507
Epoch: 28 	Loss: 0.1214394
Epoch: 29 	Loss: 0.121696144
Epoch: 30 	Loss: 0.12137703
Epoch: 31 	Loss: 0.12140573
Epoch: 32 	Loss: 0.121356696
Epoch: 33 	Loss: 0.12136726
Epoch: 34 	Loss: 0.121338524
Epoch: 35 	Loss: 0.12

In [None]:
# # Test accuracy
# with tf.Session() as sess:
#     # if the checkpoint file exists, restore the model and load the epoch number
#     print("Loading model")
#     saver.restore(sess, final_model_path)

#     acc_val = sess.run(accuracy, feed_dict={
#         X: mnist.test.images, 
#         y: mnist.test.labels
#     })
#     print("Accuracy: {}".format(acc_val))