In [1]:
import pandas as pd
import numpy as np
import datetime as dt

import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml.feature import VectorAssembler, StringIndexer, StandardScaler
from pyspark.ml.clustering import *
from pyspark.ml.evaluation import *
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.stat import Correlation
from pyspark.ml import Pipeline

# iplot won't work because I've not installed the extension
import chart_studio.plotly as ply
import plotly.offline as plyoff
import plotly.graph_objects as go
import plotly.subplots as plysub

plyoff.init_notebook_mode(connected=True)
init = go.Figure(data=[go.Scatter(x=[1,2], y=[42,42])], layout=go.Layout(title='Init'))
plyoff.iplot(init)

pd.set_option('display.max_columns', None)

In [2]:
# initialize
sc = pyspark.SparkContext()
spark = SparkSession(sc)
spark.sparkContext.appName = 'gmm'
# show the number of cores
print('%d cores'%spark._jsc.sc().getExecutorMemoryStatus().keySet().size())
spark

1 cores


In [3]:
''' get the data '''
# load the data - can't parse the timestamp for some reason, so just let it be string
fil = '../data/sales_data_sample.csv'
schem = StructType([StructField('ORDERNUMBER', IntegerType()), StructField('QUANTITYORDERED', IntegerType()),
                    StructField('PRICEEACH', FloatType()), StructField('ORDERLINENUMBER', IntegerType()),
                    StructField('SALES', FloatType()), StructField('ORDERDATE', StringType()),# TimestampType()),
                    StructField('STATUS', StringType()), StructField('QTR_ID', IntegerType()),
                    StructField('MONTH_ID', IntegerType()), StructField('YEAR_ID', IntegerType()),
                    StructField('PRODUCTLINE', StringType()), StructField('MSRP', FloatType()),
                    StructField('PRODUCTCODE', StringType()), StructField('CUSTOMERNAME', StringType()),
                    StructField('PHONE', IntegerType()), StructField('ADDRESSLINE1', StringType()),
                    StructField('ADDRESSLINE2', StringType()), StructField('CITY', StringType()),
                    StructField('STATE', StringType()), StructField('POSTALCODE', StringType()),
                    StructField('COUNTRY', StringType()), StructField('TERRITORY', StringType()),
                    StructField('CONTACTLASTNAME', FloatType()), StructField('CONTACTFIRSTNAME', StringType()),
                    StructField('DEALSIZE', StringType())])
sales = spark.read.format('csv').options(header=True, timestampFormat='M/d/yyyy HH:MM').schema(schem).load(fil)

# talk
cnt = sales.count()
print('%d records'%cnt)
sales.show(truncate=False)

2823 records
+-----------+---------------+---------+---------------+-------+---------------+-------+------+--------+-------+-----------+----+-----------+--------------------------+----------+-----------------------------+------------+-------------+--------+----------+---------+---------+---------------+----------------+--------+
|ORDERNUMBER|QUANTITYORDERED|PRICEEACH|ORDERLINENUMBER|SALES  |ORDERDATE      |STATUS |QTR_ID|MONTH_ID|YEAR_ID|PRODUCTLINE|MSRP|PRODUCTCODE|CUSTOMERNAME              |PHONE     |ADDRESSLINE1                 |ADDRESSLINE2|CITY         |STATE   |POSTALCODE|COUNTRY  |TERRITORY|CONTACTLASTNAME|CONTACTFIRSTNAME|DEALSIZE|
+-----------+---------------+---------+---------------+-------+---------------+-------+------+--------+-------+-----------+----+-----------+--------------------------+----------+-----------------------------+------------+-------------+--------+----------+---------+---------+---------------+----------------+--------+
|10107      |30             |95.7

In [4]:
# check out an order
display(sales.where(col('ORDERNUMBER')==10159).orderBy(col('ORDERLINENUMBER')).toPandas())

Unnamed: 0,ORDERNUMBER,QUANTITYORDERED,PRICEEACH,ORDERLINENUMBER,SALES,ORDERDATE,STATUS,QTR_ID,MONTH_ID,YEAR_ID,PRODUCTLINE,MSRP,PRODUCTCODE,CUSTOMERNAME,PHONE,ADDRESSLINE1,ADDRESSLINE2,CITY,STATE,POSTALCODE,COUNTRY,TERRITORY,CONTACTLASTNAME,CONTACTFIRSTNAME,DEALSIZE
0,10159,50,69.800003,1,3490.0,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,61.0,S24_3371,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Medium
1,10159,41,100.0,2,8296.349609,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,194.0,S12_1099,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Large
2,10159,24,73.419998,3,1762.079956,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,79.0,S12_3990,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Small
3,10159,25,100.0,4,3638.0,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,146.0,S18_3482,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Medium
4,10159,21,81.209999,5,1705.410034,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,80.0,S18_3278,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Small
5,10159,23,67.099998,6,1543.300049,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,80.0,S24_4620,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Small
6,10159,32,100.0,7,4618.879883,10/10/2003 0:00,Shipped,4,10,2003,Classic Cars,148.0,S18_4721,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Medium
7,10159,21,64.660004,8,1357.859985,10/10/2003 0:00,Shipped,4,10,2003,Motorcycles,62.0,S18_3782,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Small
8,10159,35,35.400002,9,1239.0,10/10/2003 0:00,Shipped,4,10,2003,Motorcycles,40.0,S32_2206,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Small
9,10159,31,71.599998,10,2219.600098,10/10/2003 0:00,Shipped,4,10,2003,Motorcycles,81.0,S50_4713,Corporate Gift Ideas Co.,,7734 Strong St.,,San Francisco,CA,,USA,,,Julie,Small


### Data Prep

In [5]:
# drop some columns that I have no doubt will not be useful for modeling - the latter because I can't aggregate them
# dropping quarter as it's highly correlated with month
sales = sales.drop('CONTACTLASTNAME', 'ADDRESSLINE2', 'CUSTOMERNAME', 'CONTACTFIRSTNAME', 'ADDRESSLINE1',
                   'ORDERDATE', 'PRODUCTLINE', 'PRODUCTCODE', 'QTR_ID')

In [6]:
''' handle missing values '''
# presumably important columns (for modeling)
importantCols = ['STATE', 'POSTALCODE', 'ORDERNUMBER', 'TERRITORY', 'COUNTRY',
                 'CITY', 'QUANTITYORDERED', 'MSRP', 'YEAR_ID', 'MONTH_ID',
                 'STATUS', 'SALES', 'ORDERLINENUMBER', 'PRICEEACH',
                 'DEALSIZE']

# check for missing values
nullCounts = {colm:sales.select(colm).where(col(colm).isNull()).count() for colm in sales.columns}
nullCounts = {colm:(ncnt, ncnt/cnt) for (colm, ncnt) in nullCounts.items()}
nullCountsDF = pd.DataFrame(nullCounts).T.reset_index(drop=False).sort_values(1, ascending=False)
nullCountsDF.columns = ['Column', 'Freq.', 'Rel. Freq.']
nullCountsDF = nullCountsDF.merge(pd.DataFrame([[colm.name, colm.dataType] for colm in sales.schema], columns=['Column', 'Type']),
                                how='inner', on=['Column'])
nullCountsDF['Important'] = [c in importantCols for c in nullCountsDF['Column']]

# talk
display(nullCountsDF)

# fill null states with 'NA' as they are ex-US, and the missing zip codes for california to the range
sales = sales.fillna(value='NA', subset='STATE').fillna(value='90001_96162', subset='POSTALCODE')

# talk some more
print('%d records'%sales.count())
display(sales.limit(10).toPandas())

Unnamed: 0,Column,Freq.,Rel. Freq.,Type,Important
0,PHONE,2589.0,0.917109,IntegerType,False
1,STATE,1486.0,0.52639,StringType,True
2,POSTALCODE,76.0,0.026922,StringType,True
3,ORDERNUMBER,0.0,0.0,IntegerType,True
4,QUANTITYORDERED,0.0,0.0,IntegerType,True
5,PRICEEACH,0.0,0.0,FloatType,True
6,ORDERLINENUMBER,0.0,0.0,IntegerType,True
7,SALES,0.0,0.0,FloatType,True
8,STATUS,0.0,0.0,StringType,True
9,MONTH_ID,0.0,0.0,IntegerType,True


2823 records


Unnamed: 0,ORDERNUMBER,QUANTITYORDERED,PRICEEACH,ORDERLINENUMBER,SALES,STATUS,MONTH_ID,YEAR_ID,MSRP,PHONE,CITY,STATE,POSTALCODE,COUNTRY,TERRITORY,DEALSIZE
0,10107,30,95.699997,2,2871.0,Shipped,2,2003,95.0,2125558000.0,NYC,NY,10022,USA,,Small
1,10121,34,81.349998,5,2765.899902,Shipped,5,2003,95.0,,Reims,,51100,France,EMEA,Small
2,10134,41,94.739998,2,3884.340088,Shipped,7,2003,95.0,,Paris,,75508,France,EMEA,Medium
3,10145,45,83.260002,6,3746.699951,Shipped,8,2003,95.0,,Pasadena,CA,90003,USA,,Medium
4,10159,49,100.0,14,5205.27002,Shipped,10,2003,95.0,,San Francisco,CA,90001_96162,USA,,Medium
5,10168,36,96.660004,1,3479.76001,Shipped,10,2003,95.0,,Burlingame,CA,94217,USA,,Medium
6,10180,29,86.129997,9,2497.77002,Shipped,11,2003,95.0,,Lille,,59000,France,EMEA,Small
7,10188,48,100.0,1,5512.319824,Shipped,11,2003,95.0,,Bergen,,N 5804,Norway,EMEA,Medium
8,10201,22,98.57,2,2168.540039,Shipped,12,2003,95.0,,San Francisco,CA,90001_96162,USA,,Small
9,10211,41,100.0,14,4708.439941,Shipped,1,2004,95.0,,Paris,,75016,France,EMEA,Medium


In [7]:
''' see some value counts '''
for colm in ['STATUS', 'STATE', 'POSTALCODE', 'COUNTRY', 'TERRITORY', 'DEALSIZE']:
    print(colm)
    sales.select(colm).groupBy(colm).count().show()

STATUS
+----------+-----+
|    STATUS|count|
+----------+-----+
|   Shipped| 2617|
|   On Hold|   44|
| Cancelled|   60|
|  Resolved|   47|
|In Process|   41|
|  Disputed|   14|
+----------+-----+

STATE
+-------------+-----+
|        STATE|count|
+-------------+-----+
|           NJ|   21|
|           NA| 1486|
|     Victoria|   78|
|           BC|   48|
|           NH|   34|
|           NV|   29|
|        Tokyo|   32|
|           CA|  416|
|           CT|   61|
|           PA|   75|
|           NY|  178|
|       Quebec|   22|
|Isle of Wight|   26|
|        Osaka|   20|
|           MA|  190|
|   Queensland|   15|
|          NSW|   92|
+-------------+-----+

POSTALCODE
+----------+-----+
|POSTALCODE|count|
+----------+-----+
|     28034|  259|
|   WX1 6LT|   29|
|     97823|   17|
|     51003|   44|
|     97562|  205|
|     67000|   19|
|    B-6000|    8|
|     92561|    3|
|     80686|   14|
|     41101|   15|
|      8200|   27|
|  530-0003|   20|
|     44000|   60|
|     69045|   36|

In [8]:
''' index the string columns '''
# set the string columns to index
strCols = ['STATUS', 'STATE', 'POSTALCODE', 'COUNTRY', 'TERRITORY', 'DEALSIZE']
# do the indexing
indxr = StringIndexer(inputCols=strCols, outputCols=[c+'_int' for c in strCols])
sales = indxr.fit(sales).transform(sales)
# talk
display(sales.limit(10).toPandas())

Unnamed: 0,ORDERNUMBER,QUANTITYORDERED,PRICEEACH,ORDERLINENUMBER,SALES,STATUS,MONTH_ID,YEAR_ID,MSRP,PHONE,CITY,STATE,POSTALCODE,COUNTRY,TERRITORY,DEALSIZE,STATUS_int,POSTALCODE_int,TERRITORY_int,STATE_int,COUNTRY_int,DEALSIZE_int
0,10107,30,95.699997,2,2871.0,Shipped,2,2003,95.0,2125558000.0,NYC,NY,10022,USA,,Small,0.0,2.0,1.0,3.0,0.0,1.0
1,10121,34,81.349998,5,2765.899902,Shipped,5,2003,95.0,,Reims,,51100,France,EMEA,Small,0.0,17.0,0.0,0.0,2.0,1.0
2,10134,41,94.739998,2,3884.340088,Shipped,7,2003,95.0,,Paris,,75508,France,EMEA,Medium,0.0,60.0,0.0,0.0,2.0,0.0
3,10145,45,83.260002,6,3746.699951,Shipped,8,2003,95.0,,Pasadena,CA,90003,USA,,Medium,0.0,33.0,1.0,1.0,0.0,0.0
4,10159,49,100.0,14,5205.27002,Shipped,10,2003,95.0,,San Francisco,CA,90001_96162,USA,,Medium,0.0,4.0,1.0,1.0,0.0,0.0
5,10168,36,96.660004,1,3479.76001,Shipped,10,2003,95.0,,Burlingame,CA,94217,USA,,Medium,0.0,3.0,1.0,1.0,0.0,0.0
6,10180,29,86.129997,9,2497.77002,Shipped,11,2003,95.0,,Lille,,59000,France,EMEA,Small,0.0,59.0,0.0,0.0,2.0,1.0
7,10188,48,100.0,1,5512.319824,Shipped,11,2003,95.0,,Bergen,,N 5804,Norway,EMEA,Medium,0.0,36.0,0.0,0.0,7.0,0.0
8,10201,22,98.57,2,2168.540039,Shipped,12,2003,95.0,,San Francisco,CA,90001_96162,USA,,Small,0.0,4.0,1.0,1.0,0.0,1.0
9,10211,41,100.0,14,4708.439941,Shipped,1,2004,95.0,,Paris,,75016,France,EMEA,Medium,0.0,38.0,0.0,0.0,2.0,0.0


In [9]:
''' aggregate by order '''
# set the output features
features = ['AvgQuantity', 'AvgPrice', 'ItemCount', 'Month', 'Year', 'AvgMSRP', 'Status', 'State', 'Country',
            'PostCode', 'DealSize', 'Territory', 'TotalSale', 'ItemCount']
# agg
sales = sales.groupBy('ORDERNUMBER').agg(mean('QUANTITYORDERED').alias('AvgQuantity'), mean('PRICEEACH').alias('AvgPrice'),
                                         max('ORDERLINENUMBER').alias('LineCount'), min('MONTH_ID').alias('Month'),
                                         min('YEAR_ID').alias('Year'), mean('MSRP').alias('AvgMSRP'),
                                         min('STATUS_int').alias('Status'), min('STATE_int').alias('State'),
                                         min('COUNTRY_int').alias('Country'), min('POSTALCODE_int').alias('PostCode'),
                                         min('DEALSIZE_int').alias('DealSize'), min('TERRITORY_int').alias('Territory'),
                                         sum('SALES').alias('TotalSale'), sum('QUANTITYORDERED').alias('ItemCount'))\
    .select('ORDERNUMBER', *features)
# talk
display(sales.limit(10).toPandas())

Unnamed: 0,ORDERNUMBER,AvgQuantity,AvgPrice,ItemCount,Month,Year,AvgMSRP,Status,State,Country,PostCode,DealSize,Territory,TotalSale,ItemCount.1
0,10206,32.272727,86.77,355,12,2003,104.0,0.0,8.0,9.0,56.0,0.0,1.0,38662.209717,355
1,10362,29.25,86.647499,117,1,2005,128.75,0.0,1.0,0.0,3.0,0.0,1.0,13529.57019,117
2,10121,37.0,83.963998,185,5,2003,99.4,0.0,0.0,2.0,17.0,0.0,0.0,18971.959717,185
3,10230,42.25,83.42125,338,3,2004,108.375,0.0,0.0,11.0,54.0,0.0,0.0,37266.48938,338
4,10395,39.0,92.280001,156,3,2005,129.0,0.0,0.0,2.0,60.0,0.0,0.0,20321.529785,156
5,10416,32.857143,82.759285,460,5,2005,88.428571,0.0,0.0,5.0,20.0,0.0,0.0,41509.940063,460
6,10257,41.6,78.972,208,6,2004,92.6,0.0,1.0,0.0,3.0,0.0,1.0,16128.100098,208
7,10264,36.142857,74.734285,253,6,2004,83.142857,0.0,2.0,0.0,14.0,0.0,1.0,19548.350037,253
8,10128,39.25,97.290001,157,6,2003,102.25,0.0,0.0,1.0,0.0,0.0,0.0,17448.080078,157
9,10183,32.583333,87.916666,391,11,2003,109.416667,0.0,6.0,0.0,15.0,0.0,1.0,40061.660034,391


In [10]:
''' prepare the features '''
# create & scale the features vector
assr = VectorAssembler(inputCols=features, outputCol='features_raw')
scalr = StandardScaler(inputCol='features_raw', outputCol='features')
pipe = Pipeline(stages=[assr, scalr]).fit(sales)
sales = pipe.transform(sales).drop('features_raw')

# talk
display(sales.limit(10).toPandas())
sales.select('features').take(1)
print('First row features = %s'%sales.select('features').take(1)[0])

Unnamed: 0,ORDERNUMBER,AvgQuantity,AvgPrice,ItemCount,Month,Year,AvgMSRP,Status,State,Country,PostCode,DealSize,Territory,TotalSale,ItemCount.1,features
0,10206,32.272727,86.77,355,12,2003,104.0,0.0,8.0,9.0,56.0,0.0,1.0,38662.209717,355,"[6.695761042216421, 9.9101624751318, 2.0200486..."
1,10362,29.25,86.647499,117,1,2005,128.75,0.0,1.0,0.0,3.0,0.0,1.0,13529.57019,117,"[6.068622860093334, 9.89617140593235, 0.665762..."
2,10121,37.0,83.963998,185,5,2003,99.4,0.0,0.0,2.0,17.0,0.0,0.0,18971.959717,185,"[7.676548575160798, 9.589683822417818, 1.05270..."
3,10230,42.25,83.42125,338,3,2004,108.375,0.0,0.0,11.0,54.0,0.0,0.0,37266.48938,338,"[8.76578857569037, 9.52769555979781, 1.9233139..."
4,10395,39.0,92.280001,156,3,2005,129.0,0.0,0.0,2.0,60.0,0.0,0.0,20321.529785,156,"[8.091497146791111, 10.539469849491002, 0.8876..."
5,10416,32.857143,82.759285,460,5,2005,88.428571,0.0,0.0,5.0,20.0,0.0,0.0,41509.940063,460,"[6.817012248212291, 9.452091237298191, 2.61752..."
6,10257,41.6,78.972,208,6,2004,92.6,0.0,1.0,0.0,3.0,0.0,1.0,16128.100098,208,"[8.63093028991052, 9.01953844871407, 1.1835778..."
7,10264,36.142857,74.734285,253,6,2004,83.142857,0.0,2.0,0.0,14.0,0.0,1.0,19548.350037,253,"[7.498713473033521, 8.535541162608858, 1.43964..."
8,10128,39.25,97.290001,157,6,2003,102.25,0.0,0.0,1.0,0.0,0.0,0.0,17448.080078,157,"[8.143365718244901, 11.111671257871407, 0.8933..."
9,10183,32.583333,87.916666,391,11,2003,109.416667,0.0,6.0,0.0,15.0,0.0,1.0,40061.660034,391,"[6.760203812810523, 10.04112535062834, 2.22489..."


First row features = [6.695761042216421,9.9101624751318,2.020048649824036,3.2888780296672513,2801.675686425482,5.885823940998557,0.0,2.251454956624162,1.8889699073198103,2.474031956714512,0.0,1.2449667681321885,2.1921892435008092,2.020048649824036]


In [11]:
# check for multicollinearity
corr = Correlation.corr(sales, column='features', method='pearson')
corrdf = pd.DataFrame(index=features, data=corr.collect()[0][0].toArray(), columns=features)
display(corrdf)

Unnamed: 0,AvgQuantity,AvgPrice,ItemCount,Month,Year,AvgMSRP,Status,State,Country,PostCode,DealSize,Territory,TotalSale,ItemCount.1
AvgQuantity,1.0,0.062665,0.030248,-0.088645,0.184497,0.082597,0.183027,-0.05836,-0.155987,-0.071344,-0.017657,-0.049534,0.04359,0.030248
AvgPrice,0.062665,1.0,0.037275,-0.008084,-0.04724,0.782062,-0.000776,0.020867,0.083265,0.068795,-0.380235,0.022937,0.161578,0.037275
ItemCount,0.030248,0.037275,1.0,0.099039,-0.06398,0.04108,-0.008664,-0.035982,-0.056716,-0.210889,-0.326147,-0.02085,0.966176,1.0
Month,-0.088645,-0.008084,0.099039,1.0,-0.441549,-0.030964,-0.099557,0.00202,-0.004976,0.057231,-0.032925,-0.032708,0.110859,0.099039
Year,0.184497,-0.04724,-0.06398,-0.441549,1.0,-0.038607,0.318256,0.014285,-0.001681,-0.085471,0.064561,0.013328,-0.066069,-0.06398
AvgMSRP,0.082597,0.782062,0.04108,-0.030964,-0.038607,1.0,0.005358,0.035913,0.089052,0.071918,-0.278742,0.021671,0.216201,0.04108
Status,0.183027,-0.000776,-0.008664,-0.099557,0.318256,0.005358,1.0,-0.007418,0.017101,-0.035605,0.176784,-0.021786,-0.019045,-0.008664
State,-0.05836,0.020867,-0.035982,0.00202,0.014285,0.035913,-0.007418,1.0,-0.060046,0.144013,-0.071679,0.509668,-0.042118,-0.035982
Country,-0.155987,0.083265,-0.056716,-0.004976,-0.001681,0.089052,0.017101,-0.060046,1.0,0.474515,0.15503,-0.002851,-0.035172,-0.056716
PostCode,-0.071344,0.068795,-0.210889,0.057231,-0.085471,0.071918,-0.035605,0.144013,0.474515,1.0,0.137353,-0.148703,-0.204198,-0.210889


## Modeling

In [None]:
# split for cross-val
trainPerc = 0.7
randSeed = 42
tran, test = cc.select('ORDERNUMBER', 'features').randomSplit([trainPerc, 1.0 - trainPerc], seed=randSeed)

# talk
print('Training Cases')
tran.select('ORDERNUMBER').show()
print('Testing Cases')
test.select('ORDERNUMBER').show()

In [None]:
''' set up the estimators & param grids '''
models = {}
kMax = 10

# GMM
GMM = GauissianMixture()
params = (ParamGridBuilder().addGrid(GMM.k, list(range(2, kMax+1))).build())
paramNames = ['k']
models['GMM'] = [GMM, params, paramNames, None, None]

In [None]:
''' run the models '''
# number of cv folds
folds = 5
# define the evaulation function
evl = ClusteringEvaluator()

# iterate over models
for (model, stuff) in models.items():
    print('Cross Validator: %s'%model)
    # execute
    cv = CrossValidator(estimator=stuff[0], estimatorParamMaps=stuff[1], evaluator=evl, numFolds=folds)
    fitModel = cv.fit(house.select('features', 'label'))
    # get the best
    bestModel = fitModel.bestModel
    # evaluate performance on the test set
    testScore = evl.evaluate(bestModel.transform(test.select('features', 'label')))
    print('\tBest Model Test Score = %0.3f'%testScore)    
    # get best parameters
    bestParams = bestModel.extractParamMap()
    for (key, val) in bestParams.items():
        for parm in stuff[2]:
            if parm in key.name.lower():
                print('\t%s = %0.2f'%(key, val))
                break
    # save stuff
    models[model][3] = fitModel
    models[model][4] = testScore

In [None]:
''' evaluate different clustering cardinalities '''
# setup range that will be tried
kMax = 41
xs = list(range(2, kMax))
kCost = np.ones(kMax)*np.inf
kSil = np.ones(kMax)*-1

# iterate over k
models = [None]*kMax
for k in range(2, kMax):
    print('Trying k = %d'%k)
    # fit the model on the training set
    GMM = GaussianMixture(k=k, seed=randSeed, featuresCol='features')
    models[k] = GMM.fit(tran)
    # eval the model on the training set
    kCost[k] = models[k].summary.trainingCost
    print('\tTraining Cost = %0.3f'%kCost[k])
    # eval the model on the testing set 
    testPred = models[k].transform(test)
    kSil[k] = ClusteringEvaluator().evaluate(testPred)
    print('\tSilhouette score = %0.3f'%kSil[k])
    
# show the scree plot
fig = plysub.make_subplots(rows=2, cols=1, print_grid=False, subplot_titles=('Train SSE vs. k', 'Test Silhouette vs. k'))
fig.add_trace(go.Scatter(x=xs, y=kCost[2:], mode='markers+lines'), 1, 1)
fig.add_trace(go.Scatter(x=xs, y=kSil[2:], mode='markers+lines'), 2, 1)
fig['layout']['title'] = '|GMM Results'
plyoff.plot(fig)

# find the min
bestK = np.argmin(kCost)
print('Best model has %d clusters, with a cost of %0.3f'%(bestK, kCost[bestK]))

In [None]:
''' evaluate different clustering cardinalities  - bisecting kmeans '''
# setup range that will be tried
bkCost = np.ones(kMax)*np.inf
bkSil = np.ones(kMax)*-1

# iterate over k
bmodels = [None]*kMax
for k in range(2, kMax):
    print('Trying k = %d'%k)
    # fit the bisecting kmeans model on the training set
    kmeans = BisectingKMeans(k=k, seed=randSeed, featuresCol='features')
    bmodels[k] = kmeans.fit(tran)
    # eval the model on the training set
    bkCost[k] = bmodels[k].summary.trainingCost
    print('\tTraining Cost = %0.3f'%bkCost[k])
    # eval the model on the testing set 
    testPred = bmodels[k].transform(test)
    bkSil[k] = ClusteringEvaluator().evaluate(testPred)
    print('\tSilhouette score = %0.3f'%bkSil[k])
    
# show the scree plot
fig = plysub.make_subplots(rows=2, cols=1, print_grid=False, subplot_titles=('Train SSE vs. k', 'Test Silhouette vs. k'))
fig.add_trace(go.Scatter(x=xs, y=bkCost[2:], mode='markers+lines'), 1, 1)
fig.add_trace(go.Scatter(x=xs, y=bkSil[2:], mode='markers+lines'), 2, 1)
fig['layout']['title'] = 'Bisecting Kmeans Results'
plyoff.plot(fig)

# find the min
bestBK = np.argmin(bkCost)
print('Best model has %d clusters, with a cost of %0.3f'%(bestBK, bkCost[bestBK]))

In [None]:
''' Evaluate best model on test set '''
# get the best
bestK = int(input('Enter the "best" k'))
bestMod = input('Enter the best model ("k" or bk")')
if bestMod == 'k':
    # kmeans
    bestModel = models[bestK]
    cst = kCost[bestK]
else:
    # bisecting kmeans
    bestModel = bmodels[bestK]
    cst = bkCost[bestK]
print('Best %s model has %d clusters, with a cost of %0.3f'%(bestMod, bestK, cst))

# predict
testPred = bestModel.transform(test)
# eval
evalSil = ClusteringEvaluator()
silhouette = evalSil.evaluate(testPred)
print('Silhouette score for %s model with %d cluster = [-1, %0.3f, 1]'%(bestMod, bestK, silhouette))
# get the centers
cents = pd.DataFrame(index=list(range(bestK)), data=bestModel.clusterCenters(), columns=features)
display(cents)

In [None]:
# add predictions to entire dataset
ccpred = bestModel.transform(cc)
for feat in features:
    ccpred.groupBy('prediction').agg(min(col(feat)), mean(col(feat)), max(col(feat))).show()

In [None]:
sc.stop()