In [1]:
import tensorflow as tf
import numpy as np
import librosa
import tflearn

seed = 'the best seed ever'
seed = sum([ord(char) for char in seed])
np.random.seed(seed)
tf.set_random_seed(seed)

In [2]:
class AudioGenerator():
    
    def __init__(self, audio_path, sample_rate, audio_frame_size, data_frames_truncation=0):
        data, _ = librosa.core.load(mono=True,
                                    path=audio_path,
                                    sr=sample_rate)
        self.original_length = len(data)
        if data_frames_truncation > 0:
            data = data[:audio_frame_size*data_frames_truncation]
        self.shortened_length = len(data)
        self.audio_frames = []
        skip = audio_frame_size
        for start in range(0, len(data) - audio_frame_size, skip):
            end = start + audio_frame_size
            self.audio_frames.append(data[start:end])
        self.audio_frames = np.array(self.audio_frames)
        self.index = 0
        self.epochs = 0
        np.random.shuffle(self.audio_frames)
        
        
    def print_dataset_stats(self):
        percent = self.shortened_length / self.original_length * 100
        
        print("Dataset stats:")
        print("  * Audio data {}% original size".format(percent, self.shortened_length))
        print("  * Audio stft frames shape", self.audio_frames.shape, "\n")
        
              
    def next_batch(self, batch_size):
        if self.index + batch_size >= len(self.audio_frames):
            self.index = 0
            self.epochs += 1
            np.random.shuffle(self.audio_frames)
        else:
            self.index += batch_size
        return self.audio_frames[self.index:self.index + batch_size], self.epochs

In [3]:
def recurrent_net(net, rec_type, rec_size, return_sequence):
    """
    A quick if else block to build a recurrent layer, based on the type specified
    by the user.
    """
    if rec_type == 'lstm':
        net = tflearn.layers.recurrent.lstm(net, rec_size, return_seq=return_sequence)
    elif rec_type == 'gru':
        net = tflearn.layers.recurrent.gru(net, rec_size, return_seq=return_sequence)
    elif rec_type == 'bi_lstm':
        net = bidirectional_rnn(net, 
                                BasicLSTMCell(rec_size), 
                                BasicLSTMCell(rec_size), 
                                return_seq=return_sequence)
    elif rec_type == 'bi_gru':
        net = bidirectional_rnn(net, 
                                GRUCell(rec_size), 
                                GRUCell(rec_size), 
                                return_seq=return_sequence)
    else:
        raise ValueError('Incorrect rnn type passed. Try lstm, gru, bi_lstm or bi_gru.')
    return net


def get_audio_input_frame_size(sequence_length, window_size, hop_size):
    input_frame_size = window_size
    for _ in range(sequence_length):
        input_frame_size += hop_size
    return input_frame_size

In [4]:
fft_size = 1024
hop_size = 256
window_size = 1024
sequence_length = 15
dataset_truncated_amount = 0
generating = True

rnn_sizes = [1024, 1024]
dense_sizes = [1024, 1024]
dense_activation = tf.nn.relu
weight_decay = 0.000
batch_norm = False
amount_epochs = 10000

batch_size = 32
learning_rate = 0.001
keep_prob = 1.0

input_frame_size = get_audio_input_frame_size(sequence_length,
                                              window_size,
                                              hop_size)

assert fft_size == window_size, "fft size must equal window size for maths to work"

In [5]:
tf.reset_default_graph()

with tf.variable_scope("inputs"):
    audio = tf.placeholder(tf.float32, 
                           shape=[None, input_frame_size])
    
with tf.variable_scope("stft"):
    stfts = tf.contrib.signal.stft(audio, 
                                   frame_length=fft_size, 
                                   frame_step=hop_size,
                                   fft_length=window_size,
                                   pad_end=False)
    
    print('stfts shape', stfts.get_shape())
    
    stft_frames_length = stfts.get_shape()[1]
    assert stft_frames_length == sequence_length + 1, "{} wtf".format()

with tf.variable_scope("cart2polar"):    
    magnitudes = tf.abs(stfts)
    phases = tf.angle(stfts)

with tf.variable_scope("input_target_split"):
    input_magnitudes = magnitudes[:, :sequence_length]
    input_phases = phases[:, :sequence_length]
    target_magnitudes = magnitudes[:, -1]
    target_phases = phases[:, -1]
    target_features = tf.concat([target_magnitudes, target_phases], axis=1)
    
    print(input_magnitudes.get_shape(), target_magnitudes.get_shape())
    print(input_phases.get_shape(), target_phases.get_shape())
    

with tf.variable_scope("mag_phases_concat"):
    features = tf.concat([input_magnitudes, input_phases], axis=2)

    if batch_norm:
        features = tf.contrib.layers.batch_norm(features)
        
net = features

# Recurrent
for layer, size in enumerate(rnn_sizes):
    return_sequence = layer != (len(rnn_sizes) - 1)
    net = recurrent_net(net, 'lstm', size, return_sequence)
    net = tflearn.dropout(net, keep_prob) 
    
    
# Dense + MLP Out
for size in dense_sizes:
    net = tflearn.fully_connected(net, 
                                  size, 
                                  activation=dense_activation,                                            
                                  regularizer='L2', 
                                  weight_decay=0.001)
                      
logits = tflearn.fully_connected(net, 
                                 features.get_shape()[2], 
                                 activation='linear')

split_size = int(features.get_shape()[2]) // 2

with tf.variable_scope("mag_phase_predict_split"):
    predicted_magnitudes = logits[:, :split_size] 
    predicted_phases = logits[:, split_size:]        

predicted_real = predicted_magnitudes * tf.cos(predicted_phases)
predicted_imag = predicted_magnitudes * tf.sin(predicted_phases)      
predicted_stft = tf.complex(predicted_real, predicted_imag)
predicted_audio = tf.contrib.signal.inverse_stft(
    predicted_stft,
    frame_length=window_size, 
    frame_step=hop_size,
    fft_length=fft_size,
    window_fn=None,
    name=None
)

loss = tf.losses.mean_squared_error(logits, target_features)

black_list = ['BatchNorm', 'batch_norm', 'LSTM', 'lstm', 'bias', '/b:']
regulisable_vars = []
for var in tf.trainable_variables():
    if not any([bad in var.name for bad in black_list]):
        regulisable_vars.append(tf.nn.l2_loss(var))

l2_losses = tf.add_n(regulisable_vars)
l2_loss = l2_losses * weight_decay

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

with tf.control_dependencies(update_ops):
    loss += l2_loss


optimiser = tf.train.AdamOptimizer(learning_rate=learning_rate)
        
optimise = optimiser.minimize(loss)

stfts shape (?, 16, 513)
(?, 15, 513) (?, 513)
(?, 15, 513) (?, 513)


In [6]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    audio_generator = AudioGenerator("./assets/electronic_piano/HM_120_AF_EPiano5.wav",
                                     44100,
                                     input_frame_size,
                                     dataset_truncated_amount)
    
    audio_generator.print_dataset_stats()

    print('Started optimisation.')
    
    generation_step = 20
    generation_length = 400
    generated_audio = []

    for epoch in range(amount_epochs):
        
        epoch = 0
        last_epoch = 0
            
        while epoch != amount_epochs:
            
            audio_frames, epoch = audio_generator.next_batch(batch_size)

            sess.run(optimise, feed_dict={
                audio: audio_frames
            })
            iteration_loss = sess.run(loss, feed_dict={
                audio: audio_frames
            })
            
            print(epoch, iteration_loss)
            
            if iteration_loss == 0.0:
                print(audio_frames)
            
            if epoch != last_epoch and generating and iteration_loss < 0.001:
                
                print('Generating.')
                
                index = np.random.randint(len(audio_generator.audio_frames))
                impulse = audio_generator.audio_frames[index]
                impulse_size = len(impulse)

                for _ in range(generation_length):
                    
                    predicted_stft_frames = sess.run(predicted_stft, feed_dict={
                        audio: impulse[-impulse_size:].reshape((1, -1))
                    })
                    
                    impulse = np.concatenate((impulse, predicted_audio_frames))  
                    
                generated_audio.append(impulse)
            
            last_epoch = epoch

Dataset stats:
  * Audio data 100.0% original size
  * Audio stft frames shape (145, 4864) 

Started optimisation.
0 2.54676
0 2.49069
0 1.97749
0 2.36907
1 1.85467
1 2.2966
1 2.05854
1 1.84664
1 1.95637
2 1.81655
2 2.12075
2 1.96342
2 1.94408
2 1.94958
3 1.94135
3 1.85145
3 1.90432
3 1.82602
3 1.74109
4 1.8952
4 1.77655
4 1.80163
4 1.80867
4 1.96874
5 1.69307
5 1.91414
5 1.70816
5 1.80042
5 1.94992
6 1.7003
6 1.69202
6 1.77391
6 1.80675
6 1.73345
7 1.69654
7 1.591
7 1.82861
7 1.59847
7 1.50543
8 1.61845
8 1.60182
8 1.59749
8 1.5055
8 1.46327
9 1.48398
9 1.57352
9 1.35213
9 1.51703
9 1.39681
10 1.33227
10 1.37254
10 1.39592
10 1.49893
10 1.27349
11 1.1625
11 1.5106
11 1.17958
11 1.44257
11 1.25142
12 1.29182
12 1.31165
12 1.23157
12 1.24799
12 1.04166
13 1.20076
13 1.11977
13 1.15991
13 1.29474
13 0.936035
14 1.08307
14 1.17181
14 1.06651
14 1.11984
14 0.986434
15 1.10411
15 1.02312
15 0.950196
15 1.05574
15 0.893401
16 0.813833
16 0.927005
16 0.927631
16 1.04928
16 1.10914
17 0.917992

KeyboardInterrupt: 

In [17]:
generated_audio = np.array(generated_audio).reshape((1, -1))

input_audio = tf.placeholder(tf.float32, shape=[1, None])
frames = tf.contrib.signal.frame(input_audio, frame_length=window_size, frame_step=hop_size)
reconstructed_signals = tf.contrib.signal.overlap_and_add(generated_audio, 
                                                          frame_step=hop_size)

with tf.Session() as sess:
    
    preview = sess.run(input_audio, feed_dict={
        input_audio: generated_audio
    })

print(generated_audio.shape)
print(preview.shape)
librosa.output.write_wav("preview.wav", preview, 44100)

(1, 1657856)
(1, 1657856)


error: ushort format requires 0 <= number <= (0x7fff * 2 + 1)

In [18]:
import IPython.display as ipd
ipd.Audio(preview.reshape((-1))[:44100 * 5], rate=44100)