Now with a dataset handy, we'll use it to fit a model.

This notebook follows closely the `train.py` scripts found in all of the model directories.
First, imports:

In [1]:
import numpy as np
import tensorflow as tf
import sys
import datetime
import os
import time

sys.path.insert(0, '../tfmodels')
import tfmodels

## Let's use a densnet-small 
sys.path.insert(0, '../')
from densenet_small import Training

Instantiating and updating the model is simple. First we'll make a bunch of folders to organize the outputs. We'll get back snapshots, logs, and we'll make a few extras for random debugging-related output, and eventually for inference.

In [2]:
!ls

1_dataset_from_image_mask.ipynb    example_data
2_segmentation_cnn_training.ipynb  validating_trained_model.ipynb
deploy_model_to_wsi.ipynb


In [3]:
basedir = './trained_model'
log_dir, save_dir, debug_dir, infer_dir = tfmodels.make_experiment(basedir=basedir)

Creating base experiment directory
Creating ./trained_model/logs/2018_07_27_10_34_32
Creating ./trained_model/snapshots
Creating ./trained_model/debug
Creating ./trained_model/inference


In [4]:
!ls

1_dataset_from_image_mask.ipynb    example_data
2_segmentation_cnn_training.ipynb  trained_model
deploy_model_to_wsi.ipynb	   validating_trained_model.ipynb


Next, set up the paths to the dataset, and a few constants for training. To make everything really lightweight we'll use an input size of 128px. That means we crop out a 512px area and resize it by a factor of 0.25. Also set up how many iterations to train for, and the number of expected classes in the data:

In [5]:
crop_size = 512
image_ratio = 0.25
record_path = './example_data/image_mask_pairs.tfrecord'
iterations = 1000
batch_size = 4
learning_rate = 1e-4

Let's note a few things. We set record path to be the `tfrecord` object created in the previous notebook. This record has a grand total of 11 image/mask examples in it so we'll probably over fit very quickly. Therefore, we choose a low number of `iterations`, and a small `batch_size`.

It might not be so bad for two reasons. First, is the random crop. Each of our 11 examples is $1200 \times 1200$ pixels, leaving plenty of room for randomness to be introduced just by random cropping. The second helpful technique is hidden from us for now. The class `tfmodels.TFRecordImageMask` dataset (optionally) applies color augmentation to each example it loads. The transformations randomly alter the hue, saturation, and brightness of the image within a set percentage change from the original. Also, they randomly apply flips and rotations. These transformations are valid for our problem because they help a small dataset artificially cover more of the possible distribution of data that exists in the world for this problem. In other words, we never do anything that will make a training image **too** different from a possible naturally occuring image, like one we might see at inference time.

In [6]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess = tf.Session(config=config)

dataset = tfmodels.TFRecordImageMask(
    training_record = record_path,
    sess = sess,
    crop_size = crop_size,
    ratio = image_ratio,
    prefetch = 512, ## How many images to prefetch into memory
    shuffle_buffer = 16,
    n_classes = 5,
    n_threads = 4
)

Dataset TRAINING phase


Now, to instantiate the model. Its `Training` mode is already imported, from above. It wants a few variables to be defined in order to set itself up correctly. 

We require a fixed shape for the input image at training time. This is to help predict how much memory the model will need on GPU. It's also so that we can make sure the input is big enough to repetitively down-sample and not end up with a vector after 2 or 3 convolution-pool layers.

In [7]:
x_dims = [int(crop_size * image_ratio),
          int(crop_size * image_ratio),
          3]
model = Training(
    sess = sess,
    dataset = dataset,
    learning_rate = learning_rate,
    log_dir = log_dir,
    save_dir = save_dir,
    summary_iters = 10, # Log scalars and histograms 
    summary_image_iters = 100, # Log images
    x_dims = x_dims
)

Requesting 4 dense blocks
MINIMIUM DIMENSION:  4
Setting up densenet in training mode
DenseNet Model
Non-linearity: <function selu at 0x7f5d0f653230>
	 x_in (?, 128, 128, 3)
Dense block #0 (dd)
	 Transition Down with k_out= 96
Dense block #1 (dd)
	 Transition Down with k_out= 144
Dense block #2 (dd)
	 Transition Down with k_out= 240
Dense block #3 (dense)
	 Bottleneck:  (?, 4, 4, 528)
	 Transition Up with k_out= 264
Dense block #0 (du)
	 Transition Up with k_out= 96
Dense block #1 (du)
	 Transition Up with k_out= 48
Dense block #2 (du)
Model output y_hat: (?, 128, 128, 5)
Setting up batch norm update ops
Done setting up densenet


`model` inherits from `tfmodels.Segmentation`, so there are a couple useful methods baked in. One is `train_step` which is just like it sounds. Pulling a batch from the dataset, the model processes the batch, computes gradients, and applies the gradients via the models optimizer. We can choose the optimizer from any `tf.optimizers`. The default is Adam. The `model` class has an internal counter for the number of times `train_step` is called -- that is used to periodically log. By default the current step and loss are printed whenever a log is written.

The second method is `snapshot`. This also does what it sounds like. Given that we instantiated the model with a `save_dir`, we just have to call `model.snapshot()` to save the trainable variables. Be careful not to modify the file with the model code in it, otherwise it might be hard to actually use these snapshots in the future!

Run the training loop:

In [9]:
# If this cell is run multiple times, the training will pick up where it left off
snapshot_iterations = 250
for itx in range(iterations):
    model.train_step(lr=learning_rate)
    if itx % snapshot_iterations == 0:
        model.snapshot()
        
# Make one final snapshot
model.snapshot()

Snapshotting to [./trained_model/snapshots/densenet.ckpt] step [1001] ./trained_model/snapshots/densenet.ckpt-1001
Done
[0001010] writing scalar summaries (loss=0.432) (lr=1.000000E-04)
[0001020] writing scalar summaries (loss=0.393) (lr=1.000000E-04)
[0001030] writing scalar summaries (loss=0.458) (lr=1.000000E-04)
[0001040] writing scalar summaries (loss=0.429) (lr=1.000000E-04)
[0001050] writing scalar summaries (loss=0.423) (lr=1.000000E-04)
[0001060] writing scalar summaries (loss=0.408) (lr=1.000000E-04)
[0001070] writing scalar summaries (loss=0.393) (lr=1.000000E-04)
[0001080] writing scalar summaries (loss=0.427) (lr=1.000000E-04)
[0001090] writing scalar summaries (loss=0.370) (lr=1.000000E-04)
[0001100] writing scalar summaries (loss=0.387) (lr=1.000000E-04)
[0001100] writing image summaries
[0001110] writing scalar summaries (loss=0.403) (lr=1.000000E-04)
[0001120] writing scalar summaries (loss=0.399) (lr=1.000000E-04)
[0001130] writing scalar summaries (loss=0.434) (lr=1.