In [1]:
#import scipy
#rate,data = scipy.io.wavfile.read('./vocalSeparation/origin_mix.wav')
#scipy.io.wavfile.write('./vocalSeparation/morigin_mix.wav',rate,data)

In [2]:
#import soundfile as sf
#data, samplerate = sf.read('./vocalSeparation/pred_mix.wav', dtype='float32')
#data = librosa.resample(data.T, samplerate, 16000)

In [3]:
"""Training script for the WaveNet network on the VCTK corpus.

This script trains a network with the WaveNet using data from the VCTK corpus,
which can be freely downloaded at the following site (~10 GB):
http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html
"""

from __future__ import print_function

import argparse
from datetime import datetime
import json
import os
import sys
import time
import numpy as np
import tensorflow as tf
from tensorflow.python.client import timeline
import soundfile as sf

from wavenetVS import WaveNetModel, AudioReader, optimizer_factory


'''import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"'''


'import os\nos.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"\nos.environ["CUDA_VISIBLE_DEVICES"] = "0"'

In [4]:
BATCH_SIZE = 1
DATA_DIRECTORY = './vsCorpus'
LOGDIR_ROOT = './logdirVS'
CHECKPOINT_EVERY = 50
NUM_STEPS = int(1e5)
LEARNING_RATE = 1e-3
#L2_REGULARIZATION_STRENGTH = 1e-5
L2_REGULARIZATION_STRENGTH = 0
WAVENET_PARAMS = './wavenet_params.json'
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
SAMPLE_SIZE = 100000
#SILENCE_THRESHOLD = 0.3
SILENCE_THRESHOLD = 0
EPSILON = 0.001
MOMENTUM = 0.9
MAX_TO_KEEP = 5
METADATA = False

In [5]:
def get_arguments():
    def _str_to_bool(s):
        """Convert string to bool (in argparse context)."""
        if s.lower() not in ['true', 'false']:
            raise ValueError('Argument needs to be a '
                             'boolean, got {}'.format(s))
        return {'true': True, 'false': False}[s.lower()]

    parser = argparse.ArgumentParser(description='WaveNet example network')
    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
                        help='How many wav files to process at once. Default: ' + str(BATCH_SIZE) + '.')
    parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,
                        help='The directory containing the VCTK corpus.')
    parser.add_argument('--store_metadata', type=bool, default=METADATA,
                        help='Whether to store advanced debugging information '
                        '(execution time, memory consumption) for use with '
                        'TensorBoard. Default: ' + str(METADATA) + '.')
    parser.add_argument('--logdir', type=str, default=None,
                        help='Directory in which to store the logging '
                        'information for TensorBoard. '
                        'If the model already exists, it will restore '
                        'the state and will continue training. '
                        'Cannot use with --logdir_root and --restore_from.')
    parser.add_argument('--logdir_root', type=str, default=None,
                        help='Root directory to place the logging '
                        'output and generated model. These are stored '
                        'under the dated subdirectory of --logdir_root. '
                        'Cannot use with --logdir.')
    parser.add_argument('--restore_from', type=str, default=None,
                        help='Directory in which to restore the model from. '
                        'This creates the new model under the dated directory '
                        'in --logdir_root. '
                        'Cannot use with --logdir.')
    parser.add_argument('--checkpoint_every', type=int,
                        default=CHECKPOINT_EVERY,
                        help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')
    parser.add_argument('--num_steps', type=int, default=NUM_STEPS,
                        help='Number of training steps. Default: ' + str(NUM_STEPS) + '.')
    parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
                        help='Learning rate for training. Default: ' + str(LEARNING_RATE) + '.')
    parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,
                        help='JSON file with the network parameters. Default: ' + WAVENET_PARAMS + '.')
    parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,
                        help='Concatenate and cut audio samples to this many '
                        'samples. Default: ' + str(SAMPLE_SIZE) + '.')
    parser.add_argument('--l2_regularization_strength', type=float,
                        default=L2_REGULARIZATION_STRENGTH,
                        help='Coefficient in the L2 regularization. '
                        'Default: False')
    parser.add_argument('--silence_threshold', type=float,
                        default=SILENCE_THRESHOLD,
                        help='Volume threshold below which to trim the start '
                        'and the end from the training set samples. Default: ' + str(SILENCE_THRESHOLD) + '.')
    parser.add_argument('--optimizer', type=str, default='adam',
                        choices=optimizer_factory.keys(),
                        help='Select the optimizer specified by this option. Default: adam.')
    parser.add_argument('--momentum', type=float,
                        default=MOMENTUM, help='Specify the momentum to be '
                        'used by sgd or rmsprop optimizer. Ignored by the '
                        'adam optimizer. Default: ' + str(MOMENTUM) + '.')
    parser.add_argument('--histograms', type=_str_to_bool, default=False,
                        help='Whether to store histogram summaries. Default: False')
    parser.add_argument('--gc_channels', type=int, default=None,
                        help='Number of global condition channels. Default: None. Expecting: Int')
    parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,
                        help='Maximum amount of checkpoints that will be kept alive. Default: '
                             + str(MAX_TO_KEEP) + '.')
    return parser.parse_args([])


In [6]:
def save(saver, sess, logdir, step):
    model_name = 'model.ckpt'
    checkpoint_path = os.path.join(logdir, model_name)
    print('Storing checkpoint to {} ...'.format(logdir), end="")
    sys.stdout.flush()

    if not os.path.exists(logdir):
        os.makedirs(logdir)

    saver.save(sess, checkpoint_path, global_step=step)
    print(' Done.')

In [7]:
def load(saver, sess, logdir):
    print("Trying to restore saved checkpoints from {} ...".format(logdir),
          end="")

    ckpt = tf.train.get_checkpoint_state(logdir)
    if ckpt:
        print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))
        global_step = int(ckpt.model_checkpoint_path
                          .split('/')[-1]
                          .split('-')[-1])
        print("  Global step was: {}".format(global_step))
        print("  Restoring...", end="")
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(" Done.")
        return global_step
    else:
        print(" No checkpoint found.")
        return None

In [None]:
def get_default_logdir(logdir_root):
    logdir = os.path.join(logdir_root, 'train', STARTED_DATESTRING)
    return logdir


def validate_directories(args):
    """Validate and arrange directory related arguments."""

    # Validation
    if args.logdir and args.logdir_root:
        raise ValueError("--logdir and --logdir_root cannot be "
                         "specified at the same time.")

    if args.logdir and args.restore_from:
        raise ValueError(
            "--logdir and --restore_from cannot be specified at the same "
            "time. This is to keep your previous model from unexpected "
            "overwrites.\n"
            "Use --logdir_root to specify the root of the directory which "
            "will be automatically created with current date and time, or use "
            "only --logdir to just continue the training from the last "
            "checkpoint.")

    # Arrangement
    logdir_root = args.logdir_root
    if logdir_root is None:
        logdir_root = LOGDIR_ROOT

    logdir = args.logdir
    if logdir is None:
        logdir = get_default_logdir(logdir_root)
        print('Using default logdir: {}'.format(logdir))

    restore_from = args.restore_from
    if restore_from is None:
        # args.logdir and args.restore_from are exclusive,
        # so it is guaranteed the logdir here is newly created.
        restore_from = logdir

    return {
        'logdir': logdir,
        'logdir_root': args.logdir_root,
        'restore_from': restore_from
    }


In [None]:
args = get_arguments()

try:
    directories = validate_directories(args)
except ValueError as e:
    print("Some arguments are wrong:")
    print(str(e))

logdir = directories['logdir']
restore_from = directories['restore_from']

# Even if we restored the model, we will treat it as new training
# if the trained model is written into an arbitrary location.
is_overwritten_training = logdir != restore_from

with open(args.wavenet_params, 'r') as f:
    wavenet_params = json.load(f)

# Create coordinator.
coord = tf.train.Coordinator()

# Load raw waveform from VCTK corpus.
with tf.name_scope('create_inputs'):
    # Allow silence trimming to be skipped by specifying a threshold near
    # zero.
    silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                  EPSILON else None
    gc_enabled = args.gc_channels is not None
    reader = AudioReader(
        args.data_dir,
        coord,
        sample_rate=wavenet_params['sample_rate'],   #"sample_rate": 16000,
        gc_enabled=gc_enabled,
        receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
                                                               wavenet_params["dilations"],
                                                               wavenet_params["scalar_input"],
                                                               wavenet_params["initial_filter_width"]),
        sample_size=args.sample_size,  #SAMPLE_SIZE = 100000
        silence_threshold=silence_threshold)
    traudio_batch = reader.trdequeue(args.batch_size)  #BATCH_SIZE = 1
    if gc_enabled:
        ##TODO train and val
        gc_id_batch = reader.dequeue_gc(args.batch_size)
    else:
        gc_id_batch = None

# Create network.
net = WaveNetModel(
    batch_size=args.batch_size,
    dilations=wavenet_params["dilations"],
    filter_width=wavenet_params["filter_width"],
    residual_channels=wavenet_params["residual_channels"],
    dilation_channels=wavenet_params["dilation_channels"],
    skip_channels=wavenet_params["skip_channels"],
    quantization_channels=wavenet_params["quantization_channels"],
    use_biases=wavenet_params["use_biases"],
    scalar_input=wavenet_params["scalar_input"],
    initial_filter_width=wavenet_params["initial_filter_width"],
    histograms=args.histograms,
    global_condition_channels=args.gc_channels,
    global_condition_cardinality=reader.gc_category_cardinality)

if args.l2_regularization_strength == 0:
    args.l2_regularization_strength = None
trloss = net.trloss(input_batch=traudio_batch,
                global_condition_batch=gc_id_batch,
                l2_regularization_strength=args.l2_regularization_strength)
optimizer = optimizer_factory[args.optimizer](
                learning_rate=args.learning_rate,
                momentum=args.momentum)
trainable = tf.trainable_variables()
optim = optimizer.minimize(trloss, var_list=trainable)

# Set up logging for TensorBoard.
writer = tf.summary.FileWriter(logdir)
writer.add_graph(tf.get_default_graph())
run_metadata = tf.RunMetadata()
summaries = tf.summary.merge_all()

# Set up session
config = tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)
#sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

vaudio_batch = reader.valbatch()
valloss = net.valloss(input_batch=vaudio_batch,
                global_condition_batch=gc_id_batch,
                l2_regularization_strength=args.l2_regularization_strength)
genfile = net.generateFile(input_batch=vaudio_batch,
                global_condition_batch=gc_id_batch,
                l2_regularization_strength=args.l2_regularization_strength)

init = tf.global_variables_initializer()
sess.run(init)

# Saver for storing checkpoints of the model.
saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)

try:
    saved_global_step = load(saver, sess, restore_from)
    if is_overwritten_training or saved_global_step is None:
        # The first training step will be saved_global_step + 1,
        # therefore we put -1 here for new or overwritten trainings.
        saved_global_step = -1

except:
    print("Something went wrong while restoring checkpoint. "
          "We will terminate training to avoid accidentally overwriting "
          "the previous model.")
    raise

    
    
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
reader.start_threads(sess)


step = None
last_saved_step = saved_global_step
minvalloss = 10000
try:
    for step in range(saved_global_step + 1, args.num_steps):
        start_time = time.time()
        if args.store_metadata and step % 50 == 0:
            # Slow run that stores extra information for debugging.
            print('Storing metadata')
            run_options = tf.RunOptions(
                trace_level=tf.RunOptions.FULL_TRACE)
            summary, trloss_value, _ = sess.run(
                [summaries, trloss, optim],
                options=run_options,
                run_metadata=run_metadata)
            writer.add_summary(summary, step)
            writer.add_run_metadata(run_metadata,
                                    'step_{:04d}'.format(step))
            tl = timeline.Timeline(run_metadata.step_stats)
            timeline_path = os.path.join(logdir, 'timeline.trace')
            with open(timeline_path, 'w') as f:
                f.write(tl.generate_chrome_trace_format(show_memory=True))
        else:
            summary, trloss_value, _ = sess.run([summaries, trloss, optim])
            writer.add_summary(summary, step)
        duration = time.time() - start_time
        print('step {:d} - trloss = {:.3f}, ({:.3f} sec/step)'
              .format(step, trloss_value, duration))
        
        
        if step % args.checkpoint_every == 0:
            ans = sess.run(genfile)
            #print(ans.shape)
            sf.write('./vsCorpus/ans/'+str(step)+'.wav',ans.reshape(-1),16000)
            print('stored done')
            valloss_value = sess.run(valloss)
            print('validateLoss = {:.3f}, ({:.3f} sec/step)'
              .format(valloss_value, time.time() - start_time))
            #if(valloss_value < minvalloss):
                #save(saver, sess, logdir, step)
                #last_saved_step = step
                #minvalloss = valloss_value

except KeyboardInterrupt:
    # Introduce a line break after ^C is displayed so save message
    # is on its own line.
    print()
finally:
    if step > last_saved_step:
        save(saver, sess, logdir, step)
    coord.request_stop()
    coord.join(threads)




Using default logdir: ./logdirVS/train/2018-06-06T13-17-48
trdequeue
val ['./vsCorpus/origin_mix.wav', './vsCorpus/origin_vocal.wav']
raw_output Tensor("wavenet_2/postprocessing/Add_1:0", shape=(?, ?, 256), dtype=float32)
ans0 (?,)
Trying to restore saved checkpoints from ./logdirVS/train/2018-06-06T13-17-48 ... No checkpoint found.
step 0 - trloss = 5.630, (18.889 sec/step)
stored done
validateLoss = 5.602, (220.399 sec/step)
step 1 - trloss = 5.602, (2.487 sec/step)
step 2 - trloss = 5.574, (2.484 sec/step)
step 3 - trloss = 5.549, (2.482 sec/step)
step 4 - trloss = 5.527, (2.486 sec/step)
step 5 - trloss = 5.507, (2.485 sec/step)
step 6 - trloss = 5.490, (2.486 sec/step)
step 7 - trloss = 5.475, (2.485 sec/step)
step 8 - trloss = 5.460, (2.484 sec/step)
step 9 - trloss = 5.447, (2.483 sec/step)
step 10 - trloss = 5.435, (2.487 sec/step)
step 11 - trloss = 5.424, (2.485 sec/step)
step 12 - trloss = 5.413, (2.486 sec/step)
step 13 - trloss = 5.402, (2.486 sec/step)
step 14 - trloss = 

step 177 - trloss = 3.766, (2.506 sec/step)
step 178 - trloss = 3.742, (2.501 sec/step)
step 179 - trloss = 3.730, (2.507 sec/step)
step 180 - trloss = 3.689, (2.505 sec/step)
step 181 - trloss = 3.667, (2.504 sec/step)
step 182 - trloss = 3.662, (2.503 sec/step)
step 183 - trloss = 3.648, (2.504 sec/step)
step 184 - trloss = 3.651, (2.506 sec/step)
step 185 - trloss = 3.641, (2.509 sec/step)
step 186 - trloss = 3.633, (2.511 sec/step)
step 187 - trloss = 3.596, (2.504 sec/step)
step 188 - trloss = 3.587, (2.506 sec/step)
step 189 - trloss = 3.562, (2.501 sec/step)
step 190 - trloss = 3.535, (2.501 sec/step)
step 191 - trloss = 3.513, (2.501 sec/step)
step 192 - trloss = 3.498, (2.501 sec/step)
step 193 - trloss = 3.480, (2.498 sec/step)
step 194 - trloss = 3.462, (2.499 sec/step)
step 195 - trloss = 3.445, (2.499 sec/step)
step 196 - trloss = 3.441, (2.499 sec/step)
step 197 - trloss = 3.453, (2.497 sec/step)
step 198 - trloss = 3.535, (2.499 sec/step)
step 199 - trloss = 3.478, (2.49

step 359 - trloss = 0.932, (2.509 sec/step)
step 360 - trloss = 0.922, (2.507 sec/step)
step 361 - trloss = 0.920, (2.501 sec/step)
step 362 - trloss = 0.936, (2.501 sec/step)
step 363 - trloss = 0.929, (2.501 sec/step)
step 364 - trloss = 0.955, (2.501 sec/step)
step 365 - trloss = 0.944, (2.497 sec/step)
step 366 - trloss = 0.944, (2.502 sec/step)
step 367 - trloss = 0.889, (2.500 sec/step)
step 368 - trloss = 0.855, (2.502 sec/step)
step 369 - trloss = 0.841, (2.503 sec/step)
step 370 - trloss = 0.881, (2.503 sec/step)
step 371 - trloss = 0.936, (2.503 sec/step)
step 372 - trloss = 0.902, (2.501 sec/step)
step 373 - trloss = 0.834, (2.503 sec/step)
step 374 - trloss = 0.771, (2.502 sec/step)
step 375 - trloss = 0.789, (2.501 sec/step)
step 376 - trloss = 0.802, (2.499 sec/step)
step 377 - trloss = 0.757, (2.502 sec/step)
step 378 - trloss = 0.721, (2.502 sec/step)
step 379 - trloss = 0.740, (2.506 sec/step)
step 380 - trloss = 0.737, (2.500 sec/step)
step 381 - trloss = 0.702, (2.49

step 542 - trloss = 0.081, (2.498 sec/step)
step 543 - trloss = 0.080, (2.495 sec/step)
step 544 - trloss = 0.079, (2.494 sec/step)
step 545 - trloss = 0.079, (2.493 sec/step)
step 546 - trloss = 0.078, (2.495 sec/step)
step 547 - trloss = 0.077, (2.496 sec/step)
step 548 - trloss = 0.077, (2.494 sec/step)
step 549 - trloss = 0.076, (2.492 sec/step)
step 550 - trloss = 0.076, (2.496 sec/step)
stored done
validateLoss = 0.075, (4.717 sec/step)
step 551 - trloss = 0.075, (2.498 sec/step)
step 552 - trloss = 0.075, (2.493 sec/step)
step 553 - trloss = 0.074, (2.496 sec/step)
step 554 - trloss = 0.073, (2.498 sec/step)
step 555 - trloss = 0.072, (2.495 sec/step)
step 556 - trloss = 0.072, (2.494 sec/step)
step 557 - trloss = 0.071, (2.495 sec/step)
step 558 - trloss = 0.070, (2.496 sec/step)
step 559 - trloss = 0.070, (2.493 sec/step)
step 560 - trloss = 0.069, (2.495 sec/step)
step 561 - trloss = 0.069, (2.501 sec/step)
step 562 - trloss = 0.068, (2.495 sec/step)
step 563 - trloss = 0.068

In [None]:
from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = bytes("<stripped %d bytes>"%max_const_size, 'utf-8')
                
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

In [None]:
show_graph(tf.get_default_graph().as_graph_def())