In [1]:
#import scipy
#rate,data = scipy.io.wavfile.read('./vocalSeparation/origin_mix.wav')
#scipy.io.wavfile.write('./vocalSeparation/morigin_mix.wav',rate,data)

In [2]:
#import soundfile as sf
#data, samplerate = sf.read('./vocalSeparation/pred_mix.wav', dtype='float32')
#data = librosa.resample(data.T, samplerate, 16000)

In [3]:
"""Training script for the WaveNet network on the VCTK corpus.

This script trains a network with the WaveNet using data from the VCTK corpus,
which can be freely downloaded at the following site (~10 GB):
http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
"""

from __future__ import print_function

import argparse
from datetime import datetime
import json
import os
import sys
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
import soundfile as sf

from wavenetVS import WaveNetModel, AudioReader, optimizer_factory


'''import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"'''


'import os\nos.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"\nos.environ["CUDA_VISIBLE_DEVICES"] = "0"'

In [4]:
BATCH_SIZE = 1
DATA_DIRECTORY = './vsCorpus'
LOGDIR_ROOT = './logdirVS'
CHECKPOINT_EVERY = 50
NUM_STEPS = int(1e5)
LEARNING_RATE = 1e-3
L2_REGULARIZATION_STRENGTH = 1e-5
WAVENET_PARAMS = './wavenet_params.json'
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
SAMPLE_SIZE = 100000
#SILENCE_THRESHOLD = 0.3
SILENCE_THRESHOLD = 0
EPSILON = 0.001
MOMENTUM = 0.9
MAX_TO_KEEP = 5
METADATA = False

In [5]:
def get_arguments():
    def _str_to_bool(s):
        """Convert string to bool (in argparse context)."""
        if s.lower() not in ['true', 'false']:
            raise ValueError('Argument needs to be a '
                             'boolean, got {}'.format(s))
        return {'true': True, 'false': False}[s.lower()]

    parser = argparse.ArgumentParser(description='WaveNet example network')
    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
                        help='How many wav files to process at once. Default: ' + str(BATCH_SIZE) + '.')
    parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,
                        help='The directory containing the VCTK corpus.')
    parser.add_argument('--store_metadata', type=bool, default=METADATA,
                        help='Whether to store advanced debugging information '
                        '(execution time, memory consumption) for use with '
                        'TensorBoard. Default: ' + str(METADATA) + '.')
    parser.add_argument('--logdir', type=str, default=None,
                        help='Directory in which to store the logging '
                        'information for TensorBoard. '
                        'If the model already exists, it will restore '
                        'the state and will continue training. '
                        'Cannot use with --logdir_root and --restore_from.')
    parser.add_argument('--logdir_root', type=str, default=None,
                        help='Root directory to place the logging '
                        'output and generated model. These are stored '
                        'under the dated subdirectory of --logdir_root. '
                        'Cannot use with --logdir.')
    parser.add_argument('--restore_from', type=str, default=None,
                        help='Directory in which to restore the model from. '
                        'This creates the new model under the dated directory '
                        'in --logdir_root. '
                        'Cannot use with --logdir.')
    parser.add_argument('--checkpoint_every', type=int,
                        default=CHECKPOINT_EVERY,
                        help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')
    parser.add_argument('--num_steps', type=int, default=NUM_STEPS,
                        help='Number of training steps. Default: ' + str(NUM_STEPS) + '.')
    parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
                        help='Learning rate for training. Default: ' + str(LEARNING_RATE) + '.')
    parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,
                        help='JSON file with the network parameters. Default: ' + WAVENET_PARAMS + '.')
    parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,
                        help='Concatenate and cut audio samples to this many '
                        'samples. Default: ' + str(SAMPLE_SIZE) + '.')
    parser.add_argument('--l2_regularization_strength', type=float,
                        default=L2_REGULARIZATION_STRENGTH,
                        help='Coefficient in the L2 regularization. '
                        'Default: False')
    parser.add_argument('--silence_threshold', type=float,
                        default=SILENCE_THRESHOLD,
                        help='Volume threshold below which to trim the start '
                        'and the end from the training set samples. Default: ' + str(SILENCE_THRESHOLD) + '.')
    parser.add_argument('--optimizer', type=str, default='adam',
                        choices=optimizer_factory.keys(),
                        help='Select the optimizer specified by this option. Default: adam.')
    parser.add_argument('--momentum', type=float,
                        default=MOMENTUM, help='Specify the momentum to be '
                        'used by sgd or rmsprop optimizer. Ignored by the '
                        'adam optimizer. Default: ' + str(MOMENTUM) + '.')
    parser.add_argument('--histograms', type=_str_to_bool, default=False,
                        help='Whether to store histogram summaries. Default: False')
    parser.add_argument('--gc_channels', type=int, default=None,
                        help='Number of global condition channels. Default: None. Expecting: Int')
    parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,
                        help='Maximum amount of checkpoints that will be kept alive. Default: '
                             + str(MAX_TO_KEEP) + '.')
    return parser.parse_args([])


In [6]:
def save(saver, sess, logdir, step):
    model_name = 'model.ckpt'
    checkpoint_path = os.path.join(logdir, model_name)
    print('Storing checkpoint to {} ...'.format(logdir), end="")
    sys.stdout.flush()

    if not os.path.exists(logdir):
        os.makedirs(logdir)

    saver.save(sess, checkpoint_path, global_step=step)
    print(' Done.')

In [7]:
def load(saver, sess, logdir):
    print("Trying to restore saved checkpoints from {} ...".format(logdir),
          end="")

    ckpt = tf.train.get_checkpoint_state(logdir)
    if ckpt:
        print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))
        global_step = int(ckpt.model_checkpoint_path
                          .split('/')[-1]
                          .split('-')[-1])
        print("  Global step was: {}".format(global_step))
        print("  Restoring...", end="")
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(" Done.")
        return global_step
    else:
        print(" No checkpoint found.")
        return None

In [None]:
def get_default_logdir(logdir_root):
    logdir = os.path.join(logdir_root, 'train', STARTED_DATESTRING)
    return logdir


def validate_directories(args):
    """Validate and arrange directory related arguments."""

    # Validation
    if args.logdir and args.logdir_root:
        raise ValueError("--logdir and --logdir_root cannot be "
                         "specified at the same time.")

    if args.logdir and args.restore_from:
        raise ValueError(
            "--logdir and --restore_from cannot be specified at the same "
            "time. This is to keep your previous model from unexpected "
            "overwrites.\n"
            "Use --logdir_root to specify the root of the directory which "
            "will be automatically created with current date and time, or use "
            "only --logdir to just continue the training from the last "
            "checkpoint.")

    # Arrangement
    logdir_root = args.logdir_root
    if logdir_root is None:
        logdir_root = LOGDIR_ROOT

    logdir = args.logdir
    if logdir is None:
        logdir = get_default_logdir(logdir_root)
        print('Using default logdir: {}'.format(logdir))

    restore_from = args.restore_from
    if restore_from is None:
        # args.logdir and args.restore_from are exclusive,
        # so it is guaranteed the logdir here is newly created.
        restore_from = logdir

    return {
        'logdir': logdir,
        'logdir_root': args.logdir_root,
        'restore_from': restore_from
    }


In [None]:
args = get_arguments()

try:
    directories = validate_directories(args)
except ValueError as e:
    print("Some arguments are wrong:")
    print(str(e))

logdir = directories['logdir']
restore_from = directories['restore_from']

# Even if we restored the model, we will treat it as new training
# if the trained model is written into an arbitrary location.
is_overwritten_training = logdir != restore_from

with open(args.wavenet_params, 'r') as f:
    wavenet_params = json.load(f)

# Create coordinator.
coord = tf.train.Coordinator()

# Load raw waveform from VCTK corpus.
with tf.name_scope('create_inputs'):
    # Allow silence trimming to be skipped by specifying a threshold near
    # zero.
    silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                  EPSILON else None
    gc_enabled = args.gc_channels is not None
    reader = AudioReader(
        args.data_dir,
        coord,
        sample_rate=wavenet_params['sample_rate'],   #"sample_rate": 16000,
        gc_enabled=gc_enabled,
        receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
                                                               wavenet_params["dilations"],
                                                               wavenet_params["scalar_input"],
                                                               wavenet_params["initial_filter_width"]),
        sample_size=args.sample_size,  #SAMPLE_SIZE = 100000
        silence_threshold=silence_threshold)
    traudio_batch = reader.trdequeue(args.batch_size)  #BATCH_SIZE = 1
    if gc_enabled:
        ##TODO train and val
        gc_id_batch = reader.dequeue_gc(args.batch_size)
    else:
        gc_id_batch = None

# Create network.
net = WaveNetModel(
    batch_size=args.batch_size,
    dilations=wavenet_params["dilations"],
    filter_width=wavenet_params["filter_width"],
    residual_channels=wavenet_params["residual_channels"],
    dilation_channels=wavenet_params["dilation_channels"],
    skip_channels=wavenet_params["skip_channels"],
    quantization_channels=wavenet_params["quantization_channels"],
    use_biases=wavenet_params["use_biases"],
    scalar_input=wavenet_params["scalar_input"],
    initial_filter_width=wavenet_params["initial_filter_width"],
    histograms=args.histograms,
    global_condition_channels=args.gc_channels,
    global_condition_cardinality=reader.gc_category_cardinality)

if args.l2_regularization_strength == 0:
    args.l2_regularization_strength = None
trloss = net.trloss(input_batch=traudio_batch,
                global_condition_batch=gc_id_batch,
                l2_regularization_strength=args.l2_regularization_strength)
optimizer = optimizer_factory[args.optimizer](
                learning_rate=args.learning_rate,
                momentum=args.momentum)
trainable = tf.trainable_variables()
optim = optimizer.minimize(trloss, var_list=trainable)

# Set up logging for TensorBoard.
writer = tf.summary.FileWriter(logdir)
writer.add_graph(tf.get_default_graph())
run_metadata = tf.RunMetadata()
summaries = tf.summary.merge_all()

# Set up session
config = tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
#sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

vaudio_batch = reader.valbatch()
valloss = net.valloss(input_batch=vaudio_batch,
                global_condition_batch=gc_id_batch,
                l2_regularization_strength=args.l2_regularization_strength)
genfile = net.generateFile(input_batch=vaudio_batch,
                global_condition_batch=gc_id_batch,
                l2_regularization_strength=args.l2_regularization_strength)

init = tf.global_variables_initializer()
sess.run(init)

# Saver for storing checkpoints of the model.
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)

try:
    saved_global_step = load(saver, sess, restore_from)
    if is_overwritten_training or saved_global_step is None:
        # The first training step will be saved_global_step + 1,
        # therefore we put -1 here for new or overwritten trainings.
        saved_global_step = -1

except:
    print("Something went wrong while restoring checkpoint. "
          "We will terminate training to avoid accidentally overwriting "
          "the previous model.")
    raise

    
    
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
reader.start_threads(sess)


step = None
last_saved_step = saved_global_step
minvalloss = 10000
try:
    for step in range(saved_global_step + 1, args.num_steps):
        start_time = time.time()
        if args.store_metadata and step % 50 == 0:
            # Slow run that stores extra information for debugging.
            print('Storing metadata')
            run_options = tf.RunOptions(
                trace_level=tf.RunOptions.FULL_TRACE)
            summary, trloss_value, _ = sess.run(
                [summaries, trloss, optim],
                options=run_options,
                run_metadata=run_metadata)
            writer.add_summary(summary, step)
            writer.add_run_metadata(run_metadata,
                                    'step_{:04d}'.format(step))
            tl = timeline.Timeline(run_metadata.step_stats)
            timeline_path = os.path.join(logdir, 'timeline.trace')
            with open(timeline_path, 'w') as f:
                f.write(tl.generate_chrome_trace_format(show_memory=True))
        else:
            summary, trloss_value, _ = sess.run([summaries, trloss, optim])
            writer.add_summary(summary, step)
        duration = time.time() - start_time
        print('step {:d} - trloss = {:.3f}, ({:.3f} sec/step)'
              .format(step, trloss_value, duration))
        
        
        if step % args.checkpoint_every == 0:
            ans = sess.run(genfile)
            #print(ans.shape)
            sf.write('./vsCorpus/ans/'+str(step)+'.wav',ans.reshape(-1),16000)
            print('stored done')
            valloss_value = sess.run(valloss)
            print('validateLoss = {:.3f}, ({:.3f} sec/step)'
              .format(valloss_value, time.time() - start_time))
            #if(valloss_value < minvalloss):
                #save(saver, sess, logdir, step)
                #last_saved_step = step
                #minvalloss = valloss_value

except KeyboardInterrupt:
    # Introduce a line break after ^C is displayed so save message
    # is on its own line.
    print()
finally:
    if step > last_saved_step:
        save(saver, sess, logdir, step)
    coord.request_stop()
    coord.join(threads)




Using default logdir: ./logdirVS/train/2018-06-06T00-08-54
trdequeue
val ['./vsCorpus/pred_mix.wav', './vsCorpus/pred_vocal.wav']
raw_output Tensor("wavenet_2/postprocessing/Add_1:0", shape=(?, ?, 256), dtype=float32)
ans0 (?,)
Trying to restore saved checkpoints from ./logdirVS/train/2018-06-06T00-08-54 ... No checkpoint found.
step 0 - trloss = 5.702, (18.824 sec/step)
stored done
validateLoss = 5.624, (18.824 sec/step)
step 1 - trloss = 5.667, (2.497 sec/step)
step 2 - trloss = 5.637, (2.487 sec/step)
step 3 - trloss = 5.608, (2.487 sec/step)
step 4 - trloss = 5.583, (2.488 sec/step)
step 5 - trloss = 5.560, (2.487 sec/step)
step 6 - trloss = 5.542, (2.487 sec/step)
step 7 - trloss = 5.525, (2.489 sec/step)
step 8 - trloss = 5.512, (2.499 sec/step)
step 9 - trloss = 5.497, (2.495 sec/step)
step 10 - trloss = 5.488, (2.495 sec/step)
step 11 - trloss = 5.477, (2.494 sec/step)
step 12 - trloss = 5.468, (2.500 sec/step)
step 13 - trloss = 5.461, (2.496 sec/step)
step 14 - trloss = 5.449

step 177 - trloss = 4.426, (2.502 sec/step)
step 178 - trloss = 4.275, (2.506 sec/step)
step 179 - trloss = 4.383, (2.514 sec/step)
step 180 - trloss = 4.304, (2.509 sec/step)
step 181 - trloss = 4.355, (2.502 sec/step)
step 182 - trloss = 4.325, (2.504 sec/step)
step 183 - trloss = 4.333, (2.506 sec/step)
step 184 - trloss = 4.290, (2.502 sec/step)
step 185 - trloss = 4.323, (2.509 sec/step)
step 186 - trloss = 4.291, (2.502 sec/step)
step 187 - trloss = 4.327, (2.503 sec/step)
step 188 - trloss = 4.266, (2.501 sec/step)
step 189 - trloss = 4.317, (2.503 sec/step)
step 190 - trloss = 4.264, (2.501 sec/step)
step 191 - trloss = 4.325, (2.504 sec/step)
step 192 - trloss = 4.246, (2.504 sec/step)
step 193 - trloss = 4.310, (2.503 sec/step)
step 194 - trloss = 4.243, (2.502 sec/step)
step 195 - trloss = 4.315, (2.501 sec/step)
step 196 - trloss = 4.232, (2.501 sec/step)
step 197 - trloss = 4.300, (2.500 sec/step)
step 198 - trloss = 4.229, (2.502 sec/step)
step 199 - trloss = 4.305, (2.50

step 359 - trloss = 3.967, (2.499 sec/step)
step 360 - trloss = 3.908, (2.500 sec/step)
step 361 - trloss = 4.018, (2.498 sec/step)
step 362 - trloss = 3.978, (2.499 sec/step)
step 363 - trloss = 4.020, (2.499 sec/step)
step 364 - trloss = 3.901, (2.498 sec/step)
step 365 - trloss = 4.033, (2.501 sec/step)
step 366 - trloss = 3.932, (2.500 sec/step)
step 367 - trloss = 3.991, (2.512 sec/step)
step 368 - trloss = 3.904, (2.502 sec/step)
step 369 - trloss = 3.981, (2.500 sec/step)
step 370 - trloss = 3.900, (2.501 sec/step)
step 371 - trloss = 3.984, (2.504 sec/step)
step 372 - trloss = 3.898, (2.502 sec/step)
step 373 - trloss = 3.955, (2.509 sec/step)
step 374 - trloss = 3.899, (2.501 sec/step)
step 375 - trloss = 3.954, (2.502 sec/step)
step 376 - trloss = 3.882, (2.501 sec/step)
step 377 - trloss = 3.951, (2.502 sec/step)
step 378 - trloss = 3.880, (2.502 sec/step)
step 379 - trloss = 3.944, (2.501 sec/step)
step 380 - trloss = 3.871, (2.503 sec/step)
step 381 - trloss = 3.961, (2.50

step 542 - trloss = 3.580, (2.503 sec/step)
step 543 - trloss = 3.639, (2.500 sec/step)
step 544 - trloss = 3.584, (2.500 sec/step)
step 545 - trloss = 3.643, (2.500 sec/step)
step 546 - trloss = 3.559, (2.501 sec/step)
step 547 - trloss = 3.628, (2.502 sec/step)
step 548 - trloss = 3.564, (2.500 sec/step)
step 549 - trloss = 3.634, (2.499 sec/step)
step 550 - trloss = 3.572, (2.500 sec/step)
stored done
validateLoss = 5.421, (2.500 sec/step)
step 551 - trloss = 3.627, (2.500 sec/step)
step 552 - trloss = 3.545, (2.500 sec/step)
step 553 - trloss = 3.630, (2.499 sec/step)
step 554 - trloss = 3.569, (2.500 sec/step)
step 555 - trloss = 3.649, (2.500 sec/step)
step 556 - trloss = 3.562, (2.502 sec/step)
step 557 - trloss = 3.606, (2.501 sec/step)
step 558 - trloss = 3.556, (2.500 sec/step)
step 559 - trloss = 3.639, (2.500 sec/step)
step 560 - trloss = 3.589, (2.498 sec/step)
step 561 - trloss = 3.657, (2.499 sec/step)
step 562 - trloss = 3.524, (2.501 sec/step)
step 563 - trloss = 3.634

step 724 - trloss = 3.309, (2.499 sec/step)
step 725 - trloss = 3.388, (2.499 sec/step)
step 726 - trloss = 3.323, (2.500 sec/step)
step 727 - trloss = 3.368, (2.501 sec/step)
step 728 - trloss = 3.344, (2.500 sec/step)
step 729 - trloss = 3.356, (2.501 sec/step)
step 730 - trloss = 3.326, (2.502 sec/step)
step 731 - trloss = 3.328, (2.500 sec/step)
step 732 - trloss = 3.302, (2.501 sec/step)
step 733 - trloss = 3.328, (2.502 sec/step)
step 734 - trloss = 3.272, (2.501 sec/step)
step 735 - trloss = 3.308, (2.503 sec/step)
step 736 - trloss = 3.266, (2.501 sec/step)
step 737 - trloss = 3.299, (2.503 sec/step)
step 738 - trloss = 3.247, (2.503 sec/step)
step 739 - trloss = 3.285, (2.502 sec/step)
step 740 - trloss = 3.227, (2.501 sec/step)
step 741 - trloss = 3.283, (2.502 sec/step)
step 742 - trloss = 3.204, (2.501 sec/step)
step 743 - trloss = 3.276, (2.509 sec/step)
step 744 - trloss = 3.188, (2.501 sec/step)
step 745 - trloss = 3.263, (2.501 sec/step)
step 746 - trloss = 3.182, (2.50

step 906 - trloss = 2.863, (2.498 sec/step)
step 907 - trloss = 2.974, (2.505 sec/step)
step 908 - trloss = 2.831, (2.498 sec/step)
step 909 - trloss = 2.935, (2.500 sec/step)
step 910 - trloss = 2.874, (2.501 sec/step)
step 911 - trloss = 2.895, (2.502 sec/step)
step 912 - trloss = 2.831, (2.501 sec/step)
step 913 - trloss = 2.918, (2.501 sec/step)
step 914 - trloss = 2.841, (2.500 sec/step)
step 915 - trloss = 2.872, (2.501 sec/step)
step 916 - trloss = 2.813, (2.502 sec/step)
step 917 - trloss = 2.905, (2.502 sec/step)
step 918 - trloss = 2.803, (2.501 sec/step)
step 919 - trloss = 2.843, (2.500 sec/step)
step 920 - trloss = 2.840, (2.501 sec/step)
step 921 - trloss = 2.873, (2.502 sec/step)
step 922 - trloss = 2.769, (2.503 sec/step)
step 923 - trloss = 2.862, (2.502 sec/step)
step 924 - trloss = 2.807, (2.501 sec/step)
step 925 - trloss = 2.871, (2.500 sec/step)
step 926 - trloss = 2.792, (2.500 sec/step)
step 927 - trloss = 2.823, (2.501 sec/step)
step 928 - trloss = 2.810, (2.50

step 1087 - trloss = 2.530, (2.499 sec/step)
step 1088 - trloss = 2.414, (2.499 sec/step)
step 1089 - trloss = 2.484, (2.500 sec/step)
step 1090 - trloss = 2.480, (2.499 sec/step)
step 1091 - trloss = 2.509, (2.499 sec/step)
step 1092 - trloss = 2.383, (2.498 sec/step)
step 1093 - trloss = 2.486, (2.497 sec/step)
step 1094 - trloss = 2.504, (2.498 sec/step)
step 1095 - trloss = 2.502, (2.499 sec/step)
step 1096 - trloss = 2.377, (2.499 sec/step)
step 1097 - trloss = 2.504, (2.507 sec/step)
step 1098 - trloss = 2.466, (2.506 sec/step)
step 1099 - trloss = 2.485, (2.500 sec/step)
step 1100 - trloss = 2.420, (2.500 sec/step)
stored done
validateLoss = 8.049, (2.500 sec/step)
step 1101 - trloss = 2.509, (2.507 sec/step)
step 1102 - trloss = 2.419, (2.505 sec/step)
step 1103 - trloss = 2.436, (2.501 sec/step)
step 1104 - trloss = 2.457, (2.500 sec/step)
step 1105 - trloss = 2.503, (2.502 sec/step)
step 1106 - trloss = 2.358, (2.502 sec/step)
step 1107 - trloss = 2.472, (2.502 sec/step)
step

step 1265 - trloss = 2.175, (2.502 sec/step)
step 1266 - trloss = 2.013, (2.508 sec/step)
step 1267 - trloss = 2.191, (2.502 sec/step)
step 1268 - trloss = 2.132, (2.502 sec/step)
step 1269 - trloss = 2.079, (2.503 sec/step)
step 1270 - trloss = 2.059, (2.501 sec/step)
step 1271 - trloss = 2.184, (2.502 sec/step)
step 1272 - trloss = 2.036, (2.503 sec/step)
step 1273 - trloss = 2.179, (2.502 sec/step)
step 1274 - trloss = 2.028, (2.502 sec/step)
step 1275 - trloss = 2.086, (2.500 sec/step)
step 1276 - trloss = 2.054, (2.502 sec/step)
step 1277 - trloss = 2.125, (2.502 sec/step)
step 1278 - trloss = 2.008, (2.501 sec/step)
step 1279 - trloss = 2.087, (2.503 sec/step)
step 1280 - trloss = 2.039, (2.501 sec/step)
step 1281 - trloss = 2.085, (2.500 sec/step)
step 1282 - trloss = 2.005, (2.500 sec/step)
step 1283 - trloss = 2.078, (2.500 sec/step)
step 1284 - trloss = 2.010, (2.501 sec/step)
step 1285 - trloss = 2.070, (2.509 sec/step)
step 1286 - trloss = 1.987, (2.500 sec/step)
step 1287 

step 1444 - trloss = 1.720, (2.500 sec/step)
step 1445 - trloss = 1.749, (2.500 sec/step)
step 1446 - trloss = 1.740, (2.501 sec/step)
step 1447 - trloss = 1.746, (2.501 sec/step)
step 1448 - trloss = 1.693, (2.500 sec/step)
step 1449 - trloss = 1.770, (2.502 sec/step)
step 1450 - trloss = 1.669, (2.498 sec/step)
stored done
validateLoss = 9.800, (2.498 sec/step)
step 1451 - trloss = 1.753, (2.501 sec/step)
step 1452 - trloss = 1.701, (2.500 sec/step)
step 1453 - trloss = 1.700, (2.499 sec/step)
step 1454 - trloss = 1.688, (2.499 sec/step)
step 1455 - trloss = 1.684, (2.501 sec/step)
step 1456 - trloss = 1.664, (2.501 sec/step)
step 1457 - trloss = 1.699, (2.499 sec/step)
step 1458 - trloss = 1.630, (2.500 sec/step)
step 1459 - trloss = 1.684, (2.499 sec/step)
step 1460 - trloss = 1.627, (2.502 sec/step)
step 1461 - trloss = 1.691, (2.500 sec/step)
step 1462 - trloss = 1.633, (2.500 sec/step)
step 1463 - trloss = 1.653, (2.501 sec/step)
step 1464 - trloss = 1.622, (2.502 sec/step)
step

In [None]:
from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = bytes("<stripped %d bytes>"%max_const_size, 'utf-8')
                
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

In [None]:
show_graph(tf.get_default_graph().as_graph_def())