Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

-sandbox
# Extending PdM

In this execise, we address the following concerns:
  - **A different feature set**: We go back to the original telemetry data, and this time instead of computing features consisting of rolling means and standard deviations, we instead run PCA (principal component analysis) on the telemetry data and then run the K-means clustering algorithm on the PCs (principal components). This will give us a set of K clusters based on the telemetry data. Our hope is that some of the clusters will represent cases where one or more telemetries "go off the charts". So we can use the clusters (after one-hot-encoding it) as features into the classification model instead of the original telemetries or the rolling means and standard deviations. If we are successful, we can argue that we have found a more simple feature set for the model.
  - **Multi-class classification**: We extend the problem of binary classification into multi-class classification. Recall that the PdM data flags failure *by component*, so we know which component of a machine failed at any time. In previous notebooks, we built binary classifiers for predicting failure for one component, but now we extend this to a all the components. Our model should be able to predict which component fails given the machine's telemetries, meta-data, and time elapsed since the last maintenance and failure (for each component).

## Getting Started

Run the following cell to configure our "classroom."

In [5]:
%run "../includes/setup_env"

Reading the data

We begin by reading the raw data which has the telemetry.

In [8]:
keys = ['machineID', 'datetime']
keep_left = ['volt', 'rotate', 'pressure', 'vibration']
df_raw = spark.read.parquet("dbfs:/FileStore/tables/raw").select(*keys + keep_left).cache()
display(df_raw)

machineID,datetime,volt,rotate,pressure,vibration
55,2015-12-24T10:00:00.000+0000,182.981972694613,530.909454475822,107.032899900613,44.7708995665951
55,2015-12-24T11:00:00.000+0000,178.401564032386,512.55873161194,114.291457338468,43.1882366579731
55,2015-12-24T12:00:00.000+0000,167.79139911156,437.039303681645,99.2952398974995,48.6065707027871
55,2015-12-24T13:00:00.000+0000,176.790060700523,358.5477698768,90.7554437374637,39.4746076009144
55,2015-12-24T14:00:00.000+0000,179.199446588841,455.204987383169,89.6945438813705,43.5401805425595
55,2015-12-24T15:00:00.000+0000,168.616399339852,459.96591812524,118.416352872497,46.8677057181414
55,2015-12-24T16:00:00.000+0000,167.768062331583,394.956831508565,113.010355713941,55.1758848115815
55,2015-12-24T17:00:00.000+0000,166.258281598561,358.710892333647,96.2296444369677,38.1302037667352
55,2015-12-24T18:00:00.000+0000,193.375789842157,461.953792578062,116.506075575404,40.9969526519648
55,2015-12-24T19:00:00.000+0000,182.237680868359,472.667427441534,105.801674730209,36.4958493640103


We also load the data we finished pre-processing in a prior Notebook, but we ignore the moving average and standard deviation features.

In [10]:
df_processed = spark.read.parquet("dbfs:/FileStore/tables/processed").cache()
keep_right = ['age'] + [c for c in df_processed.columns if c.startswith('diff_') or c.startswith('y_')]
df_processed = df_processed.select(*keys + keep_right).cache()
display(df_processed)

machineID,datetime,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,y_0,y_1,y_2,y_3
68,2015-06-02T02:00:00.000+0000,10,886.0,202.0,788.0,437.0,560.0,3744.0,1460.0,3744.0,3744.0,3620.0,1460.0,740.0,20.0,0,0,0,0
68,2015-06-02T03:00:00.000+0000,10,887.0,203.0,789.0,438.0,561.0,3745.0,1461.0,3745.0,3745.0,3621.0,1461.0,741.0,21.0,0,0,0,0
68,2015-06-02T04:00:00.000+0000,10,888.0,204.0,790.0,439.0,562.0,3746.0,1462.0,3746.0,3746.0,3622.0,1462.0,742.0,22.0,0,0,0,0
68,2015-06-02T05:00:00.000+0000,10,889.0,205.0,791.0,440.0,563.0,3747.0,1463.0,3747.0,3747.0,3623.0,1463.0,743.0,23.0,0,0,0,0
68,2015-06-02T06:00:00.000+0000,10,890.0,206.0,792.0,441.0,564.0,3748.0,1464.0,3748.0,3748.0,3624.0,1464.0,744.0,24.0,0,0,0,0
68,2015-06-02T07:00:00.000+0000,10,891.0,207.0,793.0,442.0,565.0,3749.0,1465.0,3749.0,3749.0,3625.0,1465.0,745.0,25.0,0,0,0,0
68,2015-06-02T08:00:00.000+0000,10,892.0,208.0,794.0,443.0,566.0,3750.0,1466.0,3750.0,3750.0,3626.0,1466.0,746.0,26.0,0,0,0,0
68,2015-06-02T09:00:00.000+0000,10,893.0,209.0,795.0,444.0,567.0,3751.0,1467.0,3751.0,3751.0,3627.0,1467.0,747.0,27.0,0,0,0,0
68,2015-06-02T10:00:00.000+0000,10,894.0,210.0,796.0,445.0,568.0,3752.0,1468.0,3752.0,3752.0,3628.0,1468.0,748.0,28.0,0,0,0,0
68,2015-06-02T11:00:00.000+0000,10,895.0,211.0,797.0,446.0,569.0,3753.0,1469.0,3753.0,3753.0,3629.0,1469.0,749.0,29.0,0,0,0,0


We now join the two datasets into one.

In [12]:
df = df_raw.join(df_processed, on = keys, how = 'inner').cache()
display(df)

machineID,datetime,volt,rotate,pressure,vibration,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,y_0,y_1,y_2,y_3
1,2015-01-01T06:00:00.000+0000,176.217853015625,418.504078221616,113.077935462083,45.0876857639276,18,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,0,0,0,0
1,2015-01-05T10:00:00.000+0000,177.278402308539,403.199388877549,100.858613460566,38.7760211849951,18,51.0,200.0,38.0,200.0,28.0,200.0,200.0,200.0,4.0,4.0,200.0,200.0,200.0,0,0,0,0
1,2015-01-10T18:00:00.000+0000,186.897240935137,422.737517067894,98.0771661207512,38.7639287237452,18,179.0,328.0,166.0,3.0,156.0,328.0,328.0,328.0,132.0,132.0,328.0,328.0,328.0,0,0,0,0
1,2015-01-10T20:00:00.000+0000,174.387051176101,429.223630669626,107.870530003791,49.6670282651719,18,181.0,330.0,168.0,5.0,158.0,330.0,330.0,330.0,134.0,134.0,330.0,330.0,330.0,0,0,0,0
1,2015-01-11T13:00:00.000+0000,134.832298929645,477.836714609516,114.442089715359,36.7648107286152,18,198.0,347.0,185.0,22.0,175.0,347.0,347.0,347.0,151.0,151.0,347.0,347.0,347.0,0,0,0,0
1,2015-01-13T21:00:00.000+0000,156.509533873056,472.048627141898,111.061221529413,45.7417388789156,18,254.0,403.0,241.0,78.0,231.0,403.0,403.0,403.0,207.0,207.0,403.0,403.0,403.0,0,0,0,0
1,2015-01-14T21:00:00.000+0000,171.73330160665,558.710536977449,105.417088202919,30.0208108862427,18,278.0,427.0,265.0,102.0,255.0,427.0,427.0,427.0,231.0,231.0,427.0,427.0,427.0,0,0,0,0
1,2015-01-15T23:00:00.000+0000,163.892357793177,563.318565580099,98.328707914028,39.4101458505182,18,304.0,453.0,291.0,128.0,281.0,453.0,453.0,453.0,257.0,257.0,453.0,453.0,453.0,0,0,0,0
1,2015-03-03T03:00:00.000+0000,153.476099534812,387.033760595707,88.27099729395,36.6588897801185,18,839.0,1561.0,1399.0,876.0,1389.0,1561.0,1561.0,1561.0,1365.0,1005.0,1561.0,285.0,1561.0,0,0,0,0
1,2015-03-09T18:00:00.000+0000,192.236915377542,520.685672565613,94.6224605133961,34.5455545574382,18,108.0,140.0,1558.0,1035.0,1548.0,84.0,1720.0,1720.0,1524.0,84.0,1720.0,444.0,1720.0,0,0,0,0


To perform multi-class classification, we need to create a single column that encodes the four classes. We will call it `label` and let `label = 1` when `y_0 = 1` (component 1 fails), `label = 2` when `y_1 = 1` (component 2 fails), `label = 3` when `y_2 = 1` (component 3 fails), `label = 4` when `y_3 = 1` (component 4 fails), and `label = 0` when no component fails.

In [14]:
from pyspark.sql.functions import when, lit, col

df = df.withColumn("label", lit(0))
for i in range(4): # iterate over the four components
    label = 'y_' + str(i) # name of target column (one per component)
    find_labels = when((col(label) == 1), lit(i+1)).otherwise(col("label"))
    df = df.withColumn("label", find_labels)

df = df.drop("y_0", "y_1", "y_2", "y_3").cache()
display(df)

machineID,datetime,volt,rotate,pressure,vibration,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,label
1,2015-01-01T06:00:00.000+0000,176.217853015625,418.504078221616,113.077935462083,45.0876857639276,18,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,0
1,2015-01-05T10:00:00.000+0000,177.278402308539,403.199388877549,100.858613460566,38.7760211849951,18,51.0,200.0,38.0,200.0,28.0,200.0,200.0,200.0,4.0,4.0,200.0,200.0,200.0,0
1,2015-01-10T18:00:00.000+0000,186.897240935137,422.737517067894,98.0771661207512,38.7639287237452,18,179.0,328.0,166.0,3.0,156.0,328.0,328.0,328.0,132.0,132.0,328.0,328.0,328.0,0
1,2015-01-10T20:00:00.000+0000,174.387051176101,429.223630669626,107.870530003791,49.6670282651719,18,181.0,330.0,168.0,5.0,158.0,330.0,330.0,330.0,134.0,134.0,330.0,330.0,330.0,0
1,2015-01-11T13:00:00.000+0000,134.832298929645,477.836714609516,114.442089715359,36.7648107286152,18,198.0,347.0,185.0,22.0,175.0,347.0,347.0,347.0,151.0,151.0,347.0,347.0,347.0,0
1,2015-01-13T21:00:00.000+0000,156.509533873056,472.048627141898,111.061221529413,45.7417388789156,18,254.0,403.0,241.0,78.0,231.0,403.0,403.0,403.0,207.0,207.0,403.0,403.0,403.0,0
1,2015-01-14T21:00:00.000+0000,171.73330160665,558.710536977449,105.417088202919,30.0208108862427,18,278.0,427.0,265.0,102.0,255.0,427.0,427.0,427.0,231.0,231.0,427.0,427.0,427.0,0
1,2015-01-15T23:00:00.000+0000,163.892357793177,563.318565580099,98.328707914028,39.4101458505182,18,304.0,453.0,291.0,128.0,281.0,453.0,453.0,453.0,257.0,257.0,453.0,453.0,453.0,0
1,2015-03-03T03:00:00.000+0000,153.476099534812,387.033760595707,88.27099729395,36.6588897801185,18,839.0,1561.0,1399.0,876.0,1389.0,1561.0,1561.0,1561.0,1365.0,1005.0,1561.0,285.0,1561.0,0
1,2015-03-09T18:00:00.000+0000,192.236915377542,520.685672565613,94.6224605133961,34.5455545574382,18,108.0,140.0,1558.0,1035.0,1548.0,84.0,1720.0,1720.0,1524.0,84.0,1720.0,444.0,1720.0,0


Let's look at some summary statistics for the labels in the data.

In [16]:
display(df.groupBy("label").count())

label,count
1,11880
3,8943
4,12735
2,17178
0,825364


Let's begin by dividing the data into training and test sets. With time-series data, we usually divide the data based on a time cut-off and to avoid **leakage** we also put a gap (2 weeks in this case) between the training and test data. Another option we have is to sample every n-th row of the data. The data is collected hourly, and if we do not wish to use such a high frequency for modeling, we can sample every n-th row of the data.

In [18]:
# from pyspark.sql.types import DateType
from pandas import datetime
from pyspark.sql.functions import col, hour

# we sample every nth row of the data using the `hour` function
df_train = df.filter((col('datetime') < datetime(2015, 10, 1))) # & (hour(col('datetime')) % 3 == 0))
df_test = df.filter(col('datetime') > datetime(2015, 10, 15))

Let's make sure we don't have any null values in our DataFrame.

In [20]:
recordCount = df_train.count()
noNullsRecordCount = df_train.na.drop().count()

print("We have {} records that contain null values.".format(recordCount - noNullsRecordCount))

In [21]:
display(df_train)

machineID,datetime,volt,rotate,pressure,vibration,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,label
1,2015-01-01T06:00:00.000+0000,176.217853015625,418.504078221616,113.077935462083,45.0876857639276,18,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,0
1,2015-01-05T10:00:00.000+0000,177.278402308539,403.199388877549,100.858613460566,38.7760211849951,18,51.0,200.0,38.0,200.0,28.0,200.0,200.0,200.0,4.0,4.0,200.0,200.0,200.0,0
1,2015-01-10T18:00:00.000+0000,186.897240935137,422.737517067894,98.0771661207512,38.7639287237452,18,179.0,328.0,166.0,3.0,156.0,328.0,328.0,328.0,132.0,132.0,328.0,328.0,328.0,0
1,2015-01-10T20:00:00.000+0000,174.387051176101,429.223630669626,107.870530003791,49.6670282651719,18,181.0,330.0,168.0,5.0,158.0,330.0,330.0,330.0,134.0,134.0,330.0,330.0,330.0,0
1,2015-01-11T13:00:00.000+0000,134.832298929645,477.836714609516,114.442089715359,36.7648107286152,18,198.0,347.0,185.0,22.0,175.0,347.0,347.0,347.0,151.0,151.0,347.0,347.0,347.0,0
1,2015-01-13T21:00:00.000+0000,156.509533873056,472.048627141898,111.061221529413,45.7417388789156,18,254.0,403.0,241.0,78.0,231.0,403.0,403.0,403.0,207.0,207.0,403.0,403.0,403.0,0
1,2015-01-14T21:00:00.000+0000,171.73330160665,558.710536977449,105.417088202919,30.0208108862427,18,278.0,427.0,265.0,102.0,255.0,427.0,427.0,427.0,231.0,231.0,427.0,427.0,427.0,0
1,2015-01-15T23:00:00.000+0000,163.892357793177,563.318565580099,98.328707914028,39.4101458505182,18,304.0,453.0,291.0,128.0,281.0,453.0,453.0,453.0,257.0,257.0,453.0,453.0,453.0,0
1,2015-03-03T03:00:00.000+0000,153.476099534812,387.033760595707,88.27099729395,36.6588897801185,18,839.0,1561.0,1399.0,876.0,1389.0,1561.0,1561.0,1561.0,1365.0,1005.0,1561.0,285.0,1561.0,0
1,2015-03-09T18:00:00.000+0000,192.236915377542,520.685672565613,94.6224605133961,34.5455545574382,18,108.0,140.0,1558.0,1035.0,1548.0,84.0,1720.0,1720.0,1524.0,84.0,1720.0,444.0,1720.0,0


## Feature engineering using PCA and K-Means

In this section, we will learn to use two un-supervised learning algorithms, namely [PCA (principal component analysis)](https://en.wikipedia.org/wiki/Principal_component_analysis) and [K-means Clustering](https://www.google.com/search?client=firefox-b-1-ab&q=k-means+clustering) in order to expand on what we learned about feature engineering so far. Prior to running this exercise, we invite you to learn more about these algorithms if you are new to using them.

- [Documentation for `KMeans`](https://spark.apache.org/docs/2.4.0/ml-clustering.html#k-means)
- [Documentation for `PCA`](https://spark.apache.org/docs/2.4.0/ml-features.html)

In [24]:
from pyspark.ml.feature import StandardScaler, VectorAssembler, OneHotEncoderEstimator, MinMaxScaler
from pyspark.ml.feature import PCA
from pyspark.ml.clustering import KMeans

from pyspark.ml import Pipeline

PCA_features = ['volt', 'rotate', 'pressure', 'vibration']
diff_features = [c for c in df.columns if c.startswith('diff_')]

stages = []
# create a single vector feature from telemetry data
stages.append(VectorAssembler(inputCols = PCA_features, outputCol = "pca_raw_features"))
# extract principal components form the telemetry data (we chose k = 4 so there's no dimensionality reduction, just orthogonalization)
stages.append(PCA(k = 4, inputCol = "pca_raw_features", outputCol = "pca_features"))
# rescale principal components prior to running k-means
stages.append(StandardScaler(inputCol = "pca_features", outputCol="scaled_pca_features", withStd = True, withMean = False))
# run k-means on rescaled principal components (we chose K = 3 to keep it simple for now)
stages.append(KMeans(featuresCol = "scaled_pca_features", predictionCol = "cluster").setK(3).setSeed(1))
# run one-hot encoding on cluster feature
stages.append(OneHotEncoderEstimator(inputCols = ["cluster"], outputCols = ["cluster_vec"], dropLast = False))
# combine all "time-elapsed-since" features into single vector
stages.append(VectorAssembler(inputCols = diff_features, outputCol = "diff_features"))
# rescale all "time-elapsed-since" features
stages.append(MinMaxScaler(inputCol = "diff_features", outputCol="scaled_diff_features"))
# create one vector with all final features
stages.append(VectorAssembler(inputCols = ['scaled_diff_features', 'age', 'cluster_vec'], outputCol = "final_features"))

data_pipeline = Pipeline(stages = stages)
print(data_pipeline.getStages())

In [25]:
featurizer = data_pipeline.fit(df_train)

df_kmeans = featurizer.transform(df_train).select(*keys + PCA_features + ["label", "cluster", "final_features"])
display(df_kmeans)

machineID,datetime,volt,rotate,pressure,vibration,label,cluster,final_features
1,2015-01-01T06:00:00.000+0000,176.217853015625,418.504078221616,113.077935462083,45.0876857639276,0,2,"List(1, 17, List(), List(0.018241517694272163, 0.015048908954100828, 0.015048908954100828, 0.01511258878645912, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 0.015048908954100828, 18.0, 0.0, 0.0, 1.0))"
1,2015-01-05T10:00:00.000+0000,177.278402308539,403.199388877549,100.858613460566,38.7760211849951,0,1,"List(1, 17, List(), List(0.009303174024078804, 0.030097817908201655, 0.005718585402558315, 0.03022517757291824, 0.004213694507148232, 0.030097817908201655, 0.030097817908201655, 0.030097817908201655, 6.019563581640331E-4, 6.019563581640331E-4, 0.030097817908201655, 0.030097817908201655, 0.030097817908201655, 18.0, 0.0, 1.0, 0.0))"
1,2015-01-10T18:00:00.000+0000,186.897240935137,422.737517067894,98.0771661207512,38.7639287237452,0,1,"List(1, 17, List(), List(0.03265231667274717, 0.049360421369450715, 0.024981188863807374, 4.533776635937736E-4, 0.02347629796839729, 0.049360421369450715, 0.049360421369450715, 0.049360421369450715, 0.019864559819413093, 0.019864559819413093, 0.049360421369450715, 0.049360421369450715, 0.049360421369450715, 18.0, 0.0, 1.0, 0.0))"
1,2015-01-10T20:00:00.000+0000,174.387051176101,429.223630669626,107.870530003791,49.6670282651719,0,2,"List(1, 17, List(), List(0.033017147026632616, 0.04966139954853273, 0.02528216704288939, 7.55629439322956E-4, 0.023777276147479307, 0.04966139954853273, 0.04966139954853273, 0.04966139954853273, 0.02016553799849511, 0.02016553799849511, 0.04966139954853273, 0.04966139954853273, 0.04966139954853273, 18.0, 0.0, 0.0, 1.0))"
1,2015-01-11T13:00:00.000+0000,134.832298929645,477.836714609516,114.442089715359,36.7648107286152,0,0,"List(1, 17, List(), List(0.03611820503465888, 0.052219714070729874, 0.02784048156508653, 0.0033247695330210067, 0.02633559066967645, 0.052219714070729874, 0.052219714070729874, 0.052219714070729874, 0.02272385252069225, 0.02272385252069225, 0.052219714070729874, 0.052219714070729874, 0.052219714070729874, 18.0, 1.0, 0.0, 0.0))"
1,2015-01-13T21:00:00.000+0000,156.509533873056,472.048627141898,111.061221529413,45.7417388789156,0,2,"List(1, 17, List(), List(0.04633345494345129, 0.060647103085026334, 0.036267870579382994, 0.011787819253438114, 0.03476297968397291, 0.060647103085026334, 0.060647103085026334, 0.060647103085026334, 0.031151241534988713, 0.031151241534988713, 0.060647103085026334, 0.060647103085026334, 0.060647103085026334, 18.0, 0.0, 0.0, 1.0))"
1,2015-01-14T21:00:00.000+0000,171.73330160665,558.710536977449,105.417088202919,30.0208108862427,0,0,"List(1, 17, List(), List(0.050711419190076615, 0.06425884123401053, 0.0398796087283672, 0.015414840562188304, 0.03837471783295711, 0.06425884123401053, 0.06425884123401053, 0.06425884123401053, 0.03476297968397291, 0.03476297968397291, 0.06425884123401053, 0.06425884123401053, 0.06425884123401053, 18.0, 1.0, 0.0, 0.0))"
1,2015-01-15T23:00:00.000+0000,163.892357793177,563.318565580099,98.328707914028,39.4101458505182,0,0,"List(1, 17, List(), List(0.055454213790587376, 0.06817155756207675, 0.04379232505643341, 0.019344113646667674, 0.042287434161023325, 0.06817155756207675, 0.06817155756207675, 0.06817155756207675, 0.03867569601203913, 0.03867569601203913, 0.06817155756207675, 0.06817155756207675, 0.06817155756207675, 18.0, 1.0, 0.0, 0.0))"
1,2015-03-03T03:00:00.000+0000,153.476099534812,387.033760595707,88.27099729395,36.6588897801185,0,0,"List(1, 17, List(), List(0.15304633345494345, 0.23491346877351393, 0.21053423626787057, 0.1323862777693819, 0.20902934537246048, 0.23491346877351393, 0.23491346877351393, 0.23491346877351393, 0.2054176072234763, 0.15124153498871332, 0.23491346877351393, 0.04288939051918736, 0.23491346877351393, 18.0, 1.0, 0.0, 0.0))"
1,2015-03-09T18:00:00.000+0000,192.236915377542,520.685672565613,94.6224605133961,34.5455545574382,0,1,"List(1, 17, List(), List(0.019700839109813937, 0.021068472535741158, 0.23446200150489088, 0.1564152939398519, 0.23295711060948082, 0.012641083521444696, 0.2588412340105342, 0.2588412340105342, 0.2293453724604966, 0.012641083521444696, 0.2588412340105342, 0.06681715575620767, 0.2588412340105342, 18.0, 0.0, 1.0, 0.0))"


In [26]:
display(df_kmeans.groupBy("cluster").mean())

cluster,avg(machineID),avg(volt),avg(rotate),avg(pressure),avg(vibration),avg(label),avg(cluster)
1,50.508063739539146,185.52874665010123,459.1183907669744,99.59418879229914,38.64851926953644,0.1128327410294623,1.0
2,50.53246452472052,169.20489882328712,413.16828602262393,105.68700537747344,44.402088354050136,0.2481170720409926,2.0
0,50.46298218678577,158.2075781365755,465.1162695482352,97.60024971652518,38.43803462313168,0.0824580839890433,0.0


## Train a Logistic Regression Model

Let's build some of the transformations we'll need in our pipeline.

[Logistic Regression Docs](https://spark.apache.org/docs/latest/ml-classification-regression.html#logistic-regression)

In [29]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

lr = (LogisticRegression()
     .setLabelCol("label")
     .setFeaturesCol("final_features"))

model_pipeline = Pipeline(stages = [lr])
assert len(model_pipeline.getStages()) == 1 # make sure it's one stage only
print(model_pipeline.getStages())

lr_model = model_pipeline.fit(df_kmeans)

df_pred = lr_model.transform(featurizer.transform(df_test)) # apply the model to our held-out test set
display(df_pred)

machineID,datetime,volt,rotate,pressure,vibration,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,label,pca_raw_features,pca_features,scaled_pca_features,cluster,cluster_vec,diff_features,scaled_diff_features,final_features,rawPrediction,probability,prediction
1,2015-10-26T04:00:00.000+0000,170.120714877836,412.697404999149,84.8003277747275,40.4846724095414,18,263.0,1019.0,238.0,3789.0,1318.0,5614.0,214.0,7250.0,1294.0,574.0,214.0,3454.0,1294.0,0,"List(1, 4, List(), List(170.120714877836, 412.697404999149, 84.8003277747275, 40.4846724095414))","List(1, 4, List(), List(412.6440767828122, -170.380151829108, -84.55407091022357, 40.452303588871175))","List(1, 4, List(), List(7.846600908158874, -10.995354997639364, -7.670295894688139, 7.521337075275031))",0,"List(0, 3, List(0), List(1.0))","List(1, 13, List(), List(263.0, 1019.0, 238.0, 3789.0, 1318.0, 5614.0, 214.0, 7250.0, 1294.0, 574.0, 214.0, 3454.0, 1294.0))","List(1, 13, List(), List(0.04797519153593579, 0.15334838224228745, 0.03581640331075997, 0.5726159891189361, 0.1983446200150489, 0.8448457486832205, 0.032204665161775774, 1.09104589917231, 0.19473288186606472, 0.08638073739653875, 0.032204665161775774, 0.5197893152746426, 0.19473288186606472))","List(1, 17, List(), List(0.04797519153593579, 0.15334838224228745, 0.03581640331075997, 0.5726159891189361, 0.1983446200150489, 0.8448457486832205, 0.032204665161775774, 1.09104589917231, 0.19473288186606472, 0.08638073739653875, 0.032204665161775774, 0.5197893152746426, 0.19473288186606472, 18.0, 1.0, 0.0, 0.0))","List(1, 5, List(), List(4.2930492390421735, -1.2093905406769245, -0.30293091309839365, -2.7744874071722556, -0.006240378094600985))","List(1, 5, List(), List(0.9721955304872955, 0.003963459153443546, 0.009811712205954773, 8.286315291671223E-4, 0.013200666624139013))",0.0
1,2015-10-27T04:00:00.000+0000,168.397208150711,517.958039049934,77.592313718511,46.4860391294064,18,287.0,1043.0,262.0,3813.0,1342.0,5638.0,238.0,7274.0,1318.0,598.0,238.0,3478.0,1318.0,0,"List(1, 4, List(), List(168.397208150711, 517.958039049934, 77.592313718511, 46.4860391294064))","List(1, 4, List(), List(517.9034058163573, -168.66783995301668, -77.36155036346875, 46.49818755930855))","List(1, 4, List(), List(9.848151380484014, -10.88485223812086, -7.017828660073146, 8.645454300385833))",0,"List(0, 3, List(0), List(1.0))","List(1, 13, List(), List(287.0, 1043.0, 262.0, 3813.0, 1342.0, 5638.0, 238.0, 7274.0, 1318.0, 598.0, 238.0, 3478.0, 1318.0))","List(1, 13, List(), List(0.05235315578256111, 0.15696012039127163, 0.039428141459744166, 0.5762430104276862, 0.2019563581640331, 0.8484574868322047, 0.03581640331075997, 1.0946576373212942, 0.1983446200150489, 0.08999247554552295, 0.03581640331075997, 0.5234010534236267, 0.1983446200150489))","List(1, 17, List(), List(0.05235315578256111, 0.15696012039127163, 0.039428141459744166, 0.5762430104276862, 0.2019563581640331, 0.8484574868322047, 0.03581640331075997, 1.0946576373212942, 0.1983446200150489, 0.08999247554552295, 0.03581640331075997, 0.5234010534236267, 0.1983446200150489, 18.0, 1.0, 0.0, 0.0))","List(1, 5, List(), List(4.288429145789501, -1.2041708070154686, -0.30971150716040324, -2.764627963912838, -0.009918867700792422))","List(1, 5, List(), List(0.972154176930541, 0.004002481153990907, 0.009790120423916856, 8.406812536597558E-4, 0.013212540237891597))",0.0
1,2015-10-27T23:00:00.000+0000,224.097654096511,453.604244578796,95.4468648899599,48.2354156550226,18,306.0,1062.0,281.0,3832.0,1361.0,5657.0,257.0,7293.0,1337.0,617.0,257.0,3497.0,1337.0,0,"List(1, 4, List(), List(224.097654096511, 453.604244578796, 95.4468648899599, 48.2354156550226))","List(1, 4, List(), List(453.5373402262138, -224.389854048014, -95.11191186366136, 48.16769928776213))","List(1, 4, List(), List(8.62420353503836, -14.480830523035666, -8.6280471093861, 8.955868278003637))",1,"List(0, 3, List(1), List(1.0))","List(1, 13, List(), List(306.0, 1062.0, 281.0, 3832.0, 1361.0, 5657.0, 257.0, 7293.0, 1337.0, 617.0, 257.0, 3497.0, 1337.0))","List(1, 13, List(), List(0.05581904414447282, 0.1598194130925508, 0.042287434161023325, 0.5791144022971135, 0.20481565086531225, 0.8513167795334838, 0.03867569601203913, 1.0975169300225733, 0.20120391271632806, 0.09285176824680211, 0.03867569601203913, 0.5262603461249059, 0.20120391271632806))","List(1, 17, List(), List(0.05581904414447282, 0.1598194130925508, 0.042287434161023325, 0.5791144022971135, 0.20481565086531225, 0.8513167795334838, 0.03867569601203913, 1.0975169300225733, 0.20120391271632806, 0.09285176824680211, 0.03867569601203913, 0.5262603461249059, 0.20120391271632806, 18.0, 0.0, 1.0, 0.0))","List(1, 5, List(), List(3.9067363563336936, -0.1923390125544004, -0.4767947833890159, -2.872751471514304, -0.36485108887597506))","List(1, 5, List(), List(0.9577024246370371, 0.015886373661455094, 0.011953282998215187, 0.0010887707204390057, 0.013369147982853493))",0.0
1,2015-10-28T13:00:00.000+0000,165.692303883705,476.441195322744,111.891569855552,44.0942416640864,18,320.0,1076.0,295.0,3846.0,1375.0,5671.0,271.0,7307.0,1351.0,631.0,271.0,3511.0,1351.0,0,"List(1, 4, List(), List(165.692303883705, 476.441195322744, 111.891569855552, 44.0942416640864))","List(1, 4, List(), List(476.384924322686, -166.01463309062612, -111.6606688584693, 44.075142965832406))","List(1, 4, List(), List(9.058659969063399, -10.71362952807528, -10.129262384688937, 8.194935207056114))",2,"List(0, 3, List(2), List(1.0))","List(1, 13, List(), List(320.0, 1076.0, 295.0, 3846.0, 1375.0, 5671.0, 271.0, 7307.0, 1351.0, 631.0, 271.0, 3511.0, 1351.0))","List(1, 13, List(), List(0.058372856621670924, 0.1619262603461249, 0.04439428141459744, 0.5812301647272178, 0.20692249811888638, 0.853423626787058, 0.040782543265613244, 1.0996237772761475, 0.2033107599699022, 0.09495861550037622, 0.040782543265613244, 0.52836719337848, 0.2033107599699022))","List(1, 17, List(), List(0.058372856621670924, 0.1619262603461249, 0.04439428141459744, 0.5812301647272178, 0.20692249811888638, 0.853423626787058, 0.040782543265613244, 1.0996237772761475, 0.2033107599699022, 0.09495861550037622, 0.040782543265613244, 0.52836719337848, 0.2033107599699022, 18.0, 0.0, 0.0, 1.0))","List(1, 5, List(), List(3.4937865563854253, -1.5561896892045746, -0.07428795676485955, -2.4077845942726483, 0.5444756838566556))","List(1, 5, List(), List(0.917658050731259, 0.005881716151815837, 0.025887243393982413, 0.0025099270194699316, 0.048063062703473027))",0.0
1,2015-11-21T16:00:00.000+0000,152.455332119677,459.061761737657,82.3002971341767,44.6643915771304,18,899.0,1655.0,276.0,4425.0,1954.0,6250.0,850.0,7886.0,1930.0,1210.0,130.0,4090.0,1930.0,0,"List(1, 4, List(), List(152.455332119677, 459.061761737657, 82.3002971341767, 44.6643915771304))","List(1, 4, List(), List(459.01027566127493, -152.72222027348087, -82.09184081491193, 44.66537930488436))","List(1, 4, List(), List(8.728273707303666, -9.855813660847522, -7.446935467584647, 8.304678437137758))",0,"List(0, 3, List(0), List(1.0))","List(1, 13, List(), List(899.0, 1655.0, 276.0, 4425.0, 1954.0, 6250.0, 850.0, 7886.0, 1930.0, 1210.0, 130.0, 4090.0, 1930.0))","List(1, 13, List(), List(0.16399124407150675, 0.24905944319036868, 0.04153498871331828, 0.6687320538008161, 0.2940556809631302, 0.9405568096313017, 0.12791572610985705, 1.1867569601203913, 0.290443942814146, 0.18209179834462003, 0.019563581640331076, 0.6155003762227238, 0.290443942814146))","List(1, 17, List(), List(0.16399124407150675, 0.24905944319036868, 0.04153498871331828, 0.6687320538008161, 0.2940556809631302, 0.9405568096313017, 0.12791572610985705, 1.1867569601203913, 0.290443942814146, 0.18209179834462003, 0.019563581640331076, 0.6155003762227238, 0.290443942814146, 18.0, 1.0, 0.0, 0.0))","List(1, 5, List(), List(4.229365201948448, -1.2203823939672427, -0.31999030480213503, -2.721475782240328, 0.03248327906125681))","List(1, 5, List(), List(0.9700528742122595, 0.0041686947765657055, 0.010257355651569693, 9.291450611931563E-4, 0.014591930298411935))",0.0
1,2015-11-23T22:00:00.000+0000,175.242096239815,490.958338505211,103.239394721891,35.8038424763724,18,953.0,1709.0,330.0,4479.0,2008.0,6304.0,904.0,7940.0,1984.0,1264.0,184.0,4144.0,1984.0,0,"List(1, 4, List(), List(175.242096239815, 490.958338505211, 103.239394721891, 35.8038424763724))","List(1, 4, List(), List(490.9042025971526, -175.54467190704077, -102.9884032863612, 35.786743732366695))","List(1, 4, List(), List(9.334750160354828, -11.32864341791418, -9.342560546453221, 6.653864886737022))",1,"List(0, 3, List(1), List(1.0))","List(1, 13, List(), List(953.0, 1709.0, 330.0, 4479.0, 2008.0, 6304.0, 904.0, 7940.0, 1984.0, 1264.0, 184.0, 4144.0, 1984.0))","List(1, 13, List(), List(0.1738416636264137, 0.25718585402558314, 0.04966139954853273, 0.676892851745504, 0.3021820917983446, 0.9486832204665162, 0.13604213694507147, 1.1948833709556057, 0.2985703536493604, 0.19021820917983445, 0.027689992475545523, 0.6236267870579383, 0.2985703536493604))","List(1, 17, List(), List(0.1738416636264137, 0.25718585402558314, 0.04966139954853273, 0.676892851745504, 0.3021820917983446, 0.9486832204665162, 0.13604213694507147, 1.1948833709556057, 0.2985703536493604, 0.19021820917983445, 0.027689992475545523, 0.6236267870579383, 0.2985703536493604, 18.0, 0.0, 1.0, 0.0))","List(1, 5, List(), List(3.8409347764991604, -0.20093848791655172, -0.4969619473711785, -2.815220935088477, -0.3278134061229545))","List(1, 5, List(), List(0.9547576518628419, 0.016769877585191283, 0.01247293150569973, 0.0012278956235529067, 0.014771643422714233))",0.0
1,2015-11-30T21:00:00.000+0000,196.902567481234,487.154051936604,112.7574730389,35.5491877878255,18,1120.0,1876.0,497.0,4646.0,2175.0,6471.0,1071.0,8107.0,2151.0,1431.0,351.0,4311.0,2151.0,0,"List(1, 4, List(), List(196.902567481234, 487.154051936604, 112.7574730389, 35.5491877878255))","List(1, 4, List(), List(487.0952645027496, -197.22118288278887, -112.46733876583414, 35.50760354750271))","List(1, 4, List(), List(9.262321598327038, -12.72752075620669, -10.202439191106675, 6.601964074289985))",1,"List(0, 3, List(1), List(1.0))","List(1, 13, List(), List(1120.0, 1876.0, 497.0, 4646.0, 2175.0, 6471.0, 1071.0, 8107.0, 2151.0, 1431.0, 351.0, 4311.0, 2151.0))","List(1, 13, List(), List(0.20430499817584824, 0.2823175319789315, 0.07479307750188112, 0.7021308750188907, 0.327313769751693, 0.9738148984198646, 0.16117381489841986, 1.220015048908954, 0.32370203160270883, 0.21534988713318284, 0.05282167042889391, 0.6487584650112866, 0.32370203160270883))","List(1, 17, List(), List(0.20430499817584824, 0.2823175319789315, 0.07479307750188112, 0.7021308750188907, 0.327313769751693, 0.9738148984198646, 0.16117381489841986, 1.220015048908954, 0.32370203160270883, 0.21534988713318284, 0.05282167042889391, 0.6487584650112866, 0.32370203160270883, 18.0, 0.0, 1.0, 0.0))","List(1, 5, List(), List(3.8087866276159814, -0.16461784118892064, -0.5441435810526635, -2.746615642408365, -0.35340956296603454))","List(1, 5, List(), List(0.9535851346274099, 0.017936257784630136, 0.012271733138546266, 0.0013563895641809315, 0.014850484885232792))",0.0
1,2015-12-07T11:00:00.000+0000,205.231975468309,483.093126108574,107.295244865741,40.7328885052642,18,1278.0,2034.0,655.0,4804.0,2333.0,6629.0,1229.0,8265.0,2309.0,1589.0,509.0,149.0,2309.0,0,"List(1, 4, List(), List(205.231975468309, 483.093126108574, 107.295244865741, 40.7328885052642))","List(1, 4, List(), List(483.0313812533932, -205.54442251744965, -106.99261002112499, 40.685704148206))","List(1, 4, List(), List(9.185045146805756, -13.26465476818544, -9.705800898437595, 7.564733473615072))",1,"List(0, 3, List(1), List(1.0))","List(1, 13, List(), List(1278.0, 2034.0, 655.0, 4804.0, 2333.0, 6629.0, 1229.0, 8265.0, 2309.0, 1589.0, 509.0, 149.0, 2309.0))","List(1, 13, List(), List(0.23312659613279824, 0.3060948081264108, 0.09857035364936043, 0.7260087653014962, 0.3510910458991723, 0.9975921745673438, 0.18495109104589919, 1.2437923250564333, 0.3474793077501881, 0.23912716328066216, 0.07659894657637321, 0.022422874341610232, 0.3474793077501881))","List(1, 17, List(), List(0.23312659613279824, 0.3060948081264108, 0.09857035364936043, 0.7260087653014962, 0.3510910458991723, 0.9975921745673438, 0.18495109104589919, 1.2437923250564333, 0.3474793077501881, 0.23912716328066216, 0.07659894657637321, 0.022422874341610232, 0.3474793077501881, 18.0, 0.0, 1.0, 0.0))","List(1, 5, List(), List(4.677977140767688, 0.3709922413808563, -0.39101342841404224, -4.748894927729077, 0.09093897399457385))","List(1, 5, List(), List(0.9708489176360122, 0.01308132920692881, 0.006105440704312713, 7.81831438137727E-5, 0.009886129308932203))",0.0
1,2015-12-13T03:00:00.000+0000,143.933361613718,377.860973853712,89.6852174468,37.3869524527543,18,1414.0,2170.0,791.0,4940.0,2469.0,6765.0,1365.0,8401.0,2445.0,1725.0,645.0,285.0,2445.0,0,"List(1, 4, List(), List(143.933361613718, 377.860973853712, 89.6852174468, 37.3869524527543))","List(1, 4, List(), List(377.81310026393356, -144.19267731908005, -89.4810023756405, 37.360860263734935))","List(1, 4, List(), List(7.184275220326591, -9.305366019238905, -8.117240929809185, 6.946541940397829))",0,"List(0, 3, List(0), List(1.0))","List(1, 13, List(), List(1414.0, 2170.0, 791.0, 4940.0, 2469.0, 6765.0, 1365.0, 8401.0, 2445.0, 1725.0, 645.0, 285.0, 2445.0))","List(1, 13, List(), List(0.2579350601970084, 0.326561324303988, 0.11903686982693755, 0.7465618860510805, 0.37155756207674945, 1.018058690744921, 0.2054176072234763, 1.2642588412340106, 0.36794582392776526, 0.2595936794582393, 0.09706546275395034, 0.04288939051918736, 0.36794582392776526))","List(1, 17, List(), List(0.2579350601970084, 0.326561324303988, 0.11903686982693755, 0.7465618860510805, 0.37155756207674945, 1.018058690744921, 0.2054176072234763, 1.2642588412340106, 0.36794582392776526, 0.2595936794582393, 0.09706546275395034, 0.04288939051918736, 0.36794582392776526, 18.0, 1.0, 0.0, 0.0))","List(1, 5, List(), List(5.029831827966653, -0.6071287731833087, -0.26772148883590985, -4.577095849077208, 0.42211428312977173))","List(1, 5, List(), List(0.9817306550688994, 0.0034985772560288814, 0.004912405163762315, 6.603233792973401E-5, 0.009792330173379738))",0.0
1,2015-12-15T04:00:00.000+0000,160.518917354301,423.11039056228,97.2897478115814,61.8990673445743,18,1463.0,2219.0,840.0,4989.0,2518.0,6814.0,1414.0,8450.0,2494.0,1774.0,694.0,334.0,2494.0,4,"List(1, 4, List(), List(160.518917354301, 423.11039056228, 97.2897478115814, 61.8990673445743))","List(1, 4, List(), List(423.0496937527122, -160.82042375913198, -97.07273074837585, 61.872028093914174))","List(1, 4, List(), List(8.044468097244797, -10.378425134143805, -8.805922176544268, 11.503927775159884))",2,"List(0, 3, List(2), List(1.0))","List(1, 13, List(), List(1463.0, 2219.0, 840.0, 4989.0, 2518.0, 6814.0, 1414.0, 8450.0, 2494.0, 1774.0, 694.0, 334.0, 2494.0))","List(1, 13, List(), List(0.26687340386720176, 0.3339352896914974, 0.12641083521444696, 0.7539670545564455, 0.37893152746425884, 1.0254326561324305, 0.2127915726109857, 1.27163280662152, 0.37531978931527465, 0.2669676448457487, 0.10443942814145975, 0.05026335590669676, 0.37531978931527465))","List(1, 17, List(), List(0.26687340386720176, 0.3339352896914974, 0.12641083521444696, 0.7539670545564455, 0.37893152746425884, 1.0254326561324305, 0.2127915726109857, 1.27163280662152, 0.37531978931527465, 0.2669676448457487, 0.10443942814145975, 0.05026335590669676, 0.37531978931527465, 18.0, 0.0, 0.0, 1.0))","List(1, 5, List(), List(4.232109176394129, -0.9556678329314443, -0.03681833448170613, -4.213679517264074, 0.9740565082830934))","List(1, 5, List(), List(0.9449480606484107, 0.00527697467023569, 0.013226236265794767, 2.0297778692309536E-4, 0.03634575062863578))",0.0


## Evaluate the Model

We need to make sure that we use `MulticlassClassificationEvaluator`, not `BinaryClassificationEvaluator`. As we can see below, the evaluation metrics for the multi-class case are different from the binary case.

In [32]:
from pyspark.ml.evaluation import  MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator()
print(evaluator.explainParams())

In [33]:
def printEval(df, labelCol = "label"):
  evaluator = MulticlassClassificationEvaluator()
  evaluator.setLabelCol(labelCol)
  
  wrecall = evaluator.setMetricName("weightedPrecision").evaluate(df)
  wprecis = evaluator.setMetricName("weightedPrecision").evaluate(df)
  print("weighted recall: {}\nweighted precision: {}".format(wrecall, wprecis))

In [34]:
printEval(df_pred)

## Train a Random Forest Model

Let's now compare this to what we get if we use a random forest with cross-validation instead.

In [36]:
from pyspark.ml.classification import RandomForestClassifier

rf = (RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("final_features")
      .setSeed(27))

# print(rf.explainParams())

We set up a single-stage pipeline for the model.

In [38]:
from pyspark.ml import Pipeline

model_pipeline = Pipeline(stages = [rf])

model_pipeline.getStages()

# model_pipeline.getStages()[0].extractParamMap()

We perform a grid search to find the optimal combination of `maxDepth` and `numTrees` for the random forest, and use cross validation to evaluate the algorithm's performance.

In [40]:
from pyspark.ml.tuning import ParamGridBuilder

paramGrid = (ParamGridBuilder()\
            .addGrid(rf.maxDepth, [5, 10, 20]) \
            .addGrid(rf.numTrees, [20, 50]) \
            .build())

from pyspark.ml.tuning import CrossValidator
from pyspark.ml.evaluation import  MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator().setMetricName("weightedPrecision")

cv = (CrossValidator()
      .setEstimator(model_pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(3)
      .setSeed(27))

cv_model = cv.fit(df_kmeans)

We can see the result of the grid search and the average (cross-validated) evaluation metric here:

In [42]:
ll = list(zip(cv_model.getEstimatorParamMaps(), cv_model.avgMetrics))
[(list(ll[i][0].values()), ll[i][1]) for i in range(len(ll))]

## Evaluate the Model

Finally, here's the final model's performance.

In [44]:
df_pred = cv_model.transform(featurizer.transform(df_test))

printEval(df_pred)

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.