In [49]:
import numpy as np
import pandas as pd
import scipy.io.wavfile as wav
from python_speech_features import mfcc
from tempfile import TemporaryFile
import os
import math
import pickle
import random
import operator

In [50]:
#define a function to get distance between feature vectors and find neighbors
def getNeighbors(trainingset, instance, k):
    distances = []
    for x in range(len(trainingset)):
        dist = distance(trainingset[x], instance, k) + distance(instance, trainingset[x], k)
        distances.append((trainingset[x][2], dist))
    distances.sort(key=operator.itemgetter(1))
    neighbors = []
    for x in range(k):
        neighbors.append(distances[x][0])
    return neighbors

In [51]:
#function to identify the nearest neighbors
def nearestClass(neighbors):
    classVote = {}

    for x in range(len(neighbors)):
        response = neighbors[x]
        if response in classVote:
            classVote[response] += 1
        else:
            classVote[response] = 1
    sorter = sorted(classVote.items(), key=operator.itemgetter(1), reverse=True)
    return sorter[0][0]
        

In [52]:
#function that evaluates model and checks accuracy and performance of the algorithm
def getAccuracy(testSet, prediction):
    correct = 0
    for x in range(len(testSet)):
        if testSet[x][-1] == prediction[x]:
            correct += 1
    return 1.0 * correct / len(testSet)

In [53]:
# Change the file path to your file location in the GTZAN folder
directory = "C:\\Users\\omega\\Desktop\\MusicData\\genres_original"
# mydataset.dat was in the same folder as this ipynb file
f = open("mydataset.dat", "wb")
i = 0
for folder in os.listdir(directory):
    #print(folder)
    i += 1
    if i == 11:
        break
    for file in os.listdir(directory+"/"+folder):
        # print(file)
        try:
            (rate, sig) = wav.read(directory+"/"+folder+"/"+file)
            mfcc_feat = mfcc(sig, rate, winlen = 0.020, appendEnergy=False)
            covariance = np.cov(np.matrix.transpose(mfcc_feat))
            mean_matrix = mfcc_feat.mean(0)
            feature = (mean_matrix, covariance, i)
            pickle.dump(feature, f)
        except Exception as e:
            print("Got an exception: ", e, 'in folder: ', folder, ' filename: ', file)
f.close()

Got an exception:  File format b'\xcb\x15\x1e\x16' not understood. Only 'RIFF' and 'RIFX' supported. in folder:  jazz  filename:  jazz.00054.wav


In [54]:
dataset = []

def loadDataset(filename, split, trset, teset):
    with open('mydataset.dat', 'rb') as f:
        while True:
            try:
                dataset.append(pickle.load(f))
            except EOFError:
                f.close()
                break

    for x in range(len(dataset)):
        if random.random() < split:
            trset.append(dataset[x])
        else:
            teset.append(dataset[x])
trainingSet = []
testSet = []
loadDataset('mydataset.dat', 0.68, trainingSet, testSet)

In [57]:
# Calculating the distance between 2 instances (points)
def distance(instance1, instance2, k):
    distance = 0
    mm1 = instance1[0]
    cm1 = instance1[1]
    mm2 = instance2[0]
    cm2 = instance2[1]
    distance = np.trace(np.dot(np.linalg.inv(cm2), cm1))
    distance += (np.dot(np.dot((mm2-mm1).transpose(), np.linalg.inv(cm2)), mm2-mm1))
    distance += np.log(np.linalg.det(cm2)) - np.log(np.linalg.det(cm1))
    distance -= k
    return distance

In [58]:
# Make the prediction using KNN
length = len(testSet)
predictions = []
for x in range(length):
    predictions.append(nearestClass(getNeighbors(trainingSet, testSet[x], 5)))

accuracy1 = getAccuracy(testSet, predictions)
print(accuracy1)

0.7102803738317757


In [70]:
from collections import defaultdict
results = defaultdict(int)

directory = "C:\\Users\\omega\\Desktop\\MusicData\\genres_original"

i = 1
for folder in os.listdir(directory):
    results[i] = folder
    i += 1

pred = nearestClass(getNeighbors(dataset, feature, 200))
# print("Dataset ", dataset)
print("Feature ", feature)
print(results[pred])

Feature  (array([ 65.8878602 ,   1.20243799, -15.65300447,   9.13743547,
        -5.63116739,   6.63157438,  -9.17150784,   5.31948596,
       -12.85501962,   9.33670177,  -5.47233043,   6.3152848 ,
       -15.40720607]), array([[ 82.45208072, -78.48347185, -65.79196127,   8.14888424,
        -13.24404467,  -1.47596162, -28.76394202,   3.17409845,
         -7.7266519 ,  15.58920197,  -6.9924803 ,  12.46773031,
         17.46582056],
       [-78.48347185, 103.32137501,  42.81593589, -26.55003983,
          1.35737792, -19.29198704,  36.19919277,  -5.41647976,
          8.18334609,  -8.77736842,  20.43782848, -14.31894329,
        -23.33771922],
       [-65.79196127,  42.81593589, 114.56661863,  13.03672767,
         -2.30504257, -15.73869831,   4.21668427,  -9.60216018,
        -10.77743446, -13.29177223, -14.32838794,  -5.9586291 ,
         -4.17779231],
       [  8.14888424, -26.55003983,  13.03672767,  69.95900415,
          9.81494532,  -1.74764621, -44.6233195 ,  -3.74868597,
     