In [1]:
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

# Import external functions from the functions folder
import sys
sys.path.append('../../functions')
import functions as f

In [2]:
# Set local path to the folder containing the .wav audio files
path = 'C:/Users/lucvo/VScode/Machine_learning/Audio_data/nsynth-valid.jsonwav/nsynth-valid/audio/'

In [3]:
# Create a training set of spectrograms and labels
X, y, label_dict = f.create_single_inst_classification_set(N_samples=7_500, classes=['keyboard','bass','flute', 'guitar'], path=path)

# Split the data into training, validation and test sets
X_train, X_val, X_test, y_train, y_val, y_test = f.split_data(X, y, val_frac=0.1, test_frac=0.1)

The dataset might be unbalanced


In [12]:
print(X_train.shape)
print(y_train.shape)
print(label_dict)

(16000, 36765)
(16000,)
{0: 'keyboard', 1: 'bass', 2: 'flute', 3: 'guitar'}


### Use a XGBoost decision tree

In [14]:
import xgboost as xgb
from sklearn.metrics import confusion_matrix

In [15]:
# Use an XGBoost classifier
model = xgb.XGBClassifier(objective='multi:softmax', num_class=len(label_dict), n_estimators=100, max_depth=3, learning_rate=0.1, n_jobs=-1)

# Fit the model and store training accuracy as function of number of trees
eval_set = [(X_train, y_train), (X_val, y_val)]
model.fit(X_train, y_train, eval_metric='merror', eval_set=eval_set, early_stopping_rounds=10, verbose=True)

# Predict the labels of the validation set
y_pred = model.predict(X_val)

# Print the accuracy
accuracy = np.sum(y_pred == y_val) / len(y_val)
print('Accuracy:', accuracy)



KeyboardInterrupt: 

In [None]:
# Plot the confusion matrix
conf_matrix = confusion_matrix(y_val, y_pred)
conf_normed = conf_matrix / np.sum(conf_matrix, axis=1)[:, np.newaxis]
plt.imshow(conf_normed, cmap='Blues')
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('True')
plt.xticks([0, 1, 2, 3], ['bass', 'flute', 'guitar', 'keyboard'])
plt.yticks([0, 1, 2, 3], ['bass', 'flute', 'guitar', 'keyboard'])
plt.title('Confusion matrix')
plt.show()