In [278]:
from matplotlib import pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, MaxPooling1D, UpSampling1D
from keras.layers.convolutional import Conv1D, Conv2D
import medleydb as mdb
from scipy import signal
from scipy.io import wavfile

# Load the mix

In [279]:
multitracks = mdb.load_multitracks(["Phoenix_ColliersDaughter"])

In [280]:
mix = next(multitracks)

In [281]:
sample_rate, mix_audio = wavfile.read(mix.mix_path)
mix_audio = mix_audio.mean(1)
mix_audio = mix_audio[sample_rate*8:int(-sample_rate*1.5)]

In [282]:
samples_per_period = 256
overlap = int(samples_per_period/2)

In [283]:
freqs, times, s_mix = signal.stft(mix_audio, fs=sample_rate, nfft=samples_per_period, return_onesided=False,
                                         noverlap=overlap, nperseg=samples_per_period)

freqs = freqs[:-1]
s_mix = s_mix[:-1,:]

SyntaxError: invalid syntax (<ipython-input-283-7743a0115cc1>, line 2)

In [None]:
plt.pcolormesh(times, freqs, 20*np.log10(np.abs(s_mix)))

# Load the Flute

In [None]:
flutes = mdb.get_files_for_instrument("flute", [mix])

In [None]:
flute = next(flutes)

In [None]:
sample_rate, flute_audio = wavfile.read(flute)
flute_audio = flute_audio.mean(1)
flute_audio = flute_audio[sample_rate*8:int(-sample_rate*1.5)]

In [None]:
freqs, times, s_flute = signal.stft(flute_audio, fs=sample_rate, nfft=samples_per_period, return_onesided=False,
                                           noverlap=overlap, nperseg=samples_per_period)

freqs = freqs[:-1]
s_flute = s_flute[:-1,:]

In [None]:
plt.pcolormesh(times, freqs, 20*np.log10(np.abs(s_flute)))

## Create a mask for the flute

In [None]:
mask_flute = s_flute / (s_flute + s_mix + 1e-9)

In [None]:
plt.pcolormesh(times, freqs, 20*np.log10(np.abs(mask_flute)))

# Prepare the training data

In [None]:
#s_mix_train = s_mix.reshape(*s_mix.T.shape, 1)
#s_mix_train.shape

In [None]:
num_test  = int((sample_rate * 10)/samples_per_period)
num_train = s_mix.shape[1] - num_test
print(num_train, num_test)

In [None]:
mix_train = s_mix.T[:num_train,:,np.newaxis]
mix_train = np.concatenate((mix_train.real, mix_train.imag), axis=-1)
mix_train.shape

In [None]:
flute_train = mask_flute.T[:num_train,:,np.newaxis]
flute_train = np.concatenate((flute_train.real, flute_train.imag), axis=-1)
flute_train.shape

In [None]:
mix_test = s_mix.T[-num_test:,:,np.newaxis]
mix_test = np.concatenate((mix_test.real, mix_test.imag), axis=-1)
mix_test.shape

In [None]:
flute_test = mask_flute.T[-num_test:,:,np.newaxis]
flute_test = np.concatenate((flute_test.real, flute_test.imag), axis=-1)
flute_test.shape

In [None]:
model = Sequential()

model.add(Conv1D(10, 2, padding="same", input_shape=mix_train.shape[1:], activation="relu", name="Conv1D_1"))

model.add(Conv1D(2, 2, padding="same",  name="Conv1D_2"))

model.summary()

In [None]:
model.compile('adam', loss='mean_squared_error', metrics=['accuracy'])

In [None]:
history = model.fit(mix_train, flute_train, batch_size=200, epochs=10)

In [None]:
results = model.evaluate(mix_test, flute_test)

print(model.metrics_names)
print(results)

# Results

In [None]:
mask_prediction = model.predict(mix_test)
mask_prediction.shape

In [None]:
mask_prediction = np.sqrt((mask_prediction**2).sum(-1)).T

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3,sharey=True,figsize=(12, 4))

ax1.set_title("Input Mixture")
ax2.set_title("Target Flute Mask")
ax3.set_title("Generated Flute Mask")

ax1.pcolormesh(times[-num_test:], freqs, 20*np.log10(np.abs(s_mix[:,-num_test:])))
ax2.pcolormesh(times[-num_test:], freqs, 20*np.log10(np.abs(mask_flute[:,-num_test:])))
ax3.pcolormesh(times[-num_test:], freqs, 20*np.log10(mask_prediction))

In [None]:
target = s_mix[:,-num_test:] * mask_flute[:,-num_test:]

In [None]:
prediction = s_mix[:,-num_test:] * mask_prediction

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(12, 4))

ax1.set_title("Input Mixture")
ax2.set_title("Target")
ax3.set_title("Prediction")

ax1.pcolormesh(times[-num_test:], freqs, 20*np.log10(np.abs(s_mix[:,-num_test:])))
ax2.pcolormesh(times[-num_test:], freqs, 20*np.log10(np.abs(target)))
ax3.pcolormesh(times[-num_test:], freqs, 20*np.log10(np.abs(prediction)))

## Let's hear it

In [None]:
_, predicted_audio = signal.istft(prediction, fs=sample_rate)

In [None]:
wavfile.write("basic_model_flute_prediction.wav", sample_rate, predicted_audio.astype(np.int16))

In [None]:
_, target_audio = signal.istft(target, fs=sample_rate)

In [None]:
wavfile.write("basic_model_flute_target.wav", sample_rate, target_audio.astype(np.int16))

In [None]:
_, mix_audio = signal.istft(s_mix[:,-num_test:], fs=sample_rate)

In [None]:
wavfile.write("basic_model_original.wav", sample_rate, mix_audio.astype(np.int16))