This script performs the training of a network after the dataset has been constructed. Three modules are used in this training:

1. **input_data**: This module loads in the training data that was created, and prepares it so that it can be fed into the neural network for training.
2. **models**: this module contains different types of neural network architectures. One can choose a model from this script and train it.
3. **training_utilities**: this module contains functions to construct models, set up diagnostic files, and perform the training on the model.

First, we import these modules

In [1]:
from input_data import *
from models import *
from training_utilities import *

Using TensorFlow backend.


Next, we need to provide parameter information based on the data and model we're using.

**parent_dir**: This is the path to the "stem-learning" code

**data_dir**: This is the path to the data that we created in the preprocessing section

**sess_name**: we will create a folder in the "results" directory called session_name, where all the output will be stored

**N**: This is the pixel width/height of the input images (note we're assuming a square image)

**k_fac**: this is a factor that describes how many channels we want per layer in our FCNs. Whatever the default value is per layer, it is multiplied by k_fac.

**nb_classes**: this is the number of labels that we are learning at once. For example, if our data is just the "2Te" labels, then nb_classes = 2 (2Te and no defect). 

In [2]:
parent_dir = '/home/abid/Dropbox/Research/Clarksearch/stem/stem-learning/'
data_dir   = parent_dir + 'data/WSeTe/simulated/parsed_label_2Te/'
sess_name  = '2Te'
N          = 256
k_fac      = 16
nb_classes = 2

The variables below are then created to locate the directories that we'll be storing our data

In [3]:
from os import makedirs
sess_dir = parent_dir + "results/" + sess_name + "/"
makedirs(sess_dir, exist_ok=True)

model_weights_fn = sess_dir + "weights.h5"
model_fn         = sess_dir + "model.json"
diagnostics_fn   = sess_dir + "diagnostics.dat"

Now we create the model and set up a diagnostics file

In [4]:
model = construct_model(N, k_fac, nb_classes, sess_dir, model_fn, model_weights_fn)
step = setup_diagnostics(diagnostics_fn)

W0828 16:49:36.667418 140301239805760 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0828 16:49:36.675787 140301239805760 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0828 16:49:36.677871 140301239805760 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0828 16:49:36.695089 140301239805760 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W082

Finally, we train

In [5]:
train(step, data_dir, N, nb_classes, model, diagnostics_fn)

training step: 0	training file: train_0.p


W0828 16:49:40.581091 140301239805760 deprecation.py:323] From /home/abid/.local/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Train on 1058 samples, validate on 118 samples
Epoch 1/1


KeyboardInterrupt: 