



Built using guidance from https://arxiv.org/pdf/1512.09300.pdf

Features:
  * Uses ELU activations
  * Deconvolution uses upscaling unpool layer before affine operator, rather than spacing with zeros
  * Batch normalization after each transformation
  * Dropout layer after activation
  * abs-sum image loss rather than cross-entropy loss

# Summary

Best result so far is 10 epochs of the first training batch, with the prior and the prediction weighted equally.

Running another 10 epochs doesn't hurt the similarity results, but it does make the reconstructions worse.

Not quite as good after 10 epochs with regularization=0.1 (down-weighted prior)

In [31]:
import tensorflow as tf
 
print("running TensorFlow version {}".format(tf.__version__))

running TensorFlow version 1.9.0


In [32]:
# Control memory usage

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

In [33]:
# Report OOM details

run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)


# Import data

In [34]:
import os
import re

from operator import itemgetter

In [61]:
LOG_ROOT = '../../data/'

RUN_NAME = 'vae/vae_016'
SUMMARY_DIR = os.path.join(LOG_ROOT, 'logs', RUN_NAME)
MODEL_DIR = os.path.join(LOG_ROOT, 'models', RUN_NAME)
MODEL_GRAPH = os.path.join(MODEL_DIR, 'vae.meta')
MODEL_PREFIX = os.path.join(MODEL_DIR, 'vae')

DATA_SIZE = 'all_packs_14'
EXPT_NAME = 'expt_016'
VIS_NAME = 'vis_001'
DATA_ROOT = '/var/data/processed'

EXPT_DIR = os.path.join(DATA_ROOT, EXPT_NAME, 'data')
TFRECORDS_DIR = os.path.join('/var/data/original/tfrecords/', DATA_SIZE)
JSON_DIR = os.path.join('/var/data/original/labels/', "all_data_14.json")
EMBEDDINGS_DIR = os.path.join(EXPT_DIR, 'embeddings')
GOLDEN_EMBEDDINGS_DIR = os.path.join(EXPT_DIR, 'golden_embeddings')
ENCODED_IMAGES_DIR =  os.path.join(EXPT_DIR, 'images')

In [62]:
train_256_pattern = re.compile('^train_(?P<block_id>[0-9]{3}).tfrecords')
validate_256_pattern = re.compile('^validate_(?P<block_id>[0-9]{3}).tfrecords')
test_256_pattern = re.compile('^test_(?P<block_id>[0-9]{3}).tfrecords')
#golden_256_pattern1 = re.compile('golden_(?P<block_id>[0-9]{3})_src.tfrecords')
#golden_256_pattern2 = re.compile('golden_(?P<block_id>[0-9]{3})_dr.tfrecords')


In [63]:

ALL_TFRECORDS = os.listdir(TFRECORDS_DIR)
def get_sorted_records(pattern, directory):
     return [ \
         os.path.join(directory, _file) \
         for _file in \
         sorted([_m[0] for _m in \
             [pattern.match(_f) for _f in os.listdir(directory)] if _m]) \
     ]

#GOLDEN_TFRECORDS_SRC = get_sorted_records(golden_256_pattern1, TFRECORDS_DIR)
#GOLDEN_TFRECORDS_DR = get_sorted_records(golden_256_pattern2, TFRECORDS_DIR)
TRAIN_TFRECORDS = get_sorted_records(train_256_pattern, TFRECORDS_DIR)
VALIDATE_TFRECORDS = get_sorted_records(validate_256_pattern, TFRECORDS_DIR)
TEST_TFRECORDS = get_sorted_records(test_256_pattern, TFRECORDS_DIR)

In [64]:
TRAIN_TFRECORDS, VALIDATE_TFRECORDS, TEST_TFRECORDS

(['/var/data/original/tfrecords/all_packs_14/train_001.tfrecords'],
 ['/var/data/original/tfrecords/all_packs_14/validate_001.tfrecords'],
 ['/var/data/original/tfrecords/all_packs_14/test_001.tfrecords'])

In [65]:
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(SUMMARY_DIR, exist_ok=True)

In [66]:
def _decode(serialized_example):
    '''Parses an image and label from the given `serialized_example`
    '''
    features = tf.parse_single_example(
        serialized_example,
        features={
            'filename': tf.FixedLenFeature([], tf.string),
            'image': tf.FixedLenFeature([], tf.string),
            #'view': tf.FixedLenFeature([], tf.string),
            #'gender': tf.FixedLenFeature([], tf.string),
            #'age': tf.FixedLenFeature([], tf.int64),
            'labels': tf.FixedLenSequenceFeature( [], dtype=tf.int64, default_value=-1,allow_missing=True)
            })
       

    # Convert from a scalar string tensor
    filename = tf.cast(features['filename'], tf.string)
    image = tf.decode_raw(features['image'], tf.float32)
    #view = tf.cast(features['view'], tf.string)
    #gender = tf.cast(features['gender'], tf.string)
    #age = tf.cast(features['age'], tf.int32)
    labels = tf.cast(features['labels'], tf.int32)
    
    return filename, image, labels

def _filter(filename, image, labels):
    
    sub_string = tf.substr(filename,3,1)
    return tf.equal(sub_string, "0")

def _augment(filename, image, labels):
    '''Placeholder for data augmentation
    '''
    image = tf.reshape(image, [256, 256, 1])
    return filename, image, labels


def _normalize(filename, image, labels):
    '''Convert `image` from [0, 255] -> [-0.5, 0.5] floats
    '''
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return filename, image, labels

In [67]:
def inputs(filenames, batch_size, num_epochs, num_shards, shard_index):
    ''' Reads input data num_epochs times or forever if num_epochs is None
        returns dataset, iterator pair
    '''

    with tf.name_scope('input'):
        # TFRecordDataset opens a binary file and reads one record at a time.
        # `filename` could also be a list of filenames, which will be read in order.
       
        
        dataset = tf.data.TFRecordDataset(filenames)
        
        # The map transformation takes a function and applies it to every element
        # of the dataset.
        
        dataset = dataset.map(_decode)
        dataset = dataset.shard(num_shards, shard_index)
        #dataset = dataset.filter(_filter)
        dataset = dataset.map(_augment)
        dataset = dataset.map(_normalize)

        # The shuffle transformation uses a finite-sized buffer to shuffle elements
        # in memory. The parameter is the number of elements in the buffer. For
        # completely uniform shuffling, set the parameter to be the same as the
        # number of elements in the dataset.
        dataset = dataset.shuffle(1000 + 3 * batch_size)

        dataset = dataset.repeat(num_epochs)
        dataset = dataset.batch(batch_size)

        iterator = dataset.make_one_shot_iterator()

    return dataset, iterator

In [68]:
from collections import OrderedDict


## Define label keys
LABEL_KEYS = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Effusion",
    "Emphysema",
    "Fibrosis",
    "Hernia",
    "Infiltration",
    "Mass",
    "Nodule",
    "Pleural_Thickening",
    "Pneumonia",
    "Pneumothorax",
]

nb_label_keys = len(LABEL_KEYS)
label_keys_by_id = LABEL_KEYS
ids_by_label_key = OrderedDict([(k,v) for k, v in enumerate(label_keys_by_id)])

In [69]:
def count_labels():
    
    with open(JSON_DIR, 'r') as j:
        docs = json.load(j)

    label2score = {key: list([]) for key in LABEL_KEYS}
    count = {key: 0 for key in LABEL_KEYS}    
    total = 0
    
    for doc in docs:

        for i,l in enumerate(doc['labels']):
            if(l == 1):
                label = ids_by_label_key[i]
                count[label] +=1
                total += 1
    return count, total

c, total = count_labels()
print(c)

WEIGHTS = []
for key in c:
    WEIGHTS.append((total-c.get(key))/(total))
#weights = tf.constant(weights)
print(WEIGHTS)


{'Atelectasis': 11535, 'Cardiomegaly': 2772, 'Consolidation': 4667, 'Edema': 2303, 'Effusion': 13307, 'Emphysema': 2516, 'Fibrosis': 1686, 'Hernia': 227, 'Infiltration': 19870, 'Mass': 5746, 'Nodule': 6323, 'Pleural_Thickening': 3385, 'Pneumonia': 1353, 'Pneumothorax': 5298}
[0.8575714920728997, 0.9657727070677138, 0.9423741788906999, 0.9715636884476713, 0.835691707413444, 0.9689336691855583, 0.9791821010520078, 0.9971971156220675, 0.7546550106188571, 0.9290512174643157, 0.9219267051908925, 0.9582036844964686, 0.9832938213068603, 0.9345829011705438]


# VAE Model

In [70]:
# training parameters
LEARNING_RATE = 0.0001
BATCH_SIZE = 32
NUM_EPOCHS = 20

DROPOUT = 0.7
REGULARIZATION1 = 0.1
REGULARIZATION2 = 0.5

DISPLAY_EVERY = 100

## Loggers

In [71]:
def variable_summary(x, name):
    with tf.variable_scope(name):
        mean = tf.reduce_mean(x)
        tf.summary.scalar('mean', mean)
        stddev = tf.sqrt(tf.reduce_mean(tf.square(x - mean)))
        tf.summary.scalar('stddev', stddev)
        tf.summary.scalar('max', tf.reduce_max(x))
        tf.summary.scalar('min', tf.reduce_min(x))
        tf.summary.histogram('histogram', x)

## Component layers

In [72]:
# unpool operation doesn't yet exist in TF

def unpool_op(x, stride, name='unpool'):

    with tf.name_scope(name) as scope:

        if stride==1:
            return x

        shape = x.get_shape().as_list()
        dim = len(shape[1:-1])
        out = (tf.reshape(x, [-1] + shape[-dim:]))
        for i in range(dim, 0, -1):
            out = tf.concat([out]*stride, i)
        out_size = [-1] + [s * stride for s in shape[1:-1]] + [shape[-1]]
        out = tf.reshape(out, out_size, name=scope)
    return out

In [73]:
def convolution_layer(x, dims, stride, train, bias=None, name='conv', activation=tf.nn.elu):

    with tf.variable_scope(name):

        # Parameters
        weights = tf.get_variable('w', dims,
                    initializer=tf.contrib.layers.xavier_initializer())

        if bias is not None:
            biases = tf.get_variable('b', [dims[-1]],
                        initializer=tf.random_normal_initializer())

        # Layer structure
        if bias is None:
             conv = tf.nn.conv2d(x, weights, strides=[1, stride, stride, 1], padding='SAME', name='conv')
        else:
             conv = tf.nn.bias_add(tf.nn.conv2d(x, weights, strides=[1, stride, stride, 1],
                        padding='SAME'), biases, name='conv')
        normalized = tf.layers.batch_normalization(conv, axis=3, training=train,
                    name='spatial_batch_norm')        
        activations = activation(normalized, name='activation')
        activations = tf.layers.dropout(activations, rate=DROPOUT, training=train, name='dropout')

        # Variable summaries
        variable_summary(weights, 'weights')
        if bias is not None:
            variable_summary(biases, 'biases')
        tf.summary.histogram('pre-activations', normalized)
        tf.summary.histogram('activations', activations)

        return activations

In [74]:
def deconvolution_layer(x, dims, stride, train, bias=None, name='deconv', activation=tf.nn.elu):

    with tf.variable_scope(name):

        # Parameters
        weights = tf.get_variable('w', dims,
                    initializer=tf.contrib.layers.xavier_initializer())
        if bias is not None:
            biases = tf.get_variable('b', [dims[-1]],
                        initializer=tf.random_normal_initializer())

        # Layer structure
        unpool = unpool_op(x, stride, name='unpool')
        if bias is None:
            deconv = tf.nn.conv2d(unpool, weights, strides=[1, 1, 1, 1], padding='SAME', name='deconv')
        else:
            deconv = tf.nn.bias_add(tf.nn.conv2d(unpool, weights, strides=[1, 1, 1, 1],
                        padding='SAME'), biases, name='deconv')
        normalized = tf.layers.batch_normalization(deconv, axis=3, training=train,
                    name='spatial_batch_norm')        
        activations = activation(deconv, name='activation')
        activations = tf.layers.dropout(activations, rate=DROPOUT, training=train, name='dropout')

        # Variable summaries
        variable_summary(weights, 'weights')
        if bias is not None:
            variable_summary(biases, 'biases')
        tf.summary.histogram('pre-activations', deconv)
        tf.summary.histogram('activations', activations)

        return activations

In [75]:
def dense_layer(x, dims, train, bias=None, name='fc', activation=tf.nn.elu):

    with tf.variable_scope(name):

        # Parameters
        weights = tf.get_variable('w', dims,
                    initializer=tf.contrib.layers.xavier_initializer())
        if bias is not None:
            biases = tf.get_variable('b', [dims[-1]],
                        initializer=tf.random_normal_initializer())

        # Layer structure
        if bias is None:
            dense = tf.matmul(x, weights, name='dense')
        else:
            dense = tf.nn.bias_add(tf.matmul(x, weights), biases, name='dense')

        normalized = tf.layers.batch_normalization(dense, axis=1, training=train,
                    name='batch_norm')
        activations = activation(normalized, name='activation')

        # Variable summaries
        variable_summary(weights, 'weights')
        if bias is not None:
            variable_summary(biases, 'biases')
        tf.summary.histogram('pre-activations', normalized)
        tf.summary.histogram('activations', activations)

        return activations

In [76]:
def gaussian_sample(mean, stddev, name):

    with tf.variable_scope(name):

        # mean is unconstrained; stddev must be strictly positive
        stddev = 1e-6 + tf.nn.softplus(stddev)

        # actually sample
        z = mean + stddev * tf.random_normal(tf.shape(mean), 0, 1, dtype=tf.float32)

        # Variable summaries
        tf.summary.histogram('mean', mean)
        tf.summary.histogram('stddev', stddev)
        tf.summary.histogram('z', z)

        return mean, stddev, z

## Loss

In [115]:
def evaluate(x, xhat, mu, sigma, logits_14, labels, weights):

    with tf.variable_scope('loss'):
        # Structure
        pred = tf.losses.absolute_difference(x, xhat,
                reduction=tf.losses.Reduction.MEAN)

        # offsetx = x + 0.5
        # pred = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        #         labels=offsetx, logits=xhat))

        KLdiv = 0.5 * tf.reduce_mean(tf.square(mu) + \
                    tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1)
        
        labels = tf.cast(labels, tf.float32)
        
        
        
        #ce = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = logits_14, labels = labels))
        #ce =  tf.losses.sigmoid_cross_entropy(multi_class_labels=labels, logits=logits_14, weights=(9*labels+1))
        #ce = tf.reduce_sum(tf.nn.weighted_cross_entropy_with_logits(logits=logits_14, targets=labels, pos_weight=weights))
        #ce =  tf.losses.sigmoid_cross_entropy(multi_class_labels=labels, logits=logits_14)
        #ce =  tf.losses.sigmoid_cross_entropy(multi_class_labels=labels, logits=logits_14, weights=x)
        
        weighted_logits = tf.multiply(logits_14, weights) # shape [batch_size, 14]
        ce = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = weighted_logits, labels = labels))
        
        
        loss = tf.add(REGULARIZATION1 * KLdiv, pred)
        loss = tf.add(loss, REGULARIZATION2 * ce)

        #loss = tf.add(REGULARIZATION * KLdiv, pred)
        # Summaries
        tf.summary.scalar('prediction', pred)
        tf.summary.scalar('prior', KLdiv)
        tf.summary.scalar('cross_entropy', ce)
        tf.summary.scalar('loss', loss)

    return loss, pred, KLdiv, ce

## Autoencoder

In [103]:
def encoder(img, train):

    with tf.variable_scope('encoder'):

        # convolution
        conv1 = convolution_layer(img, [5, 5, 1, 64], 2, train, name='conv1')
        conv2 = convolution_layer(conv1, [5, 5, 64, 128], 2, train, name='conv2')
        conv3 = convolution_layer(conv2, [5, 5, 128, 256], 2, train, name='conv3')

        # transition
        conv3 = tf.reshape(conv3, [-1, 32*32*256], name='reshape1')

        # dense output
        fc1 = dense_layer(conv3, [32*32*256, 64], train,
                activation=tf.identity, name='fc1')

        # sample
        mu, sigma, z = gaussian_sample(fc1[:, :32], fc1[:, 32:], name='output')

    return mu, sigma, z

In [104]:
def transform(z, train):
    
    with tf.variable_scope('transfrom'):
        
        
        dense1 = dense_layer(z, [32, 32], train,
                activation=tf.identity, name='fc2')
        logits = dense_layer(dense1, [32, 14], train,
                activation=tf.identity, name='fc4')
        
    return logits

In [105]:
def decoder(z, train):

    with tf.variable_scope('decoder'):

        # dense input
        fc1 = dense_layer(z, [32, 32*32*256], train, name='fc1')

        # transition
        fc1 = tf.reshape(fc1, [-1, 32, 32, 256], name='reshape1')

        # deconvolution
        deconv1 = deconvolution_layer(fc1, [5, 5, 256, 128], 2, train,
                            name='deconv1')
        deconv2 = deconvolution_layer(deconv1, [5, 5, 128, 64], 2, train, 
                            name='deconv2')
        deconv3 = deconvolution_layer(deconv2, [5, 5, 64, 32], 2, train,
                            name='deconv3')
        logits = deconvolution_layer(deconv3, [5, 5, 32, 1], 1, train,
                            activation=tf.identity, name='logits')

        # put into image range for display
        with tf.name_scope('range'):
            xhat = 0.5 * tf.nn.tanh(logits)

    return xhat, logits

In [106]:
def decoder_alternative(z, train):

    with tf.variable_scope('decoder'):

        # dense input
        fc1 = dense_layer(z, [32, 32*32*256], train, name='fc1')

        # transition
        fc1 = tf.reshape(fc1, [-1, 32, 32, 256], name='reshape1')

        # deconvolution
        deconv1 = deconvolution_layer(fc1, [5, 5, 256, 128], 2, train,
                            name='deconv1')
        deconv1a = deconvolution_layer(deconv1, [5, 5, 128, 128], 2, train,
                            name='deconv1a')
        deconv2 = deconvolution_layer(deconv1a, [5, 5, 128, 64], 2, train, 
                            name='deconv2')
        deconv3 = deconvolution_layer(deconv2, [5, 5, 64, 32], 2, train,
                            name='deconv3')
        logits = deconvolution_layer(deconv3, [5, 5, 32, 1], 1, train,
                            activation=tf.identity, name='logits')

        # put into image range for display
        with tf.name_scope('range'):
            xhat = 0.5 * tf.nn.tanh(logits)

    return xhat, logits

# Build and run

## Initialize

In [107]:
import sys
import os
import argparse
from datetime import datetime

import logging
logging.basicConfig(
    datefmt="%Y-%m-%dT%H:%M:%S%z",
    format="%(asctime)s [train/initialize] %(levelname)-8s %(message)s",
    level=logging.INFO
)

NOW_STR = datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S%z")
RUN_DESC = "cross-entropy loss, 256x256 images, no bias"
RANDOM_SEED = 42


In [112]:
def initialize():

    logging.info("initializing run: {}".format(RUN_NAME))

    # write a note regarding this run
    os.makedirs(SUMMARY_DIR, exist_ok=True)
    with open(os.path.join(SUMMARY_DIR, "description.txt"), 'w') as fh:
        fh.write(NOW_STR+" "+RUN_DESC)

    # Control memory usage
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    # Report OOM details
    run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)

    # Build graph
    with tf.Graph().as_default() as graph:

        # Repeatable results
        tf.set_random_seed(RANDOM_SEED)

        # Get Data
        train_dataset, train_iterator = inputs(filenames=TRAIN_TFRECORDS,
                batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, num_shards = 1, shard_index = 0)

        # Data placeholder
        data_handle = tf.placeholder(tf.string, shape=[])
        iterator = tf.data.Iterator.from_string_handle(
            data_handle, train_dataset.output_types, train_dataset.output_shapes)
        
        

        filename, image, labels = iterator.get_next()
        weights = tf.constant(WEIGHTS)
        # Train/validate flag
        train1 = tf.placeholder(tf.bool)
        train2 = tf.placeholder(tf.bool)

        # Global counter
        global_step = tf.train.get_or_create_global_step(graph)

        # Dropout
        dropout = tf.placeholder(tf.float32)

        # Autoencoder
        mu, sigma, z = encoder(image, train1)
        pred_labels = transform(mu, train2)
        xhat, logits = decoder(z, train1)
        loss, recon, reg, ce = evaluate(image, xhat, mu, sigma, pred_labels, labels, weights)
        
        # Training branch - control dependencies so batchnorm params are updated
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)\
                .minimize(loss, global_step=global_step, name='optimizer')

        # Log output for Tensorboard
        merged = tf.summary.merge_all()
        train_summary_logger = tf.summary.FileWriter(SUMMARY_DIR+'/train',
                        graph=graph, flush_secs=30)

        # Initializer
        init_variables = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

        # Save state
        tf.add_to_collection('optimizer', optimizer)

        tf.add_to_collection('filename', filename)
        tf.add_to_collection('image', image)
        tf.add_to_collection('labels', labels)
        tf.add_to_collection('pred_labels', pred_labels)

        
        tf.add_to_collection('mu', mu)
        
        tf.add_to_collection('sigma', sigma)
        tf.add_to_collection('xhat', xhat)

        tf.add_to_collection('loss', loss)
        tf.add_to_collection('recon', recon)
        tf.add_to_collection('reg', reg)
        tf.add_to_collection('ce', ce)

        tf.add_to_collection('data_handle', data_handle)
        tf.add_to_collection('train1', train1)
        tf.add_to_collection('train2', train2)


        tf.add_to_collection('merged', merged)

        writer = tf.train.Saver()

        # Run one step: this initializes the graph and saves our starting statistics
        with tf.Session(config=config) as session:

            session.run(init_variables)

            train_handle = session.run(train_iterator.string_handle())

            # Output header
            logging.info("  step      loss      recon     reg     ce")

            _, step = session.run([optimizer, global_step],
                    feed_dict = { data_handle: train_handle, train1: 0 , train2: 1},
                    options = run_options)

            loss_, recon_, reg_, ce_, summary = \
                session.run([loss, recon, reg, ce, merged],
                           feed_dict = { data_handle: train_handle, train1: 0 , train2: 0})
            train_summary_logger.add_summary(summary, step)

            logging.info("{: 6d} {:9.3g} {:9.3g} {:9.3g} {:9.3g}".format(step, loss_, recon_, reg_, ce_))

            # Save graph
            logging.info("saving graph")
            writer.save(session, MODEL_PREFIX, global_step=step, write_meta_graph=False)
            writer.export_meta_graph(MODEL_GRAPH)
            


def main():

    parser = argparse.ArgumentParser(description='initialize training graph')
    initialize()
    sys.exit(0)


if __name__ == '__main__':
    main()

2018-10-12T18:03:14+0000 [train/initialize] INFO     initializing run: vae/vae_016
2018-10-12T18:03:19+0000 [train/initialize] INFO       step      loss      recon     reg     ce
2018-10-12T18:03:26+0000 [train/initialize] INFO          0     0.215     0.205     0.104     0.698
2018-10-12T18:03:26+0000 [train/initialize] INFO     saving graph


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


## Train

In [118]:
with tf.Graph().as_default() as graph:
    
    # Repeatable results
    tf.set_random_seed(0)

    # Get Data
    train_dataset, train_iterator = inputs(filenames=TRAIN_TFRECORDS,
            batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, num_shards = 1, shard_index = 0)
        
    # Log output for Tensorboard
    train_summary_logger = tf.summary.FileWriter(SUMMARY_DIR+'/train', flush_secs=30)

    # Run
    with tf.Session(config=config) as session:

        # restore
        reader = tf.train.import_meta_graph(MODEL_GRAPH)
        reader.restore(session, tf.train.latest_checkpoint(MODEL_DIR))

        # must be called after reader so that the graph is populated
        writer = tf.train.Saver()

        # get references to graph endpoints
        global_step = tf.train.get_global_step(graph)

        optimizer = tf.get_collection('optimizer')[0]
        loss = tf.get_collection('loss')[0]
        recon = tf.get_collection('recon')[0]
        reg = tf.get_collection('reg')[0]
        ce = tf.get_collection('ce')[0]

        data_handle = tf.get_collection('data_handle')[0]
        train1 = tf.get_collection('train1')[0]
        train2 = tf.get_collection('train2')[0]


        merged = tf.get_collection('merged')[0]

        train_handle = session.run(train_iterator.string_handle())

        # Output header
        logging.info("  step      loss      recon     reg         ce")

        while True:
            try:
                _, step = session.run([optimizer, global_step], 
                            feed_dict = { data_handle: train_handle, train1: 0 , train2: 1})

                if not step%DISPLAY_EVERY:

                    loss_, recon_, reg_, ce_, summary = session.run([loss, recon, reg, ce, merged],
                            feed_dict = { data_handle: train_handle, train1: 0 , train2: 0})
                    train_summary_logger.add_summary(summary, step)

                    logging.info("{: 6d} {:9.3g} {:9.3g} {:9.3g} {:9.3g}".format(step, loss_, recon_, reg_, ce_))

                    writer.save(session, MODEL_PREFIX, global_step=step, write_meta_graph=False)

            except tf.errors.OutOfRangeError:
                print("done")
                break

INFO:tensorflow:Restoring parameters from ../../data/models/vae/vae_016/vae-15400


2018-10-12T20:09:22+0000 [train/initialize] INFO     Restoring parameters from ../../data/models/vae/vae_016/vae-15400
2018-10-12T20:09:23+0000 [train/initialize] INFO       step      loss      recon     reg         ce
2018-10-12T20:10:01+0000 [train/initialize] INFO      15500      0.11    0.0942     0.158     0.807
2018-10-12T20:10:38+0000 [train/initialize] INFO      15600     0.109    0.0916     0.173     0.843
2018-10-12T20:11:23+0000 [train/initialize] INFO      15700     0.105    0.0865      0.19     0.845
2018-10-12T20:12:06+0000 [train/initialize] INFO      15800     0.109    0.0896     0.199      0.82
2018-10-12T20:12:49+0000 [train/initialize] INFO      15900     0.112    0.0932     0.187     0.809
2018-10-12T20:13:33+0000 [train/initialize] INFO      16000     0.111    0.0918     0.193     0.773
2018-10-12T20:14:17+0000 [train/initialize] INFO      16100     0.107    0.0883     0.191       0.8
2018-10-12T20:15:00+0000 [train/initialize] INFO      16200     0.111    0.0932  

2018-10-12T21:06:44+0000 [train/initialize] INFO      23800    0.0839    0.0653     0.186     0.797
2018-10-12T21:07:28+0000 [train/initialize] INFO      23900    0.0978    0.0751     0.227     0.817
2018-10-12T21:08:13+0000 [train/initialize] INFO      24000    0.0929    0.0719      0.21     0.804


KeyboardInterrupt: 

# Test

Try out the autoencoder by running it on some test set samples.

For a collection of test images:

1. Create (image, label, mu, sigma) tuples
1. For a seed image, compute the 10 nearest images using (mu, sigma)
1. View the nearby images and their labels, comparing them to the seed image

If the VAE has worked as expected, we should find that the nearby images match the seed visually, and perhaps even match according to their labels.

## Generate document vectors

In [36]:


def gen_documents(tfrecords, dir_name, shard_index):
    
    # Store the documents
    documents = []
    def extend(docs, m, s, xh, f, i, l, p):
        start_id = len(docs)
        docs.extend([
            {
                'filename':f_.decode('ascii') ,
                'id_': k + start_id,
                'image': i_.reshape(256, 256)+0.5,
                'labels': l_,
                'pred_labels': p_,
                'sigma': s_,
                'mu': m_,
                'xhat': y_.reshape(256, 256)+0.5
            }
            for k, (m_, s_, y_, f_, i_,  l_, p_) in enumerate(zip(m, s, xh, f, i, l, p))
        ])
        return docs

    with tf.Graph().as_default() as graph:

        # Repeatable results
        tf.set_random_seed(0)

        # Get Data
        dataset, iterator = inputs(filenames = tfrecords, batch_size=BATCH_SIZE, num_epochs=1, num_shards=1, shard_index = 0)
        # Log output for Tensorboard
        summary_logger = tf.summary.FileWriter(SUMMARY_DIR + dir_name, flush_secs=30)

        print(SUMMARY_DIR + dir_name)
        # Run
        with tf.Session(config=config) as session:

            # restore
            reader = tf.train.import_meta_graph(MODEL_GRAPH)
            reader.restore(session, tf.train.latest_checkpoint(MODEL_DIR))

            # get references to graph endpoints
            filename = tf.get_collection('filename')[0]
            images = tf.get_collection('image')[0]
            labels = tf.get_collection('labels')[0]
            pred_labels = tf.get_collection('pred_labels')[0]

            mu = tf.get_collection('mu')[0]
            sigma = tf.get_collection('sigma')[0]
            xhat = tf.get_collection('xhat')[0]

            merged = tf.get_collection('merged')[0]

            data_handle = tf.get_collection('data_handle')[0]
            train = tf.get_collection('train')[0]
            handle = session.run(iterator.string_handle())
            
            step = 0
            while True:
                try:
                    mu_, sigma_, xhat_, filename_, images_, labels_, pred_labels_, summary = \
                        session.run([mu, sigma, xhat, filename, images, labels, pred_labels, merged],
                                feed_dict = { data_handle: handle, train: 0 })
                    documents = extend(documents, mu_, sigma_, xhat_, filename_, images_, labels_, pred_labels_)
                    step += 1
                    summary_logger.add_summary(summary, step)                        
                    print("{}.".format(step), end="", flush=True)

                except tf.errors.OutOfRangeError:
                    print(".done")
                    break
    return documents

In [37]:
SUBSET = "golden_dr"

In [38]:
if SUBSET == "golden_dr":
    golden_dr_documents = gen_documents(GOLDEN_TFRECORDS_DR, '/golden', 1)
if SUBSET == "golden_src":
    golden_src_documents = gen_documents(GOLDEN_TFRECORDS_SRC, '/golden', 1)
elif SUBSET == "validate":
    validate_documents = gen_documents(VALIDATE_TFRECORDS, '/validate')
elif SUBSET == "test":
    test_documents = gen_documents(TEST_TFRECORDS, '/test')
elif SUBSET == "train":
    train_documents = gen_documents(TRAIN_TFRECORDS, '/train', 4)


../../data/logs/vae/vae_007/golden
INFO:tensorflow:Restoring parameters from ../../data/models/vae/vae_007/vae-11300


2018-08-02T19:00:16+0000 [train/initialize] INFO     Restoring parameters from ../../data/models/vae/vae_007/vae-11300


1.2.3.4..done


In [40]:
for doc in golden_dr_documents:
    print(doc['filename'])
    print(doc['labels'])
    print(doc['pred_labels'])
    

00018562_003.png
[1 0 0 0 1 0 0 0 0 0 0 0 1 0 0]
[-1.0720762  -0.87022126 -0.83928794 -0.89940995 -0.48253763 -1.0618614
 -1.0444778  -1.102079   -0.18380427 -0.79615617 -1.0838246  -0.4062531
 -1.0027723  -0.99242157 -0.83510464]
00012681_005.png
[0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
[-0.86500937 -1.221415   -0.9388064  -0.91931546 -0.83375335 -1.1990775
 -0.99700654 -1.0121012  -0.9131926  -1.1462178  -0.92739207 -0.5020547
 -1.0355849  -1.1243302  -0.927076  ]
00012625_008.png
[0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
[-0.8392162  -0.9658433  -0.9229138  -1.12591    -0.64069045 -0.8182095
 -0.9011565  -1.0855676  -1.184177   -0.86749274 -1.0675727  -0.10186093
 -0.9949252  -1.0366486  -0.9219177 ]
00012681_009.png
[0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
[-0.57942396 -1.1269436  -0.8235245  -0.89944476 -0.70706785 -1.1780941
 -1.2847226  -0.9287429  -0.70435673 -1.1891603  -0.9602176  -0.72983325
 -1.084122   -1.0870798  -1.121264  ]
00004085_010.png
[0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
[-1.0110269  -1.0217295  

## Generate Json

In [71]:
import json
import scipy.misc as misc

In [72]:
def gen_json(documents):
    
    for doc in documents:
        json_item = {
            #'id':doc['id_'],
            'filename': doc['filename'],
            'subset': SUBSET,
            'labels': doc['labels'].tolist(),
            #'age': int(doc['age']),
            #'gender': doc['gender'],
            #'view': doc['view'],
            'mu':doc['mu'].tolist(),
            'sigma':doc['sigma'].tolist()
        }

        if(SUBSET == 'golden_src'):
            JSON_DIR = os.path.join(GOLDEN_EMBEDDINGS_DIR, doc['filename'][:12] + "_src.json")
        elif(SUBSET == 'golden_dr'):
            JSON_DIR = os.path.join(GOLDEN_EMBEDDINGS_DIR, doc['filename'][:12] + "_dr.json")
        else:
            JSON_DIR = os.path.join(GOLDEN_EMBEDDINGS_DIR, doc['filename'][:12] + ".json")
            
        with open(JSON_DIR, 'w') as outfile:
            json.dump(json_item, outfile, separators=(',', ':'), indent = 2)
    

In [73]:
if SUBSET == "golden_dr":
    gen_json(golden_dr_documents)
    
elif SUBSET == "golden_src":
    gen_json(golden_src_documents)

elif SUBSET == "validate":
    gen_json(validate_documents)
    
elif SUBSET == "test":
    gen_json(test_documents)
    
elif SUBSET == "train":
    gen_json(train_documents)

## save images

In [40]:

#save the encoded 256x256 vectors as png file
def save_images(documents):
    for doc in documents:

        path = os.path.join(ENCODED_IMAGES_DIR, doc['filename'])
        misc.imsave(path, doc['xhat'])


In [101]:
if SUBSET == "golden":
    save_images(golden_documents)

elif SUBSET == "validate":
    save_images(validate_documents)
    
elif SUBSET == "test":
    save_images(test_documents)
    
elif SUBSET == "train":
    save_images(train_documents)

`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.
  import sys


In [102]:
for i,x in enumerate(train_documents):
    f = 0
print(i)

14619
