In [None]:
import numpy as np
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
from scipy.signal import butter, filtfilt

from sklearn.cluster import KMeans
import scipy
import glob
import sklearn
from sklearn import cluster
from sklearn.metrics import silhouette_samples, silhouette_score

plt.rcParams['figure.figsize'] = (14, 4)

## Preliminaries

def suppress_vocals(audio_file, bass_cutoff=300, sr=22050):
    y, sr = librosa.load(audio_file, sr=sr, mono=False)
    # print(sr)
    cutoff_freq = bass_cutoff

    # Get the filter coefficients for a Butterworth highpass filter
    nyquist_freq = 0.5 * sr
    cutoff_normalized = cutoff_freq / nyquist_freq
    b, a = butter(1, cutoff_normalized, btype='highpass')

    # Apply the filter to the audio signal
    y_filtered = filtfilt(b, a, y)

    bass = y - y_filtered

    # Separate left and right channels
    left_channel = y_filtered[0]
    right_channel = y_filtered[1]

    # Subtract higher frequency signals from left and right channels
    vocals_removed_left = left_channel - right_channel
    vocals_removed_right = right_channel - left_channel

    # # Add the bass signal back to the instrumental signal
    instrumental_audio_left = vocals_removed_left + bass[0]
    instrumental_audio_right = vocals_removed_right + bass[1]
    # instrumental_audio = vocals_removed_left + bass[0]

    instrumental_audio = np.vstack([instrumental_audio_left, instrumental_audio_right])

    # instrumental_audio, _ = librosa.effects.hpss(instrumental_audio)

    return librosa.to_mono(instrumental_audio), sr

def extract_features(y, sr):
    zcr = librosa.zero_crossings(y).sum()

    energy = scipy.linalg.norm(y)
    spec_cent = librosa.feature.spectral_centroid(y=y,sr=sr,n_fft=64)[0,0]
    return [zcr, energy,spec_cent]

def extract_and_normalize_features(y, sr=22050):
    onset_frames = librosa.onset.onset_detect(y=y, sr=sr, delta=0.04, wait=4)
    onset_times = librosa.frames_to_time(onset_frames, sr=sr)
    onset_samples = librosa.frames_to_samples(onset_frames)
    # print(onset_samples)
    frame_sz = sr * 0.020
    list = []
    for i in onset_samples:
        slice = y[i : i + int(frame_sz)]
        # window = np.hamming(len(slice))
        list.append(extract_features(slice, sr))
    features = np.array(list)
    min_max_scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(-1, 1))
    features_scaled = min_max_scaler.fit_transform(features)
    # print(features)
    return features_scaled

def plot_features(features):
    plt.scatter(features[:,0], features[:,2])
    plt.xlabel('Zero Crossing Rate (scaled)')
    plt.ylabel('Spectral Centroid (scaled)')
    plt.show()

def plot_kmeans_for_different_values_directly(data, max_number_of_instruments=5):
    inertias = []
    silhouette_avg = []

    for i in range(2,max_number_of_instruments):
        kmeans = KMeans(n_clusters=i)
        kmeans.fit(data)
        inertias.append(kmeans.inertia_)
        silhouette_avg.append(silhouette_score(data, kmeans.labels_))

    plt.plot(range(2,max_number_of_instruments), inertias, marker='o')
    plt.title('Elbow method')
    plt.xlabel('Number of clusters')
    plt.ylabel('Inertia')
    plt.show()

    plt.plot(range(2,max_number_of_instruments),silhouette_avg)
    plt.xlabel('Values of K') 
    plt.ylabel('Silhouette score') 
    plt.title('Silhouette analysis For Optimal k')
    plt.show()

def predict_number_of_instruments(data, max_number_of_instruments=6, threshold=1.5):
    inertias = []
    silhouette_avg = []

    for i in range(2,max_number_of_instruments):
        kmeans = KMeans(n_clusters=i)
        kmeans.fit(data)
        inertias.append(kmeans.inertia_)
        silhouette_avg.append(silhouette_score(data, kmeans.labels_))

    if (inertias[0] / inertias[1]) < threshold: return 1
    return np.argmax(silhouette_avg) + 2

    # model = sklearn.cluster.AffinityPropagation(preference=-1.5, damping=0.9)
    # labels = model.fit_predict(features_scaled)
    # return np.max(labels) - 1
    

def find_true_number_of_instruments(audio_file):
    split_file = audio_file.split('.')
    split_file[-1] = 'txt'
    txt_file = '.'.join(split_file)

    non_blank_count = sum(1 for line in open(txt_file) if line.strip())

    return non_blank_count

## Testing

audio_file = ''

y, sr = suppress_vocals(audio_file)

Audio(data=y, rate=sr)

test_files = glob.glob(' ')




# Testin for case when there is more than one instrument present in audio excerpt
correct_count = 0
total_count = 0
print(len(test_files))
X = []
Y = []
for test_file in test_files:
    true = find_true_number_of_instruments(test_file)
    y, sr = suppress_vocals(test_file, sr=22050)
    features_scaled = extract_and_normalize_features(y, sr)
    X.append(features_scaled)
    pred = predict_number_of_instruments(features_scaled)
    Y.append(true)
    total_count += 1
    if true == pred: correct_count +=1
