In [1]:
import pandas as pd
import numpy as np
import datetime as dt

import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *#avg, count, expr
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler, StandardScaler, RobustScaler,\
    IndexToString, PCA, UnivariateFeatureSelector
from pyspark.ml.classification import *
from pyspark.ml.evaluation import *
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

import mlflow
#from mlflow import pyspark
from mlflow.tracking import MlflowClient

In [2]:
# initialize
sc = pyspark.SparkContext()
spark = SparkSession(sc)
spark.sparkContext.appName = 'classifyHW'
# show the number of cores
print('%d cores'%spark._jsc.sc().getExecutorMemoryStatus().keySet().size())
spark

1 cores


In [3]:
# load the data
fil = '../../data/beatsdataset.csv'
beats = spark.read.format('csv').options(header=True, inferSchema=True).load(fil).drop('_c0')
print('%d records'%beats.count())
display(beats.limit(10).toPandas())
beats.printSchema()

2300 records


Unnamed: 0,1-ZCRm,2-Energym,3-EnergyEntropym,4-SpectralCentroidm,5-SpectralSpreadm,6-SpectralEntropym,7-SpectralFluxm,8-SpectralRolloffm,9-MFCCs1m,10-MFCCs2m,11-MFCCs3m,12-MFCCs4m,13-MFCCs5m,14-MFCCs6m,15-MFCCs7m,16-MFCCs8m,17-MFCCs9m,18-MFCCs10m,19-MFCCs11m,20-MFCCs12m,21-MFCCs13m,22-ChromaVector1m,23-ChromaVector2m,24-ChromaVector3m,25-ChromaVector4m,26-ChromaVector5m,27-ChromaVector6m,28-ChromaVector7m,29-ChromaVector8m,30-ChromaVector9m,31-ChromaVector10m,32-ChromaVector11m,33-ChromaVector12m,34-ChromaDeviationm,35-ZCRstd,36-Energystd,37-EnergyEntropystd,38-SpectralCentroidstd,39-SpectralSpreadstd,40-SpectralEntropystd,41-SpectralFluxstd,42-SpectralRolloffstd,43-MFCCs1std,44-MFCCs2std,45-MFCCs3std,46-MFCCs4std,47-MFCCs5std,48-MFCCs6std,49-MFCCs7std,50-MFCCs8std,51-MFCCs9std,52-MFCCs10std,53-MFCCs11std,54-MFCCs12std,55-MFCCs13std,56-ChromaVector1std,57-ChromaVector2std,58-ChromaVector3std,59-ChromaVector4std,60-ChromaVector5std,61-ChromaVector6std,62-ChromaVector7std,63-ChromaVector8std,64-ChromaVector9std,65-ChromaVector10std,66-ChromaVector11std,67-ChromaVector12std,68-ChromaDeviationstd,69-BPM,70-BPMconf,71-BPMessentia,class
0,0.13644,0.088861,3.201201,0.262825,0.249212,1.114423,0.007003,0.256682,-22.723259,1.594074,0.011276,0.204468,0.042072,0.048552,0.158505,0.118984,-0.147956,-0.186152,-0.026418,-0.007264,-0.0179,0.011581,0.008747,0.041081,0.014497,0.025711,0.012587,0.06017,0.002864,0.004631,0.009576,0.026079,0.004161,0.032185,0.050143,0.047313,0.102995,0.041285,0.017725,0.414831,0.005867,0.133778,0.838302,0.505911,0.356206,0.336074,0.288888,0.278649,0.283437,0.300305,0.287688,0.296692,0.258531,0.238352,0.194701,0.013138,0.011665,0.032049,0.015464,0.020453,0.012943,0.046397,0.003431,0.004981,0.010818,0.024001,0.005201,0.015056,133.333333,0.132792,128.0,BigRoom
1,0.117039,0.108389,3.194001,0.247657,0.250288,1.065668,0.005387,0.199821,-21.775871,1.261364,-0.113015,0.001718,-0.052682,0.20413,0.153013,0.067214,-0.013227,-0.05944,-0.008604,0.114257,0.171009,0.006535,0.002646,0.086485,0.008391,0.016442,0.009006,0.087948,0.002472,0.006549,0.007412,0.015386,0.005978,0.041116,0.043713,0.043721,0.099449,0.039386,0.018946,0.407164,0.003613,0.110334,0.624185,0.476993,0.353151,0.33555,0.283832,0.269621,0.24415,0.24666,0.25719,0.272036,0.269477,0.222393,0.187471,0.006761,0.003152,0.058923,0.009012,0.016106,0.009386,0.071726,0.004461,0.006441,0.007469,0.015499,0.005589,0.019339,120.0,0.112767,126.0,BigRoom
2,0.085308,0.128525,3.123837,0.217205,0.228652,0.789647,0.008247,0.156822,-22.472722,1.425185,0.186749,0.417114,0.076406,0.190803,-0.016302,0.075038,0.10787,0.216874,0.095604,0.020977,-0.037011,0.007143,0.00296,0.220526,0.005639,0.010151,0.007453,0.043907,0.00124,0.004347,0.007989,0.017622,0.002636,0.066049,0.03292,0.037618,0.117704,0.041509,0.022645,0.34013,0.007697,0.085784,1.02874,0.449133,0.297935,0.266731,0.258299,0.275012,0.218368,0.19839,0.210177,0.212533,0.204458,0.197634,0.16491,0.007836,0.003079,0.093865,0.005692,0.008212,0.005451,0.0429,0.001529,0.004556,0.007723,0.017482,0.002901,0.022201,133.333333,0.123373,129.0,BigRoom
3,0.10305,0.167042,3.15083,0.233593,0.245032,0.967082,0.006571,0.168083,-21.470751,1.463686,0.226548,0.404531,0.117699,0.081861,0.053974,0.164865,0.014919,0.11709,0.027778,-0.063173,-0.052606,0.010724,0.00334,0.125459,0.005728,0.014695,0.006322,0.072154,0.001628,0.003493,0.011463,0.032204,0.004738,0.046159,0.036349,0.06196,0.134908,0.032564,0.020036,0.365068,0.005215,0.086336,0.769981,0.425496,0.245312,0.260132,0.22422,0.207597,0.199472,0.207818,0.189912,0.185509,0.187273,0.177629,0.16474,0.00833,0.003528,0.061426,0.005443,0.012382,0.004985,0.057999,0.001591,0.003514,0.009477,0.023162,0.004165,0.015379,133.333333,0.158876,129.0,BigRoom
4,0.15173,0.148405,3.194498,0.29373,0.267231,1.353005,0.003872,0.292055,-21.371157,1.187854,0.184415,0.363724,0.232119,0.112277,0.107335,0.159296,0.067213,-0.018713,0.091529,0.117344,0.091616,0.009624,0.004031,0.076133,0.008175,0.016267,0.009927,0.088364,0.002645,0.004054,0.011083,0.023926,0.002248,0.036761,0.055214,0.041139,0.122271,0.036637,0.017732,0.45874,0.0029,0.148464,1.023154,0.431075,0.352099,0.327842,0.23718,0.230675,0.236538,0.25721,0.27038,0.242086,0.229678,0.211439,0.179589,0.01086,0.00539,0.046999,0.008598,0.015579,0.00918,0.069485,0.003945,0.004131,0.01133,0.028188,0.002639,0.019079,133.333333,0.190708,129.0,BigRoom
5,0.127047,0.153488,3.221987,0.261693,0.257361,1.090034,0.004943,0.230099,-21.234846,1.541917,0.049064,0.194576,0.063895,0.058361,0.050883,0.071518,0.006431,0.011621,0.068274,0.130295,0.104191,0.009467,0.005906,0.079738,0.01158,0.014767,0.012332,0.078991,0.002105,0.005256,0.012025,0.031829,0.002447,0.041918,0.05022,0.048121,0.073129,0.043929,0.019323,0.461871,0.003707,0.138257,0.872498,0.418116,0.323801,0.332065,0.240972,0.212895,0.224638,0.225902,0.236724,0.227541,0.2327,0.216201,0.193793,0.009599,0.008033,0.058464,0.011039,0.011853,0.01143,0.059939,0.002986,0.006533,0.010347,0.025008,0.003035,0.019479,133.333333,0.168933,129.0,BigRoom
6,0.123395,0.106206,3.167861,0.245459,0.245342,1.107316,0.008419,0.217213,-22.683738,1.405072,0.118772,0.646519,0.090651,0.174564,0.053183,0.024106,-0.045683,0.153718,0.126856,0.027705,-0.152214,0.014485,0.002422,0.092604,0.010015,0.021776,0.0064,0.044038,0.003403,0.007906,0.013052,0.022744,0.002966,0.034496,0.038486,0.05634,0.121416,0.036895,0.019332,0.368052,0.007098,0.100009,0.861205,0.553282,0.28581,0.265964,0.253458,0.214164,0.214898,0.208218,0.205747,0.220657,0.234514,0.19475,0.176016,0.016257,0.002349,0.051198,0.0128,0.019958,0.005169,0.040698,0.003394,0.00697,0.011462,0.019006,0.00279,0.01333,133.333333,0.259905,133.0,BigRoom
7,0.140027,0.084697,3.148168,0.267148,0.259155,1.188881,0.007938,0.259598,-22.707077,1.468045,-0.155444,0.39796,0.231832,0.164367,0.048508,0.041399,-0.107773,-0.066729,0.032721,0.070957,0.082827,0.007845,0.002854,0.0771,0.006971,0.016592,0.012131,0.101869,0.002467,0.004143,0.007565,0.025503,0.003043,0.039542,0.046062,0.04236,0.112595,0.042651,0.018279,0.37624,0.006253,0.148667,0.89867,0.373287,0.329866,0.266802,0.280483,0.217481,0.184907,0.194362,0.189401,0.190575,0.187424,0.173264,0.164607,0.007715,0.002954,0.056516,0.006453,0.012729,0.010177,0.059535,0.002542,0.003819,0.006958,0.024556,0.00265,0.014855,133.333333,0.198616,129.0,BigRoom
8,0.117635,0.146972,3.182842,0.243685,0.244968,1.099033,0.005454,0.195807,-20.837815,1.456355,0.016728,0.26685,0.102626,0.084006,-0.004172,-0.009111,-0.077262,0.090645,0.039457,0.082046,0.020676,0.009715,0.003509,0.088294,0.008659,0.012238,0.008928,0.05874,0.00274,0.008298,0.01919,0.027052,0.003149,0.036167,0.032876,0.048499,0.093086,0.029531,0.016448,0.331669,0.00325,0.083071,0.637533,0.3404,0.28449,0.268533,0.29584,0.246099,0.231303,0.212359,0.205208,0.194612,0.205478,0.214176,0.191211,0.009442,0.003292,0.048837,0.008002,0.009425,0.008502,0.053233,0.002812,0.007646,0.0132,0.02458,0.003016,0.013857,133.333333,0.148709,133.0,BigRoom
9,0.1374,0.127166,3.117554,0.297956,0.2793,1.076698,0.006725,0.271957,-22.962992,1.705804,0.135987,0.12311,-0.00529,-0.025106,-0.052423,0.030112,0.090758,0.138337,0.110342,0.041654,0.042206,0.007936,0.003106,0.117962,0.004994,0.012573,0.006303,0.0733,0.001627,0.0053,0.013402,0.024952,0.003663,0.043039,0.06435,0.067663,0.180277,0.057455,0.020675,0.530818,0.005479,0.197681,1.312354,0.466471,0.358931,0.337591,0.310588,0.260669,0.242583,0.232069,0.218763,0.269006,0.208188,0.204668,0.193399,0.00839,0.003943,0.07101,0.005266,0.013335,0.005602,0.061237,0.002407,0.006268,0.01265,0.024557,0.004807,0.019159,133.333333,0.172005,133.0,BigRoom


root
 |-- 1-ZCRm: double (nullable = true)
 |-- 2-Energym: double (nullable = true)
 |-- 3-EnergyEntropym: double (nullable = true)
 |-- 4-SpectralCentroidm: double (nullable = true)
 |-- 5-SpectralSpreadm: double (nullable = true)
 |-- 6-SpectralEntropym: double (nullable = true)
 |-- 7-SpectralFluxm: double (nullable = true)
 |-- 8-SpectralRolloffm: double (nullable = true)
 |-- 9-MFCCs1m: double (nullable = true)
 |-- 10-MFCCs2m: double (nullable = true)
 |-- 11-MFCCs3m: double (nullable = true)
 |-- 12-MFCCs4m: double (nullable = true)
 |-- 13-MFCCs5m: double (nullable = true)
 |-- 14-MFCCs6m: double (nullable = true)
 |-- 15-MFCCs7m: double (nullable = true)
 |-- 16-MFCCs8m: double (nullable = true)
 |-- 17-MFCCs9m: double (nullable = true)
 |-- 18-MFCCs10m: double (nullable = true)
 |-- 19-MFCCs11m: double (nullable = true)
 |-- 20-MFCCs12m: double (nullable = true)
 |-- 21-MFCCs13m: double (nullable = true)
 |-- 22-ChromaVector1m: double (nullable = true)
 |-- 23-ChromaVector2m:

In [4]:
# specify the response colum & tabulate by it
responseVar = 'class'
display(beats.groupBy(responseVar).count().toPandas())

Unnamed: 0,class,count
0,PsyTrance,100
1,HardDance,100
2,Breaks,100
3,HardcoreHardTechno,100
4,IndieDanceNuDisco,100
5,Trance,100
6,DeepHouse,100
7,ElectronicaDowntempo,100
8,ReggaeDub,100
9,Minimal,100


In [5]:
# check for missing values
cnt = beats.count()
nullCounts = {colm:beats.select(colm).where(col(colm).isNull()).count() for colm in beats.columns}
nullCounts = {colm:(ncnt, ncnt/cnt) for (colm, ncnt) in nullCounts.items()}

# pretty print
nullCountsDF = pd.DataFrame(nullCounts).T.reset_index(drop=False).sort_values(1, ascending=False)
nullCountsDF.columns = ['Column', 'Freq.', 'Rel. Freq.']
nullCountsDF = nullCountsDF.merge(pd.DataFrame([[colm.name, colm.dataType] for colm in beats.schema], columns=['Column', 'Type']),
                                how='inner', on=['Column'])
display(nullCountsDF)

Unnamed: 0,Column,Freq.,Rel. Freq.,Type
0,1-ZCRm,0.0,0.0,DoubleType
1,2-Energym,0.0,0.0,DoubleType
2,53-MFCCs11std,0.0,0.0,DoubleType
3,52-MFCCs10std,0.0,0.0,DoubleType
4,51-MFCCs9std,0.0,0.0,DoubleType
5,50-MFCCs8std,0.0,0.0,DoubleType
6,49-MFCCs7std,0.0,0.0,DoubleType
7,48-MFCCs6std,0.0,0.0,DoubleType
8,47-MFCCs5std,0.0,0.0,DoubleType
9,46-MFCCs4std,0.0,0.0,DoubleType


In [36]:
''' prep the data for modeling '''
# set inputs
inpColumns = beats.columns[:-1]

# create the features vector
assr = VectorAssembler(inputCols=inpColumns, outputCol='features_orig')
beatsML = assr.transform(beats)

# scale the features
scalr = RobustScaler(inputCol='features_orig', outputCol='features_scale')
beatsML = scalr.fit(beatsML).transform(beatsML)

# make the response numerical
indxr = StringIndexer(inputCol=responseVar, outputCol='label')
beatsML = indxr.fit(beatsML).transform(beatsML)

# get a distinct list of the original labels
origLabels = beatsML.select('label').distinct().orderBy('label')
rev = IndexToString(inputCol='label', outputCol='label_orig')
origLabels = [c['label_orig'] for c in rev.transform(origLabels).select('label_orig').collect()]

# talk
display(beatsML.limit(10).toPandas())
beatsML.select('features_scale', 'label').show(truncate=True)
beatsML.select('features_scale').take(1)
print('First row features = %s'%beatsML.select('features_scale').take(1)[0])

Unnamed: 0,1-ZCRm,2-Energym,3-EnergyEntropym,4-SpectralCentroidm,5-SpectralSpreadm,6-SpectralEntropym,7-SpectralFluxm,8-SpectralRolloffm,9-MFCCs1m,10-MFCCs2m,11-MFCCs3m,12-MFCCs4m,13-MFCCs5m,14-MFCCs6m,15-MFCCs7m,16-MFCCs8m,17-MFCCs9m,18-MFCCs10m,19-MFCCs11m,20-MFCCs12m,21-MFCCs13m,22-ChromaVector1m,23-ChromaVector2m,24-ChromaVector3m,25-ChromaVector4m,26-ChromaVector5m,27-ChromaVector6m,28-ChromaVector7m,29-ChromaVector8m,30-ChromaVector9m,31-ChromaVector10m,32-ChromaVector11m,33-ChromaVector12m,34-ChromaDeviationm,35-ZCRstd,36-Energystd,37-EnergyEntropystd,38-SpectralCentroidstd,39-SpectralSpreadstd,40-SpectralEntropystd,41-SpectralFluxstd,42-SpectralRolloffstd,43-MFCCs1std,44-MFCCs2std,45-MFCCs3std,46-MFCCs4std,47-MFCCs5std,48-MFCCs6std,49-MFCCs7std,50-MFCCs8std,51-MFCCs9std,52-MFCCs10std,53-MFCCs11std,54-MFCCs12std,55-MFCCs13std,56-ChromaVector1std,57-ChromaVector2std,58-ChromaVector3std,59-ChromaVector4std,60-ChromaVector5std,61-ChromaVector6std,62-ChromaVector7std,63-ChromaVector8std,64-ChromaVector9std,65-ChromaVector10std,66-ChromaVector11std,67-ChromaVector12std,68-ChromaDeviationstd,69-BPM,70-BPMconf,71-BPMessentia,class,features_orig,features_scale,label
0,0.13644,0.088861,3.201201,0.262825,0.249212,1.114423,0.007003,0.256682,-22.723259,1.594074,0.011276,0.204468,0.042072,0.048552,0.158505,0.118984,-0.147956,-0.186152,-0.026418,-0.007264,-0.0179,0.011581,0.008747,0.041081,0.014497,0.025711,0.012587,0.06017,0.002864,0.004631,0.009576,0.026079,0.004161,0.032185,0.050143,0.047313,0.102995,0.041285,0.017725,0.414831,0.005867,0.133778,0.838302,0.505911,0.356206,0.336074,0.288888,0.278649,0.283437,0.300305,0.287688,0.296692,0.258531,0.238352,0.194701,0.013138,0.011665,0.032049,0.015464,0.020453,0.012943,0.046397,0.003431,0.004981,0.010818,0.024001,0.005201,0.015056,133.333333,0.132792,128.0,BigRoom,"[0.136439587512, 0.0888612604609, 3.2012005559...","[2.621763736423596, 1.70103481603904, 32.56325...",0.0
1,0.117039,0.108389,3.194001,0.247657,0.250288,1.065668,0.005387,0.199821,-21.775871,1.261364,-0.113015,0.001718,-0.052682,0.20413,0.153013,0.067214,-0.013227,-0.05944,-0.008604,0.114257,0.171009,0.006535,0.002646,0.086485,0.008391,0.016442,0.009006,0.087948,0.002472,0.006549,0.007412,0.015386,0.005978,0.041116,0.043713,0.043721,0.099449,0.039386,0.018946,0.407164,0.003613,0.110334,0.624185,0.476993,0.353151,0.33555,0.283832,0.269621,0.24415,0.24666,0.25719,0.272036,0.269477,0.222393,0.187471,0.006761,0.003152,0.058923,0.009012,0.016106,0.009386,0.071726,0.004461,0.006441,0.007469,0.015499,0.005589,0.019339,120.0,0.112767,126.0,BigRoom,"[0.117038518483, 0.108389033282, 3.19400106287...","[2.2489612371225074, 2.074846995565889, 32.490...",0.0
2,0.085308,0.128525,3.123837,0.217205,0.228652,0.789647,0.008247,0.156822,-22.472722,1.425185,0.186749,0.417114,0.076406,0.190803,-0.016302,0.075038,0.10787,0.216874,0.095604,0.020977,-0.037011,0.007143,0.00296,0.220526,0.005639,0.010151,0.007453,0.043907,0.00124,0.004347,0.007989,0.017622,0.002636,0.066049,0.03292,0.037618,0.117704,0.041509,0.022645,0.34013,0.007697,0.085784,1.02874,0.449133,0.297935,0.266731,0.258299,0.275012,0.218368,0.19839,0.210177,0.212533,0.204458,0.197634,0.16491,0.007836,0.003079,0.093865,0.005692,0.008212,0.005451,0.0429,0.001529,0.004556,0.007723,0.017482,0.002901,0.022201,133.333333,0.123373,129.0,BigRoom,"[0.0853077737447, 0.128525418596, 3.1238373468...","[1.6392370551487672, 2.4603095954731096, 31.77...",0.0
3,0.10305,0.167042,3.15083,0.233593,0.245032,0.967082,0.006571,0.168083,-21.470751,1.463686,0.226548,0.404531,0.117699,0.081861,0.053974,0.164865,0.014919,0.11709,0.027778,-0.063173,-0.052606,0.010724,0.00334,0.125459,0.005728,0.014695,0.006322,0.072154,0.001628,0.003493,0.011463,0.032204,0.004738,0.046159,0.036349,0.06196,0.134908,0.032564,0.020036,0.365068,0.005215,0.086336,0.769981,0.425496,0.245312,0.260132,0.22422,0.207597,0.199472,0.207818,0.189912,0.185509,0.187273,0.177629,0.16474,0.00833,0.003528,0.061426,0.005443,0.012382,0.004985,0.057999,0.001591,0.003514,0.009477,0.023162,0.004165,0.015379,133.333333,0.158876,129.0,BigRoom,"[0.103049917216, 0.167041735198, 3.15083006899...","[1.9801623628816705, 3.1976117132436874, 32.05...",0.0
4,0.15173,0.148405,3.194498,0.29373,0.267231,1.353005,0.003872,0.292055,-21.371157,1.187854,0.184415,0.363724,0.232119,0.112277,0.107335,0.159296,0.067213,-0.018713,0.091529,0.117344,0.091616,0.009624,0.004031,0.076133,0.008175,0.016267,0.009927,0.088364,0.002645,0.004054,0.011083,0.023926,0.002248,0.036761,0.055214,0.041139,0.122271,0.036637,0.017732,0.45874,0.0029,0.148464,1.023154,0.431075,0.352099,0.327842,0.23718,0.230675,0.236538,0.25721,0.27038,0.242086,0.229678,0.211439,0.179589,0.01086,0.00539,0.046999,0.008598,0.015579,0.00918,0.069485,0.003945,0.004131,0.01133,0.028188,0.002639,0.019079,133.333333,0.190708,129.0,BigRoom,"[0.151729948738, 0.148404713864, 3.19449794602...","[2.9155766635230598, 2.840850825631186, 32.495...",0.0
5,0.127047,0.153488,3.221987,0.261693,0.257361,1.090034,0.004943,0.230099,-21.234846,1.541917,0.049064,0.194576,0.063895,0.058361,0.050883,0.071518,0.006431,0.011621,0.068274,0.130295,0.104191,0.009467,0.005906,0.079738,0.01158,0.014767,0.012332,0.078991,0.002105,0.005256,0.012025,0.031829,0.002447,0.041918,0.05022,0.048121,0.073129,0.043929,0.019323,0.461871,0.003707,0.138257,0.872498,0.418116,0.323801,0.332065,0.240972,0.212895,0.224638,0.225902,0.236724,0.227541,0.2327,0.216201,0.193793,0.009599,0.008033,0.058464,0.011039,0.011853,0.01143,0.059939,0.002986,0.006533,0.010347,0.025008,0.003035,0.019479,133.333333,0.168933,129.0,BigRoom,"[0.127046737192, 0.153487850284, 3.22198710261...","[2.441274812353337, 2.938155230050485, 32.7747...",0.0
6,0.123395,0.106206,3.167861,0.245459,0.245342,1.107316,0.008419,0.217213,-22.683738,1.405072,0.118772,0.646519,0.090651,0.174564,0.053183,0.024106,-0.045683,0.153718,0.126856,0.027705,-0.152214,0.014485,0.002422,0.092604,0.010015,0.021776,0.0064,0.044038,0.003403,0.007906,0.013052,0.022744,0.002966,0.034496,0.038486,0.05634,0.121416,0.036895,0.019332,0.368052,0.007098,0.100009,0.861205,0.553282,0.28581,0.265964,0.253458,0.214164,0.214898,0.208218,0.205747,0.220657,0.234514,0.19475,0.176016,0.016257,0.002349,0.051198,0.0128,0.019958,0.005169,0.040698,0.003394,0.00697,0.011462,0.019006,0.00279,0.01333,133.333333,0.259905,133.0,BigRoom,"[0.123395302003, 0.106206061431, 3.16786066505...","[2.3711104228312765, 2.033059256997651, 32.224...",0.0
7,0.140027,0.084697,3.148168,0.267148,0.259155,1.188881,0.007938,0.259598,-22.707077,1.468045,-0.155444,0.39796,0.231832,0.164367,0.048508,0.041399,-0.107773,-0.066729,0.032721,0.070957,0.082827,0.007845,0.002854,0.0771,0.006971,0.016592,0.012131,0.101869,0.002467,0.004143,0.007565,0.025503,0.003043,0.039542,0.046062,0.04236,0.112595,0.042651,0.018279,0.37624,0.006253,0.148667,0.89867,0.373287,0.329866,0.266802,0.280483,0.217481,0.184907,0.194362,0.189401,0.190575,0.187424,0.173264,0.164607,0.007715,0.002954,0.056516,0.006453,0.012729,0.010177,0.059535,0.002542,0.003819,0.006958,0.024556,0.00265,0.014855,133.333333,0.198616,129.0,BigRoom,"[0.140027382431, 0.0846969282386, 3.1481680737...","[2.6907052422129745, 1.6213187051157385, 32.02...",0.0
8,0.117635,0.146972,3.182842,0.243685,0.244968,1.099033,0.005454,0.195807,-20.837815,1.456355,0.016728,0.26685,0.102626,0.084006,-0.004172,-0.009111,-0.077262,0.090645,0.039457,0.082046,0.020676,0.009715,0.003509,0.088294,0.008659,0.012238,0.008928,0.05874,0.00274,0.008298,0.01919,0.027052,0.003149,0.036167,0.032876,0.048499,0.093086,0.029531,0.016448,0.331669,0.00325,0.083071,0.637533,0.3404,0.28449,0.268533,0.29584,0.246099,0.231303,0.212359,0.205208,0.194612,0.205478,0.214176,0.191211,0.009442,0.003292,0.048837,0.008002,0.009425,0.008502,0.053233,0.002812,0.007646,0.0132,0.02458,0.003016,0.013857,133.333333,0.148709,133.0,BigRoom,"[0.117635200751, 0.146971703134, 3.18284237148...","[2.260426823913964, 2.8134192865010372, 32.376...",0.0
9,0.1374,0.127166,3.117554,0.297956,0.2793,1.076698,0.006725,0.271957,-22.962992,1.705804,0.135987,0.12311,-0.00529,-0.025106,-0.052423,0.030112,0.090758,0.138337,0.110342,0.041654,0.042206,0.007936,0.003106,0.117962,0.004994,0.012573,0.006303,0.0733,0.001627,0.0053,0.013402,0.024952,0.003663,0.043039,0.06435,0.067663,0.180277,0.057455,0.020675,0.530818,0.005479,0.197681,1.312354,0.466471,0.358931,0.337591,0.310588,0.260669,0.242583,0.232069,0.218763,0.269006,0.208188,0.204668,0.193399,0.00839,0.003943,0.07101,0.005266,0.013335,0.005602,0.061237,0.002407,0.006268,0.01265,0.024557,0.004807,0.019159,133.333333,0.172005,133.0,BigRoom,"[0.137400181488, 0.127165866877, 3.11755416966...","[2.6402220922250765, 2.434284252188199, 31.712...",0.0


+--------------------+-----+
|      features_scale|label|
+--------------------+-----+
|[2.62176373642359...|  0.0|
|[2.24896123712250...|  0.0|
|[1.63923705514876...|  0.0|
|[1.98016236288167...|  0.0|
|[2.91557666352305...|  0.0|
|[2.44127481235333...|  0.0|
|[2.37111042283127...|  0.0|
|[2.69070524221297...|  0.0|
|[2.26042682391396...|  0.0|
|[2.64022209222507...|  0.0|
|[2.86002034317062...|  0.0|
|[2.30292299045409...|  0.0|
|[1.51190686506125...|  0.0|
|[2.65818792948456...|  0.0|
|[1.94661756778893...|  0.0|
|[2.55302183832814...|  0.0|
|[2.9458137836587,...|  0.0|
|[2.26746739169909...|  0.0|
|[2.13278345629527...|  0.0|
|[2.38849043835619...|  0.0|
+--------------------+-----+
only showing top 20 rows

First row features = [2.621763736423596,1.70103481603904,32.56325622817311,4.349129196001776,9.769208712661392,2.2303314225910627,1.0969734439957743,2.284572126925484,-13.74254910397548,2.089421920313096,0.029811040877083257,0.8793220327583383,0.23157344027726573,0.298682758833

In [16]:
''' feature selection - PCA '''
# choose the optimal number of components
for k in range(len(beats.columns)):
    print('Evaluating k=%d'%(k+1))
    try:
        featSel = PCA(k=k+1, inputCol='features_scale', outputCol='features')
        PCARes = featSel.fit(beatsML)
        print('\tTotal explained variance: %0.2f'%PCARes.explainedVariance.sum())
        if PCARes.explainedVariance.sum() > 0.95:
            break
    except KeyboardInterrupt as err:
        break

# perform the principle components reduction
pcs = int(input('Enter the number of PCs to keep'))
featSel = PCA(k=pcs, inputCol='features_scale', outputCol='features_pca')
PCARes = featSel.fit(beatsML)
beatsML = PCARes.transform(beatsML)
inpColumns = ['PC%0d'%i for i in range(pcs)]

# need to scale again
scalr = RobustScaler(inputCol='features_pca', outputCol='features')
beatsML = scalr.fit(beatsML).transform(beatsML)

# talk
print('Total explained variance: %0.2f'%PCARes.explainedVariance.sum())
display(beatsML.select('features_orig', 'features_scale', 'features_pca', 'features').limit(10).toPandas())

Evaluating k=1
	Total explained variance: 0.16
Evaluating k=2
	Total explained variance: 0.28
Evaluating k=3
	Total explained variance: 0.38
Evaluating k=4
	Total explained variance: 0.43
Evaluating k=5
	Total explained variance: 0.48
Evaluating k=6
	Total explained variance: 0.53
Evaluating k=7
	Total explained variance: 0.57
Evaluating k=8
	Total explained variance: 0.60
Evaluating k=9
	Total explained variance: 0.64
Evaluating k=10
	Total explained variance: 0.67
Evaluating k=11
	Total explained variance: 0.70
Evaluating k=12


Enter the number of PCs to keep 26


Total explained variance: 0.91


Unnamed: 0,features_orig,features_scale,features_pca,features
0,"[0.136439587512, 0.0888612604609, 3.2012005559...","[2.621763736423596, 1.70103481603904, 32.56325...","[-15.694435513462444, 7.043927742660012, 3.910...","[-4.212873575276903, 2.0630201992225055, 1.364..."
1,"[0.117038518483, 0.108389033282, 3.19400106287...","[2.2489612371225074, 2.074846995565889, 32.490...","[-12.660058914376675, 6.1360768673919255, 6.18...","[-3.3983527229173776, 1.7971295254415984, 2.15..."
2,"[0.0853077737447, 0.128525418596, 3.1238373468...","[1.6392370551487672, 2.4603095954731096, 31.77...","[-8.580487328020894, 5.778008719627879, 6.1758...","[-2.30326909790479, 1.6922587986932773, 2.1543..."
3,"[0.103049917216, 0.167041735198, 3.15083006899...","[1.9801623628816705, 3.1976117132436874, 32.05...","[-8.918801187124075, 4.408268973908173, 5.2905...","[-2.3940830374022104, 1.2910904638757854, 1.84..."
4,"[0.151729948738, 0.148404713864, 3.19449794602...","[2.9155766635230598, 2.840850825631186, 32.495...","[-11.694784486557477, 5.940016260127018, 7.021...","[-3.139243108789378, 1.7397074439216946, 2.449..."
5,"[0.127046737192, 0.153487850284, 3.22198710261...","[2.441274812353337, 2.938155230050485, 32.7747...","[-12.223194465967335, 5.5471792195538265, 5.49...","[-3.281084746690841, 1.6246536302948047, 1.917..."
6,"[0.123395302003, 0.106206061431, 3.16786066505...","[2.3711104228312765, 2.033059256997651, 32.224...","[-11.326408791108037, 5.305823232962465, 4.288...","[-3.0403596394347785, 1.5539654725322403, 1.49..."
7,"[0.140027382431, 0.0846969282386, 3.1481680737...","[2.6907052422129745, 1.6213187051157385, 32.02...","[-9.6182228366001, 5.015229791078525, 5.807243...","[-2.581829514969155, 1.468856686316644, 2.0258..."
8,"[0.117635200751, 0.146971703134, 3.18284237148...","[2.260426823913964, 2.8134192865010372, 32.376...","[-11.790873815922387, 4.244897870116009, 4.917...","[-3.1650364669640165, 1.2432424592673177, 1.71..."
9,"[0.137400181488, 0.127165866877, 3.11755416966...","[2.6402220922250765, 2.434284252188199, 31.712...","[-11.410700418099527, 7.3499729191959915, 7.24...","[-3.0629861281456936, 2.1526544777294445, 2.52..."


In [45]:
''' feature selection - univariate feature selector '''
# setup
ufs = UnivariateFeatureSelector(featuresCol='features_scale', labelCol='label', outputCol='features', selectionMode='percentile')
ufs.setFeatureType('continuous')
ufs.setLabelType('categorical')
ufs.setSelectionThreshold(0.5)

# perform the feature selection
ufsRes = ufs.fit(beatsML)
beatsML = ufsRes.transform(beatsML)
inpColumns = [inpColumns[c] for c in ufsRes.selectedFeatures]

# talk
print('Selected %d / %d Features'%(len(ufsRes.selectedFeatures), len(inpColumns)))
display(beatsML.select('features_orig', 'features_scale', 'features').limit(10).toPandas())

Selected 35 / 71 Features


Unnamed: 0,features_orig,features_scale,features
0,"[0.136439587512, 0.0888612604609, 3.2012005559...","[2.621763736423596, 1.70103481603904, 32.56325...","[2.621763736423596, 1.70103481603904, 32.56325..."
1,"[0.117038518483, 0.108389033282, 3.19400106287...","[2.2489612371225074, 2.074846995565889, 32.490...","[2.2489612371225074, 2.074846995565889, 32.490..."
2,"[0.0853077737447, 0.128525418596, 3.1238373468...","[1.6392370551487672, 2.4603095954731096, 31.77...","[1.6392370551487672, 2.4603095954731096, 31.77..."
3,"[0.103049917216, 0.167041735198, 3.15083006899...","[1.9801623628816705, 3.1976117132436874, 32.05...","[1.9801623628816705, 3.1976117132436874, 32.05..."
4,"[0.151729948738, 0.148404713864, 3.19449794602...","[2.9155766635230598, 2.840850825631186, 32.495...","[2.9155766635230598, 2.840850825631186, 32.495..."
5,"[0.127046737192, 0.153487850284, 3.22198710261...","[2.441274812353337, 2.938155230050485, 32.7747...","[2.441274812353337, 2.938155230050485, 32.7747..."
6,"[0.123395302003, 0.106206061431, 3.16786066505...","[2.3711104228312765, 2.033059256997651, 32.224...","[2.3711104228312765, 2.033059256997651, 32.224..."
7,"[0.140027382431, 0.0846969282386, 3.1481680737...","[2.6907052422129745, 1.6213187051157385, 32.02...","[2.6907052422129745, 1.6213187051157385, 32.02..."
8,"[0.117635200751, 0.146971703134, 3.18284237148...","[2.260426823913964, 2.8134192865010372, 32.376...","[2.260426823913964, 2.8134192865010372, 32.376..."
9,"[0.137400181488, 0.127165866877, 3.11755416966...","[2.6402220922250765, 2.434284252188199, 31.712...","[2.6402220922250765, 2.434284252188199, 31.712..."


In [46]:
''' split for cross-val '''
trainPerc = 0.7
randSeed = 42

tranBeats, testBeats = beatsML.randomSplit([trainPerc, 1.0 - trainPerc], seed=randSeed)

### Perform the modeling

In [47]:
''' set up the estimators & param grids '''
models = {}

# logistic regression
logreg = LogisticRegression()
params = (ParamGridBuilder().addGrid(logreg.threshold, [0.4, 0.5, 0.6])\
         .addGrid(logreg.elasticNetParam, [0.0, 0.25, 0.5, 0.75, 1.0]).build())
paramNames = ['threshold', 'elasticnetparam']
models['logistic regression'] = [logreg, params, paramNames, None, None]

# random forest
ranfor = RandomForestClassifier(numTrees=20)
params = (ParamGridBuilder().addGrid(ranfor.maxBins, [20, 40, 80, 100])\
              .addGrid(ranfor.maxDepth, [5, 10, 30]).build())
paramNames = ['maxbins', 'maxdepth']
models['random forest'] = [ranfor, params, paramNames, None, None]

In [48]:
''' run the models '''
# number of cv folds
folds = 5
# define the evaulation function
acc = MulticlassClassificationEvaluator(metricName='accuracy')

# iterate over models
for (model, stuff) in models.items():
    print('Cross Validator: %s'%model)
    # execute
    cv = CrossValidator(estimator=stuff[0], estimatorParamMaps=stuff[1], evaluator=acc, numFolds=folds)
    fitModel = cv.fit(tranBeats.select('features', 'label'))
    # get the best
    bestModel = fitModel.bestModel
    # evaluate performance on the test set
    testAcc = acc.evaluate(bestModel.transform(testBeats.select('features', 'label')))
    print('\tBest Model Test Accuracy = %0.3f'%testAcc)    
    # get best parameters
    bestParams = bestModel.extractParamMap()
    for (key, val) in bestParams.items():
        for parm in stuff[2]:
            if parm in key.name.lower():
                print('\t%s = %0.2f'%(key, val))
                break
    # save stuff
    models[model][3] = fitModel
    models[model][4] = testAcc

Cross Validator: logistic regression
	Best Model Test Accuracy = 0.340
	LogisticRegression_8ed8acc45092__elasticNetParam = 0.00
	LogisticRegression_8ed8acc45092__threshold = 0.40
Cross Validator: random forest
	Best Model Test Accuracy = 0.270
	RandomForestClassifier_360bb223dbea__maxBins = 100.00
	RandomForestClassifier_360bb223dbea__maxDepth = 30.00


In [55]:
# show logistic regression summary
bm = models['logistic regression'][3].bestModel
summ = bm.summary
summ.predictions.describe().show()
summ.objectiveHistory

# get models coefficients
coefs = bm.coefficientMatrix.toArray()
ints = bm.interceptVector.toArray()
coefs = pd.concat([pd.DataFrame(index=['Intercept'], data=np.atleast_2d(ints), columns=origLabels),
                   pd.DataFrame(index=inpColumns, data=coefs.T, columns=origLabels)])
display(coefs)

+-------+------------------+------------------+
|summary|             label|        prediction|
+-------+------------------+------------------+
|  count|              1671|              1671|
|   mean|11.128665469778575|11.207061639736684|
| stddev| 6.658178390179078| 6.668648458201444|
|    min|               0.0|               0.0|
|    max|              22.0|              22.0|
+-------+------------------+------------------+



Unnamed: 0,BigRoom,Breaks,Dance,DeepHouse,DrumAndBass,Dubstep,ElectroHouse,ElectronicaDowntempo,FunkRAndB,FutureHouse,GlitchHop,HardDance,HardcoreHardTechno,HipHop,House,IndieDanceNuDisco,Minimal,ProgressiveHouse,PsyTrance,ReggaeDub,TechHouse,Techno,Trance
Intercept,-0.048214,-0.046597,0.017225,-0.070909,0.024457,-0.000236,-0.019766,0.012063,0.017053,-0.145886,0.097583,0.006809,0.006144,-0.14713,-0.006739,0.032016,0.022094,0.096628,-0.005688,0.079296,-0.003475,-0.012988,0.096263
1-ZCRm,-0.060531,-0.239946,-0.071521,-0.57486,-0.198974,0.15309,-0.009448,-0.374943,0.229039,-0.029065,0.013956,0.243096,0.932801,0.057421,0.309382,-0.133095,-0.126334,-0.553072,0.052553,-0.203381,0.317964,0.141257,0.124611
2-Energym,0.740315,-0.032206,-0.327812,-1.106487,1.515343,0.709395,0.822504,-1.104928,-1.21437,0.637678,0.336965,1.182418,1.194038,0.445142,-0.081338,-1.642229,0.282976,-1.023269,-0.057913,-0.303297,-0.409947,-0.113265,-0.449713
3-EnergyEntropym,0.582359,0.043989,0.190773,-0.319816,-0.448372,0.001369,0.355372,-0.273616,-0.225882,0.668724,0.063931,0.375326,-0.381306,0.302162,0.067353,-0.22775,-0.482877,-0.16829,0.005937,0.03122,-0.440077,-0.205639,0.485112
4-SpectralCentroidm,0.058099,0.17942,-0.21006,-0.562114,0.906968,0.34108,0.116135,-0.435416,-0.197294,-0.264815,0.142582,0.223199,0.113058,0.038701,-0.203122,-0.206238,-0.061973,-0.399254,0.097276,-0.269672,0.176915,0.051076,0.365448
5-SpectralSpreadm,0.518265,0.076198,0.041688,-0.166467,1.082924,-0.357653,-0.002576,-0.182298,0.145623,-0.652538,-0.075642,0.442732,-0.21499,-0.64147,-0.492337,0.084154,-0.166878,0.0544,0.210749,-0.147986,-0.253383,-0.050218,0.747704
6-SpectralEntropym,0.143249,0.041834,-0.156548,-0.666194,-0.427811,0.32357,0.446582,-0.42254,0.127283,0.102758,0.175282,0.212023,0.203202,0.45436,0.135954,-0.140541,-0.207902,-0.517646,0.266054,-0.096864,-0.021063,0.025903,-0.000944
7-SpectralFluxm,-0.419349,-0.228847,-0.302884,-0.05589,-0.441365,-0.672913,-0.255469,-0.624617,-0.261678,-0.134185,-0.333693,-0.173233,0.996359,-0.462195,0.148557,-0.268614,0.907294,0.22161,1.23138,-0.222082,0.459142,0.65695,0.235723
8-SpectralRolloffm,-0.239768,-0.018695,-0.173417,-0.357057,-0.204579,0.493407,0.108517,-0.332939,-0.285601,0.080893,-0.150429,0.084844,0.59004,0.144339,0.053187,-0.298969,0.226995,-0.218101,0.298975,-0.481036,0.337543,0.215514,0.126337
9-MFCCs1m,1.138564,0.045571,0.516181,-0.785739,0.198884,-0.020799,0.842415,-0.888836,-0.368085,1.091997,0.447434,0.788205,-1.063212,0.213075,-0.041608,-0.529505,-0.603302,-0.475862,-0.135068,-0.166417,-0.679647,-0.259571,0.735324


In [56]:
# view feature importances for random forest
imports = models['random forest'][3].bestModel.featureImportances.toArray()
imports = pd.DataFrame(index=inpColumns, data=imports, columns=['Importance']).sort_values(by='Importance', ascending=False, inplace=False)
display(imports)

Unnamed: 0,Importance
1-ZCRm,0.050281
44-MFCCs2std,0.040492
2-Energym,0.039336
38-SpectralCentroidstd,0.035707
39-SpectralSpreadstd,0.03566
7-SpectralFluxm,0.035609
10-MFCCs2m,0.034525
24-ChromaVector3m,0.032119
41-SpectralFluxstd,0.030983
43-MFCCs1std,0.030415


In [57]:
sc.stop()