This script performs the training of a network after the dataset has been constructed. Three modules are used in this training:

1. **input_data**: This module loads in the training data that was created, and prepares it so that it can be fed into the neural network for training.
2. **models**: this module contains different types of neural network architectures. One can choose a model from this script and train it.
3. **training_utilities**: this module contains functions to construct models, set up diagnostic files, and perform the training on the model.

First, we import these modules

In [1]:
from input_data import *
from models import *
from training_utilities import *

Using TensorFlow backend.


Next, we need to provide parameter information based on the data and model we're using.

**parent_dir**: This is the path to the "stem-learning" code

**data_dir**: This is the path to the data that we created in the preprocessing section

**sess_name**: we will create a folder in the "results" directory called session_name, where all the output will be stored

**N**: This is the pixel width/height of the input images (note we're assuming a square image)

**k_fac**: this is a factor that describes how many channels we want per layer in our FCNs. Whatever the default value is per layer, it is multiplied by k_fac.

**nb_classes**: this is the number of labels that we are learning at once. For example, if our data is just the "2Te" labels, then nb_classes = 2 (2Te and no defect). 

In [2]:
parent_dir = '/home/abid/Dropbox/Research/Clarksearch/stem/stem-learning/'
data_dir   = parent_dir + 'data/WSeTe/simulated/parsed_label_2Te/'
sess_name  = '2Te'
N          = 256
k_fac      = 16
nb_classes = 2

The variables below are then created to locate the directories that we'll be storing our data

In [3]:
from os import makedirs
sess_dir = parent_dir + "results/" + sess_name + "/"
makedirs(sess_dir, exist_ok=True)

model_weights_fn = sess_dir + "weights.h5"
model_fn         = sess_dir + "model.json"
diagnostics_fn   = sess_dir + "diagnostics.dat"

Now we create the model and set up a diagnostics file

In [4]:
model = construct_model(N, k_fac, nb_classes, sess_dir, model_fn, model_weights_fn)
step = setup_diagnostics(diagnostics_fn)

W0106 16:48:51.992354 140297698061184 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0106 16:48:52.001519 140297698061184 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0106 16:48:52.003815 140297698061184 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0106 16:48:52.025401 140297698061184 deprecation_wrapper.py:119] From /home/abid/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W010

Finally, we train

In [None]:
train(step, data_dir, N, nb_classes, model, diagnostics_fn, model_weights_fn)

training step: 0	training file: train_2.p
	grabbing data
	done
	calculating accuracy
118
191
2294
182
1578
214
1310
216
959
225
1787
230
1391
261
846
240
1203
250
878
256
1533
134
1442
177
1232
198
1430
248
1069
284
911
271
1616
216
1502
214
1237
196
1296
240
1089
164
771
177
985
211
1620
250
982
198
786
177
1419
244
1214
261
1100
175
1637
141
1037
208
1077
282
978
134
901
260
1314
177
1377
177
1780
196
1287
196
1747
182
1033
194
1033
241
1547
151
3050
201
1228
198
1733
248
1843
216
1955
240
1272
167
2019
167
1025
182
1240
208
868
214
817
271
1498
208
1646
151
1254
202
1089
213
1420
197
1896
282
921
191
1456
197
2481
160
1441
182
935
282
1338
175
1932
185
2320
282
1635
214
1775
151
2011
179
1155
151
2923
216
1240
240
914
151
2522
216
950
141
970
261
1102
177
2018
208
1548
260
933
261
1492
177
1909
194
903
208
1749
194
946
177
2769
141
1323
151
1811
261
1233
248
1022
297
1443
208
1412
261
945
282
1506
164
982
201
1273
225
1510
151
1689
151
1056
198
1576
271
1886
185
1735
172
3030
134
13

W0106 16:49:31.480357 140297698061184 deprecation.py:323] From /home/abid/.local/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Train on 1058 samples, validate on 118 samples
Epoch 1/1
