## Context Encoder

Let's begin by importing tensorflow and the network


In [None]:
%pylab inline
pylab.rcParams['figure.figsize'] = (14, 28)
import IPython
import os
import librosa
import numpy as np
import tensorflow as tf
import functools
from tensorflow.contrib.signal.python.ops import window_ops

from network.sequentialModel import SequentialModel
from network.stftGapContextEncoder import StftGapContextEncoder
from utils.strechableNumpyArray import StrechableNumpyArray


Now we initialize the context encoder network and select the step we want to use for the reconstruction.

In [None]:
sr = 16000
start_in_seconds = 0.1
side_length = 2048
gap_length = 1024
window_size = side_length*2+gap_length
starting_sample_left_side = int(sr*start_in_seconds)
ending_sample_left_side = starting_sample_left_side + side_length
starting_sample_right_side = ending_sample_left_side + gap_length
ending_sample_right_side = starting_sample_right_side + side_length

best_step =  # insert best step
batch_size = 256
fft_frame_length = 512
fft_frame_step = 128


In [None]:
tf.reset_default_graph()

aTargetModel = SequentialModel(shapeOfInput=(batch_size, window_size), name="Target Model")

with tf.name_scope('Remove_unnecesary_sides_before_stft'):
    signal = aTargetModel.output()
    signal_without_unnecesary_sides = signal[:, 1664:3456]
    aTargetModel.setOutputTo(signal_without_unnecesary_sides)
aTargetModel.addSTFT(frame_length=fft_frame_length, frame_step=fft_frame_step)
aTargetModel.divideComplexOutputIntoRealAndImaginaryParts()  # (256, 11, 257, 2)

aModel = SequentialModel(shapeOfInput=(batch_size, window_size), name="context encoder")

with tf.name_scope('Remove_gap_before_stft'):
    signal = aModel.output()
    left_side = signal[:, :2048]
    right_side = signal[:, 2048+1024:]
    
    # This is strange. The window is 5K samples long, the hole 1024 and the 0 pading 384.
    # Unless signal in in spectrogram. In that case, the code is not very clear. Maybe consider adding comments.
    left_side_padded = tf.concat((left_side, tf.zeros((batch_size, 384))), axis=1)
    right_side_padded = tf.concat((tf.zeros((batch_size, 384)), right_side), axis=1)

    # If you pad them with 0, maybe you also stack them allong axis 2 (one after the other.)
    signal_without_gap = tf.stack((left_side_padded, right_side_padded), axis=1)  # (256, 2, 2432)
    aModel.setOutputTo(signal_without_gap)

aModel.addSTFT(frame_length=fft_frame_length, frame_step=fft_frame_step)  # (256, 2, 16, 257)
aModel.addReshape((batch_size, 32, 257))
aModel.divideComplexOutputIntoRealAndImaginaryParts()  # (256, 32, 257, 2)
aModel.addReshape((batch_size, 16, 257, 4))

with tf.variable_scope("Encoder"):
    filter_shapes = [(7, 89), (3, 17), (2, 11), (1, 9), (1, 5), (2, 5)]
    input_channels = [4, 32, 128, 512, 256, 160]
    output_channels = [32, 128, 512, 256, 160, 128]
    strides = [[1, 2, 2, 1], [1, 2, 3, 1], [1, 2, 3, 1], [1, 1, 2, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
    names = ['First_Conv', 'Second_Conv', 'Third_Conv', 'Fourth_Conv', 'Fifth_Conv', 'Sixth_Conv']
    aModel.addSeveralConvLayers(filter_shapes=filter_shapes, input_channels=input_channels,
                                output_channels=output_channels, strides=strides, names=names)

aModel.addReshape((batch_size, 2048))
aModel.addFullyConnectedLayer(2048, 2048, 'Fully')
aModel.addRelu()
aModel.addBatchNormalization()
aModel.addReshape((batch_size, 8, 8, 32))

with tf.variable_scope("Decoder"):
    filter_shapes = [(8, 8), (5, 5), (3, 3)]
    input_channels = [32, 128, 512]
    output_channels = [128, 512, 257]
    strides = [[1, 2, 2, 1], [1, 2, 2, 1], [1, 1, 1, 1]]
    names = ['First_Deconv', 'Second_Deconv', 'Third_Deconv']
    aModel.addSeveralDeconvLayers(filter_shapes=filter_shapes, input_channels=input_channels,
                                  output_channels=output_channels, strides=strides, names=names)

    aModel.addReshape((batch_size, 8, 257, 128))
    aModel.addDeconvLayer(filter_shape=(5, 67), input_channels=128, output_channels=11, stride=(1, 2, 2, 1),
                          name='Fourth_deconv')
    aModel.addBatchNormalization()

    aModel.addReshape((batch_size, 11, 257, 32))

    aModel.addDeconvLayerWithoutNonLin(filter_shape=(11, 257), input_channels=32, output_channels=2,
                                       stride=(1, 1, 1, 1), name="Last_Deconv")

print(aModel.description())


In [None]:
aContextEncoderNetwork = StftGapContextEncoder(model=aModel, batch_size=batch_size, target_model=aTargetModel, window_size=window_size,
                                               gap_length=gap_length, learning_rate=1e-4, name='nat_stft_gap_big_fma_2_')


In [None]:
pathToDatasetFolder = 'fma-test'
audios = np.zeros((0,8000), dtype=np.float32)
i = 0
total = 0
file_names = []
for file_name in os.listdir(pathToDatasetFolder):
    if file_name.endswith('.mp3'):  
        file_names.append(file_name)
        audio, sr = librosa.load(pathToDatasetFolder + '/' + file_name, sr=None)
        
        if np.sum(np.absolute(audio[ending_sample_left_side:starting_sample_right_side])) < gap_length*1e-3: 
            print(file_name, "doesn't meet the minimum amplitude requirement")
            continue
        if len(audio) < 8000:
            continue
        audios = np.append(audios, [audio[:8000]], 0)
        i+=1
        
        if i > 500:
            i -= 500
            total += 500
            print("500 plus!", total)

print("there were: ", total+i)

print(audios.shape)

In [None]:
window = audios[:, starting_sample_left_side:ending_sample_right_side]
left_side = audios[:, starting_sample_left_side:ending_sample_left_side]
right_side = audios[:, starting_sample_right_side:ending_sample_right_side]
sides = np.concatenate((left_side, right_side), axis=1)
original_gaps = audios[:, ending_sample_left_side:starting_sample_right_side]

In [None]:
batch_count = 39
reconstructed_spec = aContextEncoderNetwork.reconstructAudio(window, best_step, max_batchs=batch_count)

In [None]:
print(reconstructed_spec.shape)

In [None]:
tf_original_stft = tf.contrib.signal.stft(signals=window[:len(reconstructed_spec)], frame_length=fft_frame_length, frame_step=fft_frame_step)

with tf.Session() as sess:
    original_stft = sess.run(tf_original_stft)
    
print(original_stft.shape)

In [None]:
gap_spec = reconstructed_spec[:,:,:,0]+1.0j*reconstructed_spec[:,:,:,1]

reconstructed_spec_window = np.concatenate((original_stft[:, :13, :], 
                                   gap_spec, 
                                   original_stft[:, 24:, :]), axis=1)
print(reconstructed_spec_window.shape)

In [None]:
tf.reset_default_graph()
window_fn = functools.partial(window_ops.hann_window, periodic=True)
inverse_window = tf.contrib.signal.inverse_stft_window_fn(fft_frame_step,
                                           forward_window_fn=window_fn)
rec_stft = reconstructed_spec[:,:,:,0] + 1.0j*reconstructed_spec[:,:,:,1]
ori_stft = original_stft[:, 13:24, :] 
print(rec_stft.shape)
print(ori_stft.shape)

shape = (batch_size, 11, 257)
stft_to_invert = tf.placeholder(tf.complex64, shape=shape, name='stft_to_invert')
tf_reconstructed_signals = tf.contrib.signal.inverse_stft(stfts=stft_to_invert, frame_length=fft_frame_length, frame_step=fft_frame_step,
                                                         window_fn=inverse_window)

reconstructed_signal = np.zeros([0,1792], dtype=float32)
original_signal = np.zeros([0,1792], dtype=float32)
with tf.Session() as sess:
    for i in range(int(len(rec_stft)/batch_size)):
        feed_dict = {stft_to_invert: rec_stft[i*batch_size:(i+1)*batch_size]}
        reconstructed_signal = np.append(reconstructed_signal, sess.run(tf_reconstructed_signals, feed_dict=feed_dict), axis=0)
        
        feed_dict = {stft_to_invert: ori_stft[i*batch_size:(i+1)*batch_size]}
        original_signal = np.append(original_signal, sess.run(tf_reconstructed_signals, feed_dict=feed_dict), axis=0)  

In [None]:
print(reconstructed_signal.shape)
print(original_signal.shape)

In [None]:
reconstructed_gaps = reconstructed_signal[:, 384:-384]
original_gaps = original_signal[:, 384:-384]

In [None]:
print(reconstructed_signal.shape)
print(original_signal.shape)
print(reconstructed_spec.shape)
print(original_stft.shape)
print(len(reconstructed_signal))


In [None]:
def _pavlovs_SNR(y_orig, y_inp, onAxis=(1,)):
    norm_y_orig = _squaredEuclideanNorm(y_orig, onAxis)
    norm_y_orig_minus_y_inp = _squaredEuclideanNorm(y_orig - y_inp, onAxis)
    return 10 * np.log10(norm_y_orig / norm_y_orig_minus_y_inp)

def _squaredEuclideanNorm(vector, onAxis=(1,)):
    squared = np.square(vector)
    summed = np.sum(squared, axis=onAxis)
    return summed


In [None]:
fake_a = reconstructed_gaps
gap = original_gaps[:int(batch_count*batch_size)]

SNRs = _pavlovs_SNR(gap, fake_a)


print(SNRs.shape)
print(SNRs.mean())
print(SNRs.std())
print(SNRs.min())
print(np.percentile(SNRs, [25, 50, 75]))
print(SNRs.max())

In [None]:
left_side = audios[:len(reconstructed_gaps), :ending_sample_left_side]
right_side = audios[:len(reconstructed_gaps), starting_sample_right_side:]
reconstructed_signals = np.concatenate((left_side, reconstructed_gaps, right_side), axis=1)
zeroed_signals = np.concatenate((left_side, (reconstructed_gaps)*0, right_side), axis=1)
reconstructed_original = np.concatenate((left_side, original_gaps[:len(reconstructed_gaps)], right_side), axis=1)

In [None]:
"""Write files to disk"""
maximum = 256

maxv = np.iinfo(np.int16).max
for index in range(min(len(reconstructed_signals), maximum)):
    librosa.output.write_wav("recs/original_" + file_names[index] + ".wav", (audios[index] * maxv).astype(np.int16), sr)
    librosa.output.write_wav("recs/reconstructed_" + file_names[index] + ".wav", (reconstructed_signals[index] * maxv).astype(np.int16), sr)


In [None]:
reconstructed_signal_to_evaluate = 1

In [None]:
f, axarr = plt.subplots(1, 3, sharey='row', figsize=(18, 12))

difference = original_gaps[reconstructed_signal_to_evaluate]-reconstructed_gaps[reconstructed_signal_to_evaluate]

axarr[0].plot(original_gaps[reconstructed_signal_to_evaluate])
axarr[0].set_title('original gap', size=24)
axarr[1].plot(reconstructed_gaps[reconstructed_signal_to_evaluate])
axarr[1].set_title('reconstructed gap', size=24)
axarr[2].plot(difference)
axarr[2].set_title('difference', size=24)

print(np.sum(np.absolute(original_gaps[reconstructed_signal_to_evaluate])))
print(np.absolute(difference).sum())
print(np.linalg.norm(difference))

print('SNR:', _pavlovs_SNR(original_gaps[reconstructed_signal_to_evaluate], reconstructed_gaps[reconstructed_signal_to_evaluate], onAxis=0))



In [None]:
f, axarr = plt.subplots(1, 4, sharey='row', figsize=(24, 12))

original_mag_spec = np.abs(original_stft)
rec_mag_spec = np.abs(reconstructed_spec[:, :, :, 0] + 1.0j*reconstructed_spec[:, :, :, 1])

rec_mag_to_plot = np.transpose(rec_mag_spec[reconstructed_signal_to_evaluate])
window_to_plot = np.transpose(original_mag_spec[reconstructed_signal_to_evaluate])

difference = window_to_plot[:, 13:13+11]-rec_mag_to_plot
print(window_to_plot.shape)
print(rec_mag_to_plot.shape)

print(np.zeros(reconstructed_spec[reconstructed_signal_to_evaluate].shape).shape)
z_min = np.min(window_to_plot)
z_max = np.max(window_to_plot)

axarr[0].pcolormesh(window_to_plot, vmin=z_min, vmax=z_max)
axarr[0].set_title('original', size=24)
axarr[1].pcolormesh(np.concatenate((window_to_plot[:, :13], 
                                   rec_mag_to_plot, 
                                   window_to_plot[:, 13+11:]), axis=1), vmin=z_min, vmax=z_max)
axarr[1].set_title('reconstructed', size=24)
axarr[2].pcolormesh(np.concatenate((window_to_plot[:, :13], 
                                   np.zeros(rec_mag_to_plot.shape), 
                                   window_to_plot[:, 13+11:]), axis=1), vmin=z_min, vmax=z_max)
axarr[2].set_title('zeroed', size=24)
axarr[3].pcolormesh(difference)
axarr[3].set_title('difference', size=24)


In [None]:
reconstructed_signal_to_evaluate = 3581
print('SNR:', _pavlovs_SNR(original_gaps[reconstructed_signal_to_evaluate], reconstructed_gaps[reconstructed_signal_to_evaluate], onAxis=0))

IPython.display.Audio(data=reconstructed_signals[reconstructed_signal_to_evaluate], rate=16000)
# IPython.display.Audio(data=zeroed_signals[reconstructed_signal_to_evaluate], rate=16000)
# IPython.display.Audio(data=audios[reconstructed_signal_to_evaluate], rate=16000)
