# Spark MLLib Example: Clustering

### Download the [spreadsheet](WSSSE-versus-k.xlsx)

Let's look at a clustering example in Spark MLLib.

Here, we are going to load the mtcars dataset. This has some stats on different models of cars.  Here, we will load the CSV file as a spark dataframe, and view it.

This dataset contains some statistics on 1974 Cars from Motor Trends

Here are the columns:
* name   - name of the car
*  mpg   - Miles/(US) gallon                        
*  cyl   - Number of cylinders                      
*  disp  - Displacement (cu.in.)                    
*  hp    - Gross horsepower                         
*  drat  - Rear axle ratio            

Are there any natural clusters you can identify from this data?

We are going to use **MPG and CYL** attributes to cluster.

You can also download and view the raw data in Excel : [cars.csv](/data/cars/mtcars_header.csv)

<img src="../../assets/images/6.1-cars2.png" style="border: 5px solid grey; max-width:100%;" />

In [None]:
# initialize Spark Session
import os
import sys
top_dir = os.path.abspath(os.path.join(os.getcwd(), "../../"))
if top_dir not in sys.path:
    sys.path.append(top_dir)

from init_spark import init_spark
spark = init_spark()
sc = spark.sparkContext

## Step 1 : Load Data

In [None]:
## Imports
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans

In [None]:
dataset = spark.read.csv("/data/cars/mtcars_header.csv", header=True, inferSchema=True)

In [None]:
## TODO : print schema
## Hint : printSchema()
dataset.???()

In [None]:
## TODO : display the data
## Hint : show
dataset.???()

## Step 2 : Extract data
We only care about 'model', 'mpg' and 'cyl' columns

In [None]:
## TODO : extract the columns we need : model, mpg and cyl
dataset2 = dataset.select(["model", "???", "???"])
dataset2.show()

## Step 3 : Creating Vectors

Now that we have ourselves a dataframe, let's work on turning it into vectors.  We're going to vectorize 2 columns:

1. MPG
2. Number of cylinders.

What we'll do, is we'll use the VectorAssembler class to create a new column by the name of features. This will be a Vector.

In [None]:
## TODO : create featureVector with 'mpg' and 'cyl'
## Hint :  inputCols=['mpg', 'cyl']
assembler = VectorAssembler(inputCols=["mpg", "???"], outputCol="features")
featureVector = assembler.transform(dataset2)
featureVector.show()

## Step 4 : Running Kmeans

Now it's time to run kmeans on the resultant dataframe. We don't know what value of k to use, so let's just start with k=2.  This means we will cluster into two groups.

We will fit a model to the data, and then train it.

In [None]:
k = 2
kmeans = KMeans().setK(k).setMaxIter(10)
model = kmeans.fit(featureVector)
wssse = model.computeCost(featureVector)

print(wssse)

The WSSSE for this is not particularly good.  We will probably need to change k.



## Step 5 : Display grouping
Let's take a look at the transformed dataset.  Notice the new column "prediction."

In [None]:
predicted = model.transform(featureVector)
predicted.orderBy(['prediction', 'mpg']).show(32)

Notice what we have here.  We have two clusters. One is smaller, fuel efficient cars like the Fiat and the Corolla (remember, we cluster on two variables only: mpg and cyl).  The other is for basically all other cars.  Probably, we can get better results here with a differnet value of k.

## Step 6 : Adjust K

In [None]:
k = 3
kmeans = KMeans().setK(k).setMaxIter(10)
model = kmeans.fit(featureVector)
wssse = model.computeCost(featureVector)

print('WSSSE: ' + str(wssse))

This is a much better result for WSSSE (lower is better).

In [None]:
predicted = model.transform(featureVector)
predicted.orderBy(['prediction', 'mpg']).show(32)

## Step 7 : Iterate over K
We are going to calculate WSSSE for various values of K:

In [None]:
kvals = []
wssses = []

for k in range(2,33):
    kmeans = KMeans().setK(k).setMaxIter(10)
    model = kmeans.fit(featureVector)
    wssse = model.computeCost(featureVector)
    print ("k", k , "wssse", wssse)
    kvals.append(k)
    wssses.append(wssse)

## Step 8 - Plot K vs WSSSE

In [None]:
%matplotlib inline
from matplotlib import pyplot

pyplot.plot(kvals, wssses)
pyplot.show()