In [None]:
import os
import sys
import warnings
import itertools
import numpy as np
import pandas as pd
import IPython.display as ipd
from sklearn.linear_model import Lasso, LinearRegression, Ridge
import scipy.spatial.distance as dist
import time

import librosa
from spleeter.separator import Separator
from spleeter.audio import STFTBackend
from spleeter.audio.adapter import AudioAdapter

import tensorflow as tf
from tensorflow import keras


In [None]:
# constants
data_base_path = '/data'
model_path = os.path.join(data_base_path, 'logs','DropOriginalMultiChannelParallelCRNN','fma_medium_SpleeterGPUPreprocessor_spleeter:4stems_keepOriginal_LibrosaCPUSTFT','12-03-2022-14-05-56','trained_model')
metadata_file = os.path.join(data_base_path, 'fma_metadata', 'tracks.csv')
audio_data_path = os.path.join(data_base_path, 'fma_medium')
# Init Spleeter Separator
seperator = Separator("spleeter:4stems", STFTBackend.TENSORFLOW, multiprocess=False)

audio_loader = AudioAdapter.default()
sample_rate = 44100

In [None]:
 # Get samples
# List of songs empty or shorter songs
# Source: https://github.com/mdeff/fma/wiki#excerpts-shorter-than-30s-and-erroneous-audio-length-metadata (last visit 07.02.2022)
excluded_shorter_tracks = [
    "1486",
    "5574",
    "65753",
    "80391",
    "98558",
    "98559",
    "98560",
    "98565",
    "98566",
    "98567",
    "98568",
    "98569",
    "98571",
    "99134",
    "105247",
    "108924",
    "108925",
    "126981",
    "127336",
    "133297",
    "143992",
]
# Load CSV and take medium subset and exclude faulty tracks
df = pd.read_csv(metadata_file, index_col=0, header=[0, 1])
df = df[df[("set", "subset")].isin(["small","medium"])]
df = df[~df.index.isin(excluded_shorter_tracks)]

# Generate genre dictionary
genres_dict = {}
genres = df[("track", "genre_top")].dropna().unique()
for label in genres:
    genres_dict.update({ label: len(genres_dict) })
print(genres_dict)
    
# Filter Test subset
df = df[df[("set", "split")] == "test"]
df = df[[("track", "genre_top")]]

df_hip_hop = df[df[("track", "genre_top")] == 'Hip-Hop']
df_rock = df[df[("track", "genre_top")] == 'Rock']
df_pop = df[df[("track", "genre_top")] == 'Pop']
df_jazz = df[df[("track", "genre_top")] == 'Jazz']
df_electronic = df[df[("track", "genre_top")] == 'Electronic']

#print(df_hip_hop)
#print(df_rock)
#print(df_pop)
#print(df_jazz)
#print(df_electronic)

example_hip_hop_id = 14360
example_rock_id = 152671
example_pop_id = 12387
# example_jazz_id = 67045 gutes example für mögliche vertauschung mit classic
# example_jazz_id = 47504 beispiel für jazz mit ende als rock durch verzerrte gitarre und pushing beat
example_jazz_id = 12374
example_electronic_id = 151534

example_hip_hop_path = os.path.join(audio_data_path, "{:06d}".format(example_hip_hop_id)[:3], "{:06d}".format(example_hip_hop_id) + ".mp3")
example_rock_path = os.path.join(audio_data_path, "{:06d}".format(example_rock_id)[:3], "{:06d}".format(example_rock_id) + ".mp3")
example_pop_path = os.path.join(audio_data_path, "{:06d}".format(example_pop_id)[:3], "{:06d}".format(example_pop_id) + ".mp3")
example_jazz_path = os.path.join(audio_data_path, "{:06d}".format(example_jazz_id)[:3], "{:06d}".format(example_jazz_id) + ".mp3")
example_electronic_path = os.path.join(audio_data_path, "{:06d}".format(example_electronic_id)[:3], "{:06d}".format(example_electronic_id) + ".mp3")

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    hiphop_audio, sr = librosa.load(example_hip_hop_path, mono=False, sr=44100, duration=30.0)

source_separation = seperator.separate(hiphop_audio.reshape(hiphop_audio.shape[1],hiphop_audio.shape[0]), "")

print("Full mix")
ipd.display(ipd.Audio(hiphop_audio, rate=sr))

for key, value in source_separation.items():
    print(key)
    ipd.display(ipd.Audio(value.reshape(hiphop_audio.shape), rate=sr))

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    rock_audio, sr = librosa.load(example_rock_path, mono=False, sr=44100, duration=30.0)

source_separation = seperator.separate(rock_audio.reshape(rock_audio.shape[1],rock_audio.shape[0]), "")

print("Full mix")
ipd.display(ipd.Audio(rock_audio, rate=sr))

for key, value in source_separation.items():
    print(key)
    ipd.display(ipd.Audio(value.reshape(rock_audio.shape), rate=sr))

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    electronic_audio, sr = librosa.load(example_electronic_path, mono=False, sr=44100, duration=30.0)

source_separation = seperator.separate(electronic_audio.reshape(electronic_audio.shape[1],electronic_audio.shape[0]), "")

print("Full mix")
ipd.display(ipd.Audio(electronic_audio, rate=sr))

for key, value in source_separation.items():
    print(key)
    ipd.display(ipd.Audio(value.reshape(electronic_audio.shape), rate=sr))

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    pop_audio, sr = librosa.load(example_pop_path, mono=False, sr=44100, duration=30.0)
ipd.Audio(pop_audio, rate=sr)

source_separation = seperator.separate(pop_audio.reshape(pop_audio.shape[1],pop_audio.shape[0]), "")

print("Full mix")
ipd.display(ipd.Audio(pop_audio, rate=sr))

for key, value in source_separation.items():
    print(key)
    ipd.display(ipd.Audio(value.reshape(pop_audio.shape), rate=sr))

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    jazz_audio, sr = librosa.load(example_jazz_path, mono=False, sr=44100, duration=30.0)
ipd.Audio(jazz_audio, rate=sr)

source_separation = seperator.separate(jazz_audio.reshape(jazz_audio.shape[1],jazz_audio.shape[0]), "")

print("Full mix")
ipd.display(ipd.Audio(jazz_audio, rate=sr))

for key, value in source_separation.items():
    print(key)
    ipd.display(ipd.Audio(value.reshape(jazz_audio.shape), rate=sr))

In [None]:
# Deathweight example
# Darkness
deathweigth_path = os.path.join('/data','Deathweight', 'Deathweight-Darkness.mp3')
# Volition
#deathweigth_path = os.path.join('/data','Deathweight', 'Deathweight-Volition.mp3')

#deathweigth_path = os.path.join('/data','Deathweight', 'Megadeth-Tornado_Of_Souls.mp3')

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    deathweigth_audio, sr = librosa.load(deathweigth_path, mono=False, sr=44100, duration=30.0)

source_separation = seperator.separate(deathweigth_audio.reshape(deathweigth_audio.shape[1],deathweigth_audio.shape[0]), "")

print("Full mix")
ipd.display(ipd.Audio(deathweigth_audio, rate=sr))

for key, value in source_separation.items():
    print(key)
    ipd.display(ipd.Audio(value.reshape(deathweigth_audio.shape), rate=sr))

In [None]:
file = example_hip_hop_path
label = 'Hip-Hop'
#file = example_rock_path
#label = 'Rock'
#file = example_pop_path
#label = 'Pop'
#file = example_jazz_path
#label = 'Jazz'
#file = example_electronic_path
#label = 'Electronic'

#file = deathweigth_path
#label = 'Rock'

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    audio, sr = librosa.load(file, mono=False, sr=44100, duration=30.0)

# If loaded audio is only mono duplicate channel
if audio.shape[0] != 2:
    audio = np.array([
        audio,
        audio
    ])

if audio.shape[1] < 30*44100:
    audio = np.array([
        np.pad(audio[0], (0, 30*44100 - audio.shape[1]), mode='constant', constant_values=0),
        np.pad(audio[1], (0, 30*44100 - audio.shape[1]), mode='constant', constant_values=0)
    ])


# Slice audio in 10 second parts
slice_length = 10 # seconds
surplus = audio.shape[1] % (44100 * slice_length)
number_of_slices = audio.shape[1] // (44100 * slice_length)
audio_slices = np.array([
    np.array_split(audio[0][surplus:], number_of_slices),
    np.array_split(audio[1][surplus:], number_of_slices)
])

# Prepare slices
prepared_slices = []
for index in range(audio_slices.shape[1]):
    prepared_slices.append(
        np.array([
            audio_slices[0][index],
            audio_slices[1][index],
        ])
    )


# Compute source separation
source_seperation_slices = []
for prepared_slice in prepared_slices:
    prepared_slice = prepared_slice.reshape(prepared_slice.shape[1],prepared_slice.shape[0])
    source_seperation_slices.append(seperator.separate(prepared_slice, ""))

# helper function
def to_mono(waveform):
    if waveform.shape[0] != 2:
        waveform = waveform.reshape( 
            (waveform.shape[1], waveform.shape[0])
        )
    return librosa.to_mono(waveform)

# Compute spectrograms
spectrogram_slices = []
for source_seperation_slice in source_seperation_slices:
    temp = {}
    for key, prediction in source_seperation_slice.items():
        temp[key] = librosa.power_to_db(
            librosa.feature.melspectrogram(
                to_mono(prediction),
                sr=44100,
                n_fft=2048,
                hop_length=1024
            ),
            ref=np.max
        )
    spectrogram_slices.append(temp)

# Prepare data for prediction
x_test = []
for spectrogram_slice in spectrogram_slices:
    x_test.append(np.array(list(spectrogram_slice.values()))[:, : ,0:-1])

print('Preprocessing finished')

In [None]:
# 4 Stem no original
model = tf.keras.models.load_model(model_path, custom_objects={'tf': tf})

In [None]:
# Check if model is loaded correctly
# model.summary()
# Predict
for x in x_test:
    model_prediction = model.predict(np.array([x]))
    # Get most confident value
    predictet_label = np.argmax(np.squeeze(model_prediction))
    print(f'Model predictet {list(genres_dict.keys())[list(genres_dict.values()).index(predictet_label)]}')

In [None]:
from sklearn.linear_model import Ridge

# Simple game theory channel wise permutation interpretation
l = [0,1]
x_ = list(itertools.product(l, repeat=4))
x_.remove((0,0,0,0))
x_.remove((1,1,1,1))

local_group = []
for index, x in enumerate(x_test):
    print(f'Slice number {index + 1}')
    Z_y = []
    
    # Get prediction for full input
    full_prediction = model.predict(np.array([x]))
    full_predictet_confidence = np.amax(np.squeeze(full_prediction))
    full_predictet_label = list(
        genres_dict.keys())[
        list(genres_dict
             .values())
        .index(
            np.argmax(
                np.squeeze(full_prediction)
            )
        )
    ]
    print(f'Predicted {full_predictet_label} with confidence of {full_predictet_confidence}')
    
    # Create Sample and get label for prediction
    for permutation in list(x_):
        z = np.copy(x)
        for index, value in enumerate(permutation):
            if value == 0:
                z[index] = np.full((z[index].shape[0], z[index].shape[1]), -80)
        
        model_prediction = model.predict(np.array([z]))
        
        # Get most confident value
        predictet_label = np.argmax(np.squeeze(model_prediction))
        Z_y.append(np.squeeze(model_prediction)[genres_dict[label]])
        print(f'Model predictet {list(genres_dict.keys())[list(genres_dict.values()).index(predictet_label)]} for permutation {permutation}; Target label with {round(np.squeeze(model_prediction)[genres_dict[label]], 2)}')
    
    linear_reg_clf = LinearRegression()
    linear_reg_clf.fit(x_, Z_y)
    local_group.append(linear_reg_clf.coef_)


In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

channel = ["vocals", "drums", "bass", "other"]

cmap = mpl.cm.viridis
norm = mpl.colors.Normalize(vmin=0, vmax=1)

fig, ax = plt.subplots()
for index, line in enumerate(local_group):
    ax.plot(channel, line, label=f'{time.strftime("%M:%S", time.gmtime((index * 10)))} - {time.strftime("%M:%S", time.gmtime((index + 1)* 10))}')
ax.legend()

plt.show()