Imports

In [None]:
from datetime import datetime
import gc
import IPython
import matplotlib.pyplot as plt
import numpy
import os
import random 
import signal
import tensorflow as tf
import tensorflow_io as tfio
import threading
import time
os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async'

Parameters

In [None]:
model_sample_rate = 22050
sliding_window_step = 5
pre_process_chunk_duration_seconds = 10
epochs = 6000

random_speed_multiplier = 0.1

stft_fft_length = 1024
stft_fft_unique_bins = stft_fft_length // 2 + 1
stft_window = 512
stft_step = int(stft_window / 4)

multithreaded_prep_train_batch_size = 40
multithreaded_prep_test_batch_size = 15
multithreaded_prep_chunk_size = 5

model_path = "./model"
losses_path = "./model/losses"

Data set-up

In [None]:
# https://sigsep.github.io/datasets/musdb.html#musdb18-compressed-stems
if __name__ == "__main__":
    import musdb
    db_train = musdb.DB(root="./musdb18", subsets="train")
    db_test = musdb.DB(root="./musdb18", subsets="test")

db_sample_rate = 44100

# https://stackoverflow.com/questions/62558696/how-do-i-re-batch-a-tensor-in-tensorflow
@tf.function(reduce_retracing=True)
def sliding_window(x, axis=0):
    window_size = sliding_window_step
    stride = sliding_window_step
    n_in = tf.shape(x)[axis]
    n_out = (n_in - window_size) // stride + 1
    # Just in case n_in < window_size
    n_out = tf.math.maximum(n_out, 0)
    r = tf.expand_dims(tf.range(n_out), 1)
    idx = r * stride + tf.range(window_size)
    return tf.gather(x, idx, axis=axis)

@tf.function(reduce_retracing=True)
def downsample(audio, source_sample_rate=db_sample_rate):
    return tfio.audio.resample(
        audio, 
        source_sample_rate,
        model_sample_rate
    )

@tf.function(reduce_retracing=True)
def separate_imaginary(stft, return_angle=False):
    abs = tf.math.abs(stft)
    if return_angle:
        return abs, tf.math.angle(stft)
    else:
        return abs, abs

def pre_process(audio, return_angle=False, source_sample_rate=db_sample_rate, is_audio_mono=False):
    audio = audio if is_audio_mono else audio[0,:]
    abs, angle = separate_imaginary(
        get_stft(
            downsample(
                tf.convert_to_tensor(audio, dtype=tf.float32), 
                source_sample_rate=tf.convert_to_tensor(source_sample_rate, dtype=tf.int64)
            )
        ), 
        return_angle=return_angle
    )

    if return_angle:
        return sliding_window(abs), sliding_window(angle)
    else:
        return sliding_window(abs)

def pre_process_track(track):
    track.chunk_duration = pre_process_chunk_duration_seconds
    track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
    sample_rate = int(
        db_sample_rate * random.uniform(
            1 - random_speed_multiplier, 
            1 + random_speed_multiplier
        )
    )
    
    x = pre_process(
        track.targets['linear_mixture'].audio.T,
        source_sample_rate=sample_rate
    )
    y = [
        pre_process(track.targets['drums'].audio.T, source_sample_rate=sample_rate),
        pre_process(track.targets['bass'].audio.T, source_sample_rate=sample_rate),
        pre_process(track.targets['other'].audio.T, source_sample_rate=sample_rate),
        pre_process(track.targets['vocals'].audio.T, source_sample_rate=sample_rate)
    ]

    return x, y

@tf.function
def post_process(magnitudes, angles):
    return tf.reshape(
        tf.complex(
            tf.math.multiply(magnitudes, tf.math.cos(angles)),
            tf.math.multiply(magnitudes, tf.math.sin(angles))
        ), 
        [-1, stft_fft_unique_bins]
    )

def post_process_track(stfts, angles):
    stfts = list(map(
        lambda x: post_process(x, angles),
        stfts
    ))
    return stfts, list(map(
        lambda x: get_inverse_stft(x),
        stfts
    ))

STFT set-up

In [None]:
@tf.function(reduce_retracing=True)
def get_stft(audio):
    return tf.signal.stft(
        audio,
        stft_window,
        stft_step,
        fft_length=stft_fft_length
    )

def get_inverse_stft(stft):
    return tf.signal.inverse_stft(
        stft,
        stft_window,
        stft_step,
        fft_length=stft_fft_length
    ).numpy()

def display_stfts(stfts, names=["Drums", "Bass", "Other", "Vocals"]):
    fig = plt.figure(figsize=(16,4))
    for i, stft in enumerate(stfts):
        fig.add_subplot(1, 4, i + 1)
        plt.imshow(
            numpy.transpose(tf.math.log(1 + tf.abs(stft)).numpy()), 
            aspect='auto', 
            origin='lower'
        )
        plt.title(names[i])
        plt.xlabel('Time frame')
        plt.ylabel('Frequency bin')
    plt.show()

def display_audio(audio, max_seconds = 60, sr = model_sample_rate):
    IPython.display.display(
        IPython.display.Audio(
            numpy.transpose(audio)[:sr * max_seconds], 
            rate=sr
        )
    )

Creating the model

In [None]:
@tf.keras.saving.register_keras_serializable()
class Mask(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs):
        layer_input, model_input, all_dense = inputs
        return tf.math.multiply(
            tf.math.divide(
                layer_input, 
                tf.math.add_n(all_dense)
            ), 
            model_input
        )

def create_model():
    input = tf.keras.layers.Input(shape=(sliding_window_step, stft_fft_unique_bins), name="input_stft")
    
    lstm = tf.keras.layers.LSTM(stft_fft_unique_bins * 2, return_sequences=True, name="1st_lstm")(input)
    lstm = tf.keras.layers.LSTM(stft_fft_unique_bins * 2, return_sequences=True, name="2nd_lstm")(lstm)
    lstm = tf.keras.layers.LSTM(stft_fft_unique_bins * 2, return_sequences=True, name="3rd_lstm")(lstm)
    
    dense_drums = tf.keras.layers.Dense(stft_fft_unique_bins, name="drums_dense")(lstm)
    dense_bass = tf.keras.layers.Dense(stft_fft_unique_bins, name="bass_dense")(lstm)
    dense_other = tf.keras.layers.Dense(stft_fft_unique_bins, name="other_dense")(lstm)
    dense_vocals = tf.keras.layers.Dense(stft_fft_unique_bins, name="vocals_dense")(lstm)
    
    output_drums = Mask(name="drums_output")((dense_drums, input, [dense_drums, dense_bass, dense_other, dense_vocals]))
    output_bass = Mask(name="bass_output")((dense_bass, input, [dense_drums, dense_bass, dense_other, dense_vocals]))
    output_other = Mask(name="other_output")((dense_other, input, [dense_drums, dense_bass, dense_other, dense_vocals]))
    output_vocals = Mask(name="vocals_output")((dense_vocals, input, [dense_drums, dense_bass, dense_other, dense_vocals]))
    
    return tf.keras.Model(input, [output_drums, output_bass, output_other, output_vocals])

if __name__ == "__main__":
    model = create_model()
    
    model.summary()
    
    starting_weights = model.get_weights()

Multithreading set-up

In [None]:
# https://alexandra-zaharia.github.io/posts/how-to-return-a-result-from-a-python-thread/
class ReturnValueThread(threading.Thread):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.result = None

    def run(self):
        if self._target is None:
            return
        self.result = self._target(*self._args, **self._kwargs)

    def join(self, *args, **kwargs):
        super().join(*args, **kwargs)
        return self.result

def pre_process_batch(batch):
    result = []
    for track in batch:
        result.append(pre_process_track(track))
    return result
        
def start_prep_multithreaded():
    gc.collect()

    db_train_subset = random.sample(db_train.tracks, k=multithreaded_prep_train_batch_size)
    db_test_subset = random.sample(db_test.tracks, k=multithreaded_prep_test_batch_size)
    
    train_chunks = [db_train_subset[x:x+multithreaded_prep_chunk_size] for x in range(0, len(db_train_subset), multithreaded_prep_chunk_size)]
    test_chunks = [db_test_subset[x:x+multithreaded_prep_chunk_size] for x in range(0, len(db_test_subset), multithreaded_prep_chunk_size)]

    train_threads = []
    test_threads = []
    
    train_batch = []
    test_batch = []

    start = time.time()

    for chunk in train_chunks:
        thread = ReturnValueThread(
            target=pre_process_batch, 
            args=(chunk,)
        )
        thread.start()
        train_threads.append(thread)

    for chunk in test_chunks:
        thread = ReturnValueThread(
            target=pre_process_batch, 
            args=(chunk,)
        )
        thread.start()
        test_threads.append(thread)

    def join_threads():
        joined_count = 0
        thread_count = len(train_threads) + len(test_threads)
        print(f"\r{joined_count} out of {thread_count} threads finished", end = "")
        while len(train_threads) > 0 or len(test_threads) > 0:
            for i, thread in enumerate(train_threads):
                result = thread.join(0.01)
                if result:
                    joined_count += 1
                    print(f"\r{joined_count} out of {thread_count} threads finished", end = "")
                    train_batch.extend(result)
                    train_threads.remove(thread)
    
            for i, thread in enumerate(test_threads):
                result = thread.join(0.01)
                if result:
                    joined_count += 1
                    print(f"\r{joined_count} out of {thread_count} threads finished", end = "")
                    test_batch.extend(result)
                    test_threads.remove(thread)
        print("")
        return train_batch, test_batch

    return join_threads


Training the model

In [None]:
# Losses
if __name__ == "__main__":
    loss_fn = tf.keras.losses.MeanSquaredError()
    train_acc_metric = tf.keras.metrics.MeanSquaredError()
    test_acc_metric = tf.keras.metrics.MeanSquaredError()

def read_losses(file_name):
    arr = []
    with open(f"{losses_path}/{file_name}.csv", "r") as losses_file:
        for line in losses_file:
            vals = line.split(",")
            arr.append([float(vals[0]), float(vals[1])])
    arr = numpy.transpose(arr)
    return arr[0].tolist(), arr[1].tolist()

In [None]:
# Reset training
if __name__ == "__main__":
    epoch = 1
    train_losses = []
    test_losses = []
    model.set_weights(starting_weights)
    
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=0.001,
        beta_1=0.9,
        beta_2=0.999
    )
    
    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
    
    restore_name = None
    # restore_name = ''
    
    if restore_name:
        checkpoint.restore(f'{model_path}/{restore_name}')
        epoch = int(restore_name.split('_')[1]) + 1
        train_losses, test_losses = read_losses(restore_name)

def print_model_predict(track_index = None, track_start = None): 
    display_track = db_test[track_index] if track_index is not None else random.choice(db_test)
    display_track.chunk_duration = pre_process_chunk_duration_seconds
    display_track.chunk_start = track_start or random.uniform(0, display_track.duration - display_track.chunk_duration)
    
    input, angle = pre_process(display_track.audio.T, return_angle=True)
    stft_input = post_process(input, angle)
    results = model.predict(input)
    
    stft_drums = post_process(results[0], angle)
    stft_bass = post_process(results[1], angle)
    stft_other = post_process(results[2], angle)
    stft_vocals = post_process(results[3], angle)
    
    display_stfts([stft_drums, stft_bass, stft_other, stft_vocals])
    print("Original: ")
    display_audio(get_inverse_stft(stft_input))
    print("Drums: ")
    display_audio(get_inverse_stft(stft_drums))
    print("Bass: ")
    display_audio(get_inverse_stft(stft_bass))
    print("Other: ")
    display_audio(get_inverse_stft(stft_other))
    print("Vocals: ")
    display_audio(get_inverse_stft(stft_vocals))

In [None]:
@tf.function(reduce_retracing=True)
def train_step(x, y):
    y = tf.convert_to_tensor(y)
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        logits = tf.convert_to_tensor(logits)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function(reduce_retracing=True)
def test_step(x, y):
    logits = model(x, training=False)
    test_acc_metric.update_state(y, logits)

if __name__ == "__main__":
    gc.collect()
    
    join_threads = start_prep_multithreaded()
    
    while epoch <= epochs:
        print(f"\nStart of epoch {epoch}")
        start = time.time()
    
        train_batch, test_batch = join_threads()
        join_threads = start_prep_multithreaded()
        random.shuffle(train_batch)
        
        for step, track in enumerate(train_batch):
            
            train_x, train_y = track
            loss_value = train_step(train_x, train_y)
            
            print(f"\rTraining step {step + 1}. Loss: {float(loss_value):.4f} Time taken (seconds): {float(time.time()-start):.2f}", end="")
    
            if tf.math.is_nan(loss_value): 
                raise ValueError(f'loss_value became NaN during epoch {epoch}, step {step + 1} at {datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}')
        
        train_acc = train_acc_metric.result()
        train_losses.append(train_acc)
        print("\nTraining loss over epoch: %.4f" % (train_losses[-1],), end="")
        train_acc_metric.reset_states()
    
        for step, track in enumerate(test_batch):
            test_x, test_y = track
            test_step(test_x, test_y)
    
        test_acc = test_acc_metric.result()
        test_losses.append(test_acc)
        print("\tTest loss: %.4f" % (float(test_acc),))    
        test_acc_metric.reset_states()
    
        if epoch % 40 == 0:
            new_filename = f'epoch_{epoch}_with_{optimizer.name}-{optimizer.learning_rate.numpy():.3f}_at_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
            checkpoint.save(f"{model_path}/{new_filename}")
            with open(f"{losses_path}/{new_filename}-{checkpoint.save_counter.numpy()}.csv", "w") as losses_file:
                for line in numpy.transpose([train_losses, test_losses]):
                    losses_file.write(f"{line[0]},{line[1]}\n")
        
        if epoch % 20 == 0:
            print_model_predict()
        
        epoch += 1

In [None]:
if __name__ == "__main__":
    model.save('./model_6000.keras')

In [None]:
if __name__ == "__main__":
    for i in range(4):
        print_model_predict(track_index = i, track_start = 30)