In [1]:
filePath = "/FileStore/tables/Group/" #put your own file path if necessary

#Importing the files one by one
Complaints = spark.read\
  .format("csv")\
  .option("inferSchema","true")\
  .option("header","true")\
  .option("delimiter",",")\
  .option("0","NA")\
  .load(filePath + "BDT2_1920_Complaints.csv")\

Delivery=spark.read\
  .format("csv")\
  .option("header","true")\
  .option("inferSchema","true")\
  .option("delimiter",",")\
  .load(filePath + "BDT2_1920_Delivery.csv")

Subscriptions=spark.read\
  .format("csv")\
  .option("header","true")\
  .option("inferSchema","true")\
  .option("delimiter",",")\
  .load(filePath + "BDT2_1920_Subscriptions.csv")

Customers=spark.read\
  .format("csv")\
  .option("header","true")\
  .option("inferSchema","true")\
  .option("delimiter",",")\
  .load(filePath + "BDT2_1920_Customers.csv")

Formula=spark.read\
  .format("csv")\
  .option("header","true")\
  .option("inferSchema","true")\
  .option("delimiter",",")\
  .load(filePath + "BDT2_1920_Formula.csv")

In [2]:
from pyspark.sql.functions import *

#Replacing NA in Complaints
#replacing NA with meaningfull value when possible.
#unknown ID will take the value 0
#unknown numeric values like quantities will take the value 999
#NA values in string type column will take a "NA" value or a "no response"/"no solution" depending on the context

#Replacing NA in Complaints
Complaints = Complaints.withColumn("ProductID", when(Complaints["ProductID"] == "NA", 0).otherwise(Complaints["ProductID"]))\
  .withColumn("ProductName", when(Complaints["ProductName"] == "NA", "NA").otherwise(Complaints["ProductName"]))\
  .withColumn("FeedbackTypeID", when(Complaints["FeedbackTypeID"] == "NA", 0).otherwise(Complaints["FeedbackTypeID"]))\
  .withColumn("FeedbackTypeDesc", when(Complaints["FeedbackTypeDesc"] == "NA", "no response").otherwise(Complaints["FeedbackTypeDesc"]))\
  .withColumn("SolutionTypeID", when(Complaints["SolutionTypeID"] == "NA", 0).otherwise(Complaints["SolutionTypeID"]))\
  .withColumn("SolutionTypeDesc", when(Complaints["SolutionTypeDesc"] == "NA", "no solution").otherwise(Complaints["SolutionTypeDesc"]))

#Replacing NA in Delivery
Delivery = Delivery.na.fill("NA", "DeliveryClass")

#Replacing NA in Subscriptions
#NbrMeals_EXCEP NAs have been replaced by the mean NbrMeals_EXCEP ordered by the same NbrMeals_REG category
Subscriptions = Subscriptions.withColumn("NbrMeals_EXCEP",when((Subscriptions["NbrMeals_EXCEP"] == "NA") & (Subscriptions["NbrMeals_REG"]==76), 12).otherwise(Subscriptions["NbrMeals_EXCEP"]))
Subscriptions = Subscriptions.withColumn("NbrMeals_EXCEP",when((Subscriptions["NbrMeals_EXCEP"] == "NA") & (Subscriptions["NbrMeals_REG"]==304), 13).otherwise(Subscriptions["NbrMeals_EXCEP"]))
Subscriptions = Subscriptions.withColumn("NbrMeals_EXCEP",when((Subscriptions["NbrMeals_EXCEP"] == "NA") & (Subscriptions["NbrMeals_REG"]==329), 17).otherwise(Subscriptions["NbrMeals_EXCEP"]))
Subscriptions = Subscriptions.withColumn("NbrMeals_EXCEP",when((Subscriptions["NbrMeals_EXCEP"] == "NA") & (Subscriptions["NbrMeals_REG"]==152), 13).otherwise(Subscriptions["NbrMeals_EXCEP"]))

#RenewalDate 1 and 0 (so if a client renewed 6 times his subscription, the value can be summed to 6)
Subscriptions = Subscriptions.withColumn("RenewalDate",when(Subscriptions["RenewalDate"] == "NA",0).otherwise(1))

#PaymentDate Redondant with PaymentStatus
  #GrossFormulaPrice
  #NetFormulaPrice
  #NbrMealsPrice
  #ProductDiscount
  #FormulaDiscount
  #TotalDiscount
  #TotalPrice
  #TotalCredit
  #All of those are codependent. Maybe after grouping the NbrMeals_REG/EXCEP features, we can replace NA's by the mean of 
  # the category they belong to


In [3]:
#Complaints
Complaints = Complaints.withColumn("ProductID", Complaints["ProductID"].cast("integer"))\
  .withColumn("SolutionTypeID", Complaints["SolutionTypeID"].cast("integer"))\
  .withColumn("FeedbackTypeID", Complaints["FeedbackTypeID"].cast("integer"))

#Subscriptions
Subscriptions = Subscriptions.withColumn("NbrMeals_EXCEP",Subscriptions["NbrMeals_EXCEP"].cast("integer"))
Subscriptions = Subscriptions.withColumn("RenewalDate",Subscriptions["RenewalDate"].cast("integer"))
  #converting timestamps to number of days
Subscriptions = Subscriptions.withColumn("EndDate",Subscriptions["EndDate"].cast("long")/86400)
Subscriptions = Subscriptions.withColumn("StartDate",Subscriptions["StartDate"].cast("long")/86400)
Complaints = Complaints.withColumn("ComplaintDate",Complaints["ComplaintDate"].cast("long")/86400)

In [4]:
#Subscriptions
Subscriptions = Subscriptions.withColumn("SubscriptionDuration", Subscriptions.EndDate - Subscriptions.StartDate)
Subscriptions = Subscriptions.withColumn("NbrMealsPerDay", when(Subscriptions["SubscriptionDuration"] == 0, Subscriptions.NbrMeals_REG).otherwise(Subscriptions.NbrMeals_REG / Subscriptions.SubscriptionDuration))

In [5]:
display(Subscriptions)

SubscriptionID,CustomerID,StartDate,EndDate,NbrMeals_REG,NbrMeals_EXCEP,RenewalDate,PaymentType,PaymentStatus,PaymentDate,FormulaID,GrossFormulaPrice,NetFormulaPrice,NbrMealsPrice,ProductDiscount,FormulaDiscount,TotalDiscount,TotalPrice,TotalCredit,ProductName,SubscriptionDuration,NbrMealsPerDay
627529,775138,17135.0,17225.0,76,10,1,BT,Paid,2016-12-01,919,1480.0,1480.0,19.47368,0.0,0.0,0.0,1480.0,0.0,Custom Events,90.0,0.8444444444444444
637001,194809,17039.0,17227.0,152,25,1,BT,Paid,2016-08-22,4192,2760.0,1760.0,11.57894,0.0,1000.0,1000.0,1760.0,0.0,Custom Events,188.0,0.8085106382978723
1238870,654824,17860.0,17950.0,76,10,1,BT,Paid,2018-11-09,10961,1580.0,1580.0,20.78948,0.0,0.0,0.0,1580.0,0.0,Custom Events,90.0,0.8444444444444444
315743,626815,16801.0,17165.0,304,10,1,BT,Paid,2015-12-18,896,4980.0,4980.0,16.38158,0.0,0.0,0.0,4980.0,0.0,Custom Events,364.0,0.8351648351648352
1176762,1016426,17760.0,17788.0,25,25,0,BT,Paid,2018-08-13,12867,540.0,300.0,12.0,0.0,240.0,240.0,300.0,0.0,Custom Events,28.0,0.8928571428571429
916472,871676,17491.0,17581.0,76,10,1,BT,Paid,2017-11-20,5100,1540.0,1540.0,20.26316,0.0,0.0,0.0,1540.0,0.0,Custom Events,90.0,0.8444444444444444
646275,655981,17099.0,17130.0,25,10,1,DD,Paid,2016-10-15,924,520.0,520.0,20.8,0.0,0.0,0.0,520.0,0.0,Custom Events,31.0,0.8064516129032258
752611,704300,17230.0,17258.0,25,25,1,BT,Paid,2017-03-24,5389,458.0,458.0,18.32,0.0,0.0,0.0,458.0,0.0,Custom Events,28.0,0.8928571428571429
1079202,684448,17652.0,17684.0,25,25,1,DD,Paid,2018-04-21,9466,472.0,472.0,18.88,0.0,0.0,0.0,472.0,0.0,Custom Events,32.0,0.78125
669473,276941,17122.0,17487.0,304,25,1,BT,Paid,2016-11-24,891,4980.0,4980.0,16.38158,0.0,0.0,0.0,4980.0,0.0,Custom Events,365.0,0.8328767123287671


In [6]:
Subscriptions.createOrReplaceTempView("subscriptions")

In [7]:
SubInter = spark.sql("select CustomerID, sum(NbrMeals_REG) as TotalMeal_REG, avg(NbrMeals_REG) as MeanMeal_REGPerSub, sum(NbrMeals_EXCEP) as TotalMeal_EXCEP, avg(NbrMeals_EXCEP) as MeanMeal_EXCEPPerSub, min(StartDate) as FirstSubDate, max(EndDate) as EndOfLastSub, (max(EndDate)-min(StartDate)) as HasBeenClientForXDays,count(SubscriptionID) as NbrSub, SUM(CASE WHEN PaymentStatus='Paid' THEN 1 ELSE 0 END) as SubPaid, SUM(CASE WHEN PaymentStatus='Not Paid' THEN 1 ELSE 0 END) as SubNotPaid, SUM(CASE WHEN PaymentStatus='Paid' THEN 1 ELSE 0 END)/count(SubscriptionID) as ProportionPaidSub,avg(NbrMealsPrice) as AvgPricePerMeal, sum(ProductDiscount) as TotalProductDiscount, sum(FormulaDiscount) as TotalFormulaDiscount, sum(TotalDiscount) as TotalDiscount, sum(TotalPrice) as TotalPrice, sum(TotalCredit) as TotalCredit,sum(SubscriptionDuration) as NbrDaysSub, avg(SubscriptionDuration) as AvgDurationPerSub, avg(NbrMealsPerDay) as AverageNbrMealPerDay, SUM(CASE WHEN ProductName='Custom Events' THEN 1 ELSE 0 END) as NbrCustomEventsProduct, SUM(CASE WHEN ProductName!='Custom Events' THEN 1 ELSE 0 END) as NbrGrubProduct from subscriptions group by CustomerID")
SubInter = SubInter.withColumn("FirstSubDate", SubInter.FirstSubDate*86400)
SubInter = SubInter.withColumn("FirstSubDate", SubInter.FirstSubDate.cast("timestamp"))
SubInter = SubInter.withColumn("EndOfLastSub", SubInter.EndOfLastSub*86400)
SubInter = SubInter.withColumn("EndOfLastSub", SubInter.EndOfLastSub.cast("timestamp"))
#SubInter = SubInter.withColumn("FirstSubDate", SubInter.select((unix_timestamp("FirstSubDate","yyy/MM/dd HH:mm:ss")).cast("timestamp")))
#df.select((unix_timestamp($"Date", "MM/dd/yyyy HH:mm:ss") * 1000).cast("timestamp"), $"Date")

#creating the churn dependent variable
SubInter = SubInter.withColumn("ChurnedAt03/02/2019", when(col("EndOfLastSub") > "2019-02-02 00:00:00", 0).otherwise(1))
SubInter = SubInter.withColumn("ChurnedAt03/08/2018", when(col("EndOfLastSub") > "2018-08-02 00:00:00", 0).otherwise(1))
SubInter = SubInter.withColumn("ChurnedAt03/02/2018", when(col("EndOfLastSub") > "2018-02-02 00:00:00", 0).otherwise(1))

display(SubInter)

CustomerID,TotalMeal_REG,MeanMeal_REGPerSub,TotalMeal_EXCEP,MeanMeal_EXCEPPerSub,FirstSubDate,EndOfLastSub,HasBeenClientForXDays,NbrSub,SubPaid,SubNotPaid,ProportionPaidSub,AvgPricePerMeal,TotalProductDiscount,TotalFormulaDiscount,TotalDiscount,TotalPrice,TotalCredit,NbrDaysSub,AvgDurationPerSub,AverageNbrMealPerDay,NbrCustomEventsProduct,NbrGrubProduct,ChurnedAt03/02/2019,ChurnedAt03/08/2018,ChurnedAt03/02/2018
258487,785,196.25,55,13.75,2016-05-02T00:00:00.000+0000,2018-08-31T00:00:00.000+0000,851.0,4,3,1,0.75,17.216646666666666,0.0,409.6,409.6,11640.0,0.0,849.0,212.25,19.628179522497707,4,0,1,0,0
671995,1254,25.08,677,13.54,2015-01-02T00:00:00.000+0000,2019-02-14T00:00:00.000+0000,1504.0,50,50,0,1.0,20.5868036,0.0,0.0,0.0,25800.0,0.0,1452.0,29.04,0.8644984727298408,0,50,0,0,0
285977,1519,303.8,70,14.0,2014-01-04T00:00:00.000+0000,2019-02-07T00:00:00.000+0000,1860.0,5,5,0,1.0,15.705696,0.0,0.0,0.0,23858.2,-721.8000000000001,1856.0,371.2,0.818706760989113,0,5,0,0,0
682942,1519,303.8,85,17.0,2014-02-01T00:00:00.000+0000,2019-01-31T00:00:00.000+0000,1825.0,5,5,0,1.0,16.180996,0.0,0.0,0.0,24580.0,0.0,1821.0,364.2,0.834157760048171,0,5,1,0,0
104880,1216,76.0,235,14.6875,2015-01-02T00:00:00.000+0000,2018-12-30T00:00:00.000+0000,1458.0,16,16,0,1.0,19.6064425,0.0,0.0,0.0,23840.0,0.0,1441.0,90.0625,0.8439920836606536,0,16,1,0,0
965578,1216,304.0,55,13.75,2014-12-19T00:00:00.000+0000,2018-12-17T00:00:00.000+0000,1459.0,4,4,0,1.0,15.62171,0.0,386.6,386.6,18996.0,0.0,1456.0,364.0,0.8351648351648352,0,4,1,0,0
829912,684,114.0,65,10.833333333333334,2016-03-19T00:00:00.000+0000,2017-12-31T00:00:00.000+0000,652.0,6,5,1,0.8333333333333334,14.713593333333334,0.0,787.11578,787.11578,9108.8,0.0,647.0,107.83333333333331,2.812704198909744,0,6,1,1,1
673836,1216,304.0,55,13.75,2015-01-02T00:00:00.000+0000,2019-01-03T00:00:00.000+0000,1462.0,4,4,0,1.0,16.41447,0.0,0.0,0.0,19960.0,-40.0,1459.0,364.75,0.8334705478974254,3,1,1,0,0
659301,1218,304.5,60,15.0,2014-06-23T00:00:00.000+0000,2018-06-24T00:00:00.000+0000,1462.0,4,4,0,1.0,15.961825,0.0,0.0,0.0,19440.0,0.0,1458.0,364.5,0.8353925184404636,0,4,1,1,0
1012153,304,304.0,25,25.0,2017-03-02T00:00:00.000+0000,2018-03-01T00:00:00.000+0000,364.0,1,1,0,1.0,0.0,5180.0,0.0,5180.0,0.0,0.0,364.0,364.0,0.8351648351648352,0,1,1,1,0


In [8]:
Complaints.createOrReplaceTempView("complaints")

In [9]:
Intermediary = spark.sql("select CustomerID, count(ComplaintID) as NbrComplaints, max(ComplaintDate) as LastComplaint, min(ComplaintDate) as FirstComplaint, (CASE WHEN count(ComplaintID)>1 THEN (count(ComplaintID)/(max(ComplaintDate)-min(ComplaintDate))) ELSE 0 END) as ComplaintsPerMonth, SUM(CASE WHEN ProductID=1 THEN 1 ELSE 0 END) as NbrComplaintsProduct1, SUM(CASE WHEN ProductID=2 THEN 1 ELSE 0 END) as NbrComplaintsProduct2, SUM(CASE WHEN ProductID=3 THEN 1 ELSE 0 END) as NbrComplaintsProduct3, SUM(CASE WHEN ProductID=4 THEN 1 ELSE 0 END) as NbrComplaintsProduct4, SUM(CASE WHEN ProductID=5 THEN 1 ELSE 0 END) as NbrComplaintsProduct5, SUM(CASE WHEN ProductID=6 THEN 1 ELSE 0 END) as NbrComplaintsProduct6, SUM(CASE WHEN ProductID=7 THEN 1 ELSE 0 END) as NbrComplaintsProduct7 , SUM(CASE WHEN ProductID=8 THEN 1 ELSE 0 END) as NbrComplaintsProduct8, SUM(CASE WHEN ProductID=0 THEN 1 ELSE 0 END) as NbrComplaintsProductUnknown,SUM(CASE WHEN ComplaintTypeID=1 THEN 1 ELSE 0 END) as NbrComplaintsType1, SUM(CASE WHEN ComplaintTypeID=2 THEN 1 ELSE 0 END) as NbrComplaintsType2, SUM(CASE WHEN ComplaintTypeID=3 THEN 1 ELSE 0 END) as NbrComplaintsType3, SUM(CASE WHEN ComplaintTypeID=4 THEN 1 ELSE 0 END) as NbrComplaintsType4, SUM(CASE WHEN ComplaintTypeID=5 THEN 1 ELSE 0 END) as NbrComplaintsType5, SUM(CASE WHEN ComplaintTypeID=6 THEN 1 ELSE 0 END) as NbrComplaintsType6, SUM(CASE WHEN ComplaintTypeID=7 THEN 1 ELSE 0 END) as NbrComplaintsType7 , SUM(CASE WHEN ComplaintTypeID=8 THEN 1 ELSE 0 END) as NbrComplaintsType8, SUM(CASE WHEN ComplaintTypeID=9 THEN 1 ELSE 0 END) as NbrComplaintsType9, SUM(CASE WHEN ComplaintTypeID=0 THEN 1 ELSE 0 END) as NbrComplaintsTypeUnknown, SUM(CASE WHEN SolutionTypeID=1 THEN 1 ELSE 0 END) as NbrSolutionsType1, SUM(CASE WHEN SolutionTypeID=2 THEN 1 ELSE 0 END) as NbrSolutionsType2, SUM(CASE WHEN SolutionTypeID=3 THEN 1 ELSE 0 END) as NbrSolutionsType3, SUM(CASE WHEN SolutionTypeID=4 THEN 1 ELSE 0 END) as NbrSolutionsType4, SUM(CASE WHEN SolutionTypeID=0 THEN 1 ELSE 0 END) as NbrSolutionsTypeUnknown from complaints group by CustomerID")
Intermediary = Intermediary.withColumn("FirstComplaint", Intermediary.FirstComplaint*86400)
Intermediary = Intermediary.withColumn("FirstComplaint", Intermediary.FirstComplaint.cast("timestamp"))
Intermediary = Intermediary.withColumn("LastComplaint", Intermediary.LastComplaint*86400)
Intermediary = Intermediary.withColumn("LastComplaint", Intermediary.LastComplaint.cast("timestamp"))

display(Intermediary)

CustomerID,NbrComplaints,LastComplaint,FirstComplaint,ComplaintsPerMonth,NbrComplaintsProduct1,NbrComplaintsProduct2,NbrComplaintsProduct3,NbrComplaintsProduct4,NbrComplaintsProduct5,NbrComplaintsProduct6,NbrComplaintsProduct7,NbrComplaintsProduct8,NbrComplaintsProductUnknown,NbrComplaintsType1,NbrComplaintsType2,NbrComplaintsType3,NbrComplaintsType4,NbrComplaintsType5,NbrComplaintsType6,NbrComplaintsType7,NbrComplaintsType8,NbrComplaintsType9,NbrComplaintsTypeUnknown,NbrSolutionsType1,NbrSolutionsType2,NbrSolutionsType3,NbrSolutionsType4,NbrSolutionsTypeUnknown
285977,7,2018-11-27T00:00:00.000+0000,2014-09-18T00:00:00.000+0000,0.0045721750489875,0,0,0,0,0,7,0,0,0,4,1,1,0,0,0,0,0,1,0,1,0,1,0,5
671995,2,2014-10-14T00:00:00.000+0000,2014-03-25T00:00:00.000+0000,0.0098522167487684,0,0,0,2,0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,2
466728,1,2016-12-15T00:00:00.000+0000,2016-12-15T00:00:00.000+0000,0.0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0
1012153,3,2018-08-07T00:00:00.000+0000,2018-08-02T00:00:00.000+0000,0.6,0,3,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,0,1,1,1,0,0
673836,29,2018-12-24T00:00:00.000+0000,2013-11-14T00:00:00.000+0000,0.0155412647374062,0,0,0,21,0,0,0,0,8,18,2,0,0,7,0,1,0,1,0,2,2,4,0,21
865501,1,2016-02-04T00:00:00.000+0000,2016-02-04T00:00:00.000+0000,0.0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0
802664,1,2013-06-21T00:00:00.000+0000,2013-06-21T00:00:00.000+0000,0.0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0
204587,6,2017-02-19T00:00:00.000+0000,2015-12-14T00:00:00.000+0000,0.0138568129330254,0,0,0,0,0,0,0,6,0,6,0,0,0,0,0,0,0,0,0,6,0,0,0,0
462878,2,2017-03-02T00:00:00.000+0000,2012-04-08T00:00:00.000+0000,0.0011179429849077,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,0,1
67278,2,2018-11-08T00:00:00.000+0000,2018-07-19T00:00:00.000+0000,0.0178571428571428,0,0,0,0,0,0,0,2,0,2,0,0,0,0,0,0,0,0,0,0,2,0,0,0


In [10]:
display(SubInter)

CustomerID,TotalMeal_REG,MeanMeal_REGPerSub,TotalMeal_EXCEP,MeanMeal_EXCEPPerSub,FirstSubDate,EndOfLastSub,HasBeenClientForXDays,NbrSub,SubPaid,SubNotPaid,ProportionPaidSub,AvgPricePerMeal,TotalProductDiscount,TotalFormulaDiscount,TotalDiscount,TotalPrice,TotalCredit,NbrDaysSub,AvgDurationPerSub,AverageNbrMealPerDay,NbrCustomEventsProduct,NbrGrubProduct,ChurnedAt03/02/2019,ChurnedAt03/08/2018,ChurnedAt03/02/2018
258487,785,196.25,55,13.75,2016-05-02T00:00:00.000+0000,2018-08-31T00:00:00.000+0000,851.0,4,3,1,0.75,17.216646666666666,0.0,409.6,409.6,11640.0,0.0,849.0,212.25,19.628179522497707,4,0,1,0,0
671995,1254,25.08,677,13.54,2015-01-02T00:00:00.000+0000,2019-02-14T00:00:00.000+0000,1504.0,50,50,0,1.0,20.5868036,0.0,0.0,0.0,25800.0,0.0,1452.0,29.04,0.8644984727298408,0,50,0,0,0
285977,1519,303.8,70,14.0,2014-01-04T00:00:00.000+0000,2019-02-07T00:00:00.000+0000,1860.0,5,5,0,1.0,15.705696,0.0,0.0,0.0,23858.2,-721.8000000000001,1856.0,371.2,0.818706760989113,0,5,0,0,0
682942,1519,303.8,85,17.0,2014-02-01T00:00:00.000+0000,2019-01-31T00:00:00.000+0000,1825.0,5,5,0,1.0,16.180996,0.0,0.0,0.0,24580.0,0.0,1821.0,364.2,0.834157760048171,0,5,1,0,0
104880,1216,76.0,235,14.6875,2015-01-02T00:00:00.000+0000,2018-12-30T00:00:00.000+0000,1458.0,16,16,0,1.0,19.6064425,0.0,0.0,0.0,23840.0,0.0,1441.0,90.0625,0.8439920836606536,0,16,1,0,0
965578,1216,304.0,55,13.75,2014-12-19T00:00:00.000+0000,2018-12-17T00:00:00.000+0000,1459.0,4,4,0,1.0,15.62171,0.0,386.6,386.6,18996.0,0.0,1456.0,364.0,0.8351648351648352,0,4,1,0,0
829912,684,114.0,65,10.833333333333334,2016-03-19T00:00:00.000+0000,2017-12-31T00:00:00.000+0000,652.0,6,5,1,0.8333333333333334,14.713593333333334,0.0,787.11578,787.11578,9108.8,0.0,647.0,107.83333333333331,2.812704198909744,0,6,1,1,1
673836,1216,304.0,55,13.75,2015-01-02T00:00:00.000+0000,2019-01-03T00:00:00.000+0000,1462.0,4,4,0,1.0,16.41447,0.0,0.0,0.0,19960.0,-40.0,1459.0,364.75,0.8334705478974254,3,1,1,0,0
659301,1218,304.5,60,15.0,2014-06-23T00:00:00.000+0000,2018-06-24T00:00:00.000+0000,1462.0,4,4,0,1.0,15.961825,0.0,0.0,0.0,19440.0,0.0,1458.0,364.5,0.8353925184404636,0,4,1,1,0
1012153,304,304.0,25,25.0,2017-03-02T00:00:00.000+0000,2018-03-01T00:00:00.000+0000,364.0,1,1,0,1.0,0.0,5180.0,0.0,5180.0,0.0,0.0,364.0,364.0,0.8351648351648352,0,1,1,1,0


In [11]:
#Base Table
#base = Customers.join(Complaints,on=['CustomerID'],how='full')
#base = Customers.join(Subscriptions,on=['CustomerID'],how='full')
base = Customers.join(Intermediary,on=['CustomerID'],how='full')
base1 = base.join(SubInter,on=['CustomerID'],how='full')
#Replacing null in Complaints
base1 = base1.na.fill(0)

In [12]:
display(base1)

CustomerID,Region,StreetID,NbrComplaints,LastComplaint,FirstComplaint,ComplaintsPerMonth,NbrComplaintsProduct1,NbrComplaintsProduct2,NbrComplaintsProduct3,NbrComplaintsProduct4,NbrComplaintsProduct5,NbrComplaintsProduct6,NbrComplaintsProduct7,NbrComplaintsProduct8,NbrComplaintsProductUnknown,NbrComplaintsType1,NbrComplaintsType2,NbrComplaintsType3,NbrComplaintsType4,NbrComplaintsType5,NbrComplaintsType6,NbrComplaintsType7,NbrComplaintsType8,NbrComplaintsType9,NbrComplaintsTypeUnknown,NbrSolutionsType1,NbrSolutionsType2,NbrSolutionsType3,NbrSolutionsType4,NbrSolutionsTypeUnknown,TotalMeal_REG,MeanMeal_REGPerSub,TotalMeal_EXCEP,MeanMeal_EXCEPPerSub,FirstSubDate,EndOfLastSub,HasBeenClientForXDays,NbrSub,SubPaid,SubNotPaid,ProportionPaidSub,AvgPricePerMeal,TotalProductDiscount,TotalFormulaDiscount,TotalDiscount,TotalPrice,TotalCredit,NbrDaysSub,AvgDurationPerSub,AverageNbrMealPerDay,NbrCustomEventsProduct,NbrGrubProduct,ChurnedAt03/02/2019,ChurnedAt03/08/2018,ChurnedAt03/02/2018
104880,5,45805,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1216,76.0,235,14.6875,2015-01-02T00:00:00.000+0000,2018-12-30T00:00:00.000+0000,1458.0,16,16,0,1.0,19.6064425,0.0,0.0,0.0,23840.0,0.0,1441.0,90.0625,0.8439920836606536,0,16,1,0,0
258487,1,14628,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,785,196.25,55,13.75,2016-05-02T00:00:00.000+0000,2018-08-31T00:00:00.000+0000,851.0,4,3,1,0.75,17.216646666666666,0.0,409.6,409.6,11640.0,0.0,849.0,212.25,19.628179522497707,4,0,1,0,0
285977,1,18415,7,2018-11-27T00:00:00.000+0000,2014-09-18T00:00:00.000+0000,0.0045721750489875,0,0,0,0,0,7,0,0,0,4,1,1,0,0,0,0,0,1,0,1,0,1,0,5,1519,303.8,70,14.0,2014-01-04T00:00:00.000+0000,2019-02-07T00:00:00.000+0000,1860.0,5,5,0,1.0,15.705696,0.0,0.0,0.0,23858.2,-721.8000000000001,1856.0,371.2,0.818706760989113,0,5,0,0,0
671995,1,28929,2,2014-10-14T00:00:00.000+0000,2014-03-25T00:00:00.000+0000,0.0098522167487684,0,0,0,2,0,0,0,0,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,2,1254,25.08,677,13.54,2015-01-02T00:00:00.000+0000,2019-02-14T00:00:00.000+0000,1504.0,50,50,0,1.0,20.5868036,0.0,0.0,0.0,25800.0,0.0,1452.0,29.04,0.8644984727298408,0,50,0,0,0
682942,1,18048,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1519,303.8,85,17.0,2014-02-01T00:00:00.000+0000,2019-01-31T00:00:00.000+0000,1825.0,5,5,0,1.0,16.180996,0.0,0.0,0.0,24580.0,0.0,1821.0,364.2,0.834157760048171,0,5,1,0,0
829912,5,40317,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,684,114.0,65,10.833333333333334,2016-03-19T00:00:00.000+0000,2017-12-31T00:00:00.000+0000,652.0,6,5,1,0.8333333333333334,14.713593333333334,0.0,787.11578,787.11578,9108.8,0.0,647.0,107.83333333333331,2.812704198909744,0,6,1,1,1
965578,5,45860,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1216,304.0,55,13.75,2014-12-19T00:00:00.000+0000,2018-12-17T00:00:00.000+0000,1459.0,4,4,0,1.0,15.62171,0.0,386.6,386.6,18996.0,0.0,1456.0,364.0,0.8351648351648352,0,4,1,0,0
75070,5,43993,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,912,304.0,45,15.0,2015-01-02T00:00:00.000+0000,2017-06-29T00:00:00.000+0000,909.0,3,3,0,1.0,16.29386,0.0,0.0,0.0,14860.0,0.0,907.0,302.3333333333333,1.1205064295973386,0,3,1,1,1
107896,5,46836,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1216,304.0,55,13.75,2015-01-02T00:00:00.000+0000,2018-12-30T00:00:00.000+0000,1458.0,4,4,0,1.0,16.61184,0.0,0.0,0.0,20200.0,0.0,1455.0,363.75,0.8357400175581995,0,4,1,0,0
158050,5,41138,0,,,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1216,304.0,55,13.75,2015-01-02T00:00:00.000+0000,2018-01-20T00:00:00.000+0000,1114.0,4,3,1,0.75,16.575655,0.0,0.0,0.0,14816.0,-44.0,1111.0,277.75,4.426948808766991,0,4,1,1,1


In [13]:
#Create categorical variables for Region and StreetID 
from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer
from pyspark.ml import Pipeline

#Region
genderIndxr = StringIndexer().setInputCol("Region").setOutputCol("RegionInd")

#StreetID
classIndxr = StringIndexer().setInputCol("StreetID").setOutputCol("StreetIDInd")

#One-hot encoding
ohee_catv = OneHotEncoderEstimator(inputCols=["RegionInd","StreetIDInd"],outputCols=["Region_dum","StreetID_dum"])
pipe_catv = Pipeline(stages=[genderIndxr, classIndxr, ohee_catv])

basetable_final = pipe_catv.fit(base1).transform(base1)
basetable_final=basetable_final.drop("RegionInd","StreetIDInd","EndOfLastSub","LastComplaint","FirstComplaint","ChurnedAt03/02/2018","ChurnedAt03/02/2019")


basetable_final= basetable_final.withColumnRenamed("ChurnedAt03/08/2018","label")

In [14]:

#Splitting the data set into train and test
from pyspark.sql.functions import percent_rank
from pyspark.sql import Window

basetable_final = basetable_final.withColumn("rank", percent_rank().over(Window.partitionBy().orderBy("FirstSubDate")))

In [15]:
basetable_train = basetable_final.where("rank <= .736").drop("rank","FirstSubDate")
basetable_test = basetable_final.where("rank > .736").drop("rank","FirstSubDate")

In [16]:
basetable_train.where((col("label") == 1)).count()

In [17]:
#Transform the tables in a table of label, features format
from pyspark.ml.feature import RFormula

trainBig = RFormula(formula="label ~ . - CustomerID").fit(basetable_final).transform(basetable_final)
train = RFormula(formula="label ~ . - CustomerID").fit(basetable_train).transform(basetable_train)
test = RFormula(formula="label ~ . - CustomerID").fit(basetable_test).transform(basetable_test)
print("trainBig nobs: " + str(trainBig.count()))
print("train nobs: " + str(train.count()))
print("test nobs: " + str(test.count()))

In [18]:
#Train a Logistic Regression model
from pyspark.ml.classification import LogisticRegression

#Define the algorithm class
lr = LogisticRegression()

#Fit the model
lrModel = lr.fit(trainBig)

#Print coefficients
lrModel.coefficients

In [19]:
#Spark offers two options for performing hyperparameter tuning automatically:

#1. TrainValidationSplit: randomly split data in 2 groups
from pyspark.ml.tuning import TrainValidationSplit

#2. CrossValidator: k-fold cross-validation by splitting the data into k non-overlapping, randomly partitioned folds
from pyspark.ml.tuning import CrossValidator

#First method is good for quick model evaluations, 2nd method is recommended for a more rigorous model evaluation.


#Hyperparameter tuning for different hyperparameter values of LR (aka model selection)
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator

#Define pipeline
lr = LogisticRegression()
pipeline = Pipeline().setStages([lr])

#Set param grid
lrparams = ParamGridBuilder()\
  .addGrid(lr.regParam, [0.1, 0.01])\
  .addGrid(lr.maxIter, [50, 100,150])\
  .build()

#Evaluator: uses max(AUC) by default to get the final model
evaluator = BinaryClassificationEvaluator()
#Check params through: evaluator.explainParams()

#Cross-validation of entire pipeline
cv = CrossValidator()\
  .setEstimator(pipeline)\
  .setEstimatorParamMaps(lrparams)\
  .setEvaluator(evaluator)\
  .setNumFolds(5) # Here: 5-fold cross validation

#Run cross-validation, and choose the best set of parameters.
#Spark automatically saves the best model in cvModel.
cvModel = cv.fit(train)


#Get best tuned parameters of pipeline
cvBestPipeline = cvModel.bestModel
cvBestLRModel = cvBestPipeline.stages[-1]._java_obj.parent() #the stages function refers to the stage in the pipelinemodel

print("Best LR model:")
print("** regParam: " + str(cvBestLRModel.getRegParam()))
print("** maxIter: " + str(cvBestLRModel.getMaxIter()))

In [20]:
preds = cvModel.transform(test)\
  .select("prediction", "label")
preds.show(10)

#Get model performance on test set
from pyspark.mllib.evaluation import BinaryClassificationMetrics

out = preds.rdd.map(lambda x: (float(x[0]), float(x[1])))
metrics = BinaryClassificationMetrics(out)

print(metrics.areaUnderPR) #area under precision/recall curve
print(metrics.areaUnderROC)#area under Receiver Operating Characteristic curve

In [21]:
#Exercise: Train a RandomForest model and tune the number of trees for values [150, 300, 500]
#Hint: analogous to buidling a LR model (see above)
from pyspark.ml.classification import RandomForestClassifier

#Define pipeline
rfc = RandomForestClassifier()
rfPipe = Pipeline().setStages([rfc])

#Set param grid
rfParams = ParamGridBuilder()\
  .addGrid(rfc.numTrees, [150, 300, 500])\
  .build()

rfCv = CrossValidator()\
  .setEstimator(rfPipe)\
  .setEstimatorParamMaps(rfParams)\
  .setEvaluator(BinaryClassificationEvaluator())\
  .setNumFolds(5) # Here: 5-fold cross validation

#Run cross-validation, and choose the best set of parameters.
rfcModel = rfCv.fit(train)

In [22]:
#Get predictions on the test set
preds = rfcModel.transform(test)
display(preds)

CustomerID,Region,StreetID,NbrComplaints,ComplaintsPerMonth,NbrComplaintsProduct1,NbrComplaintsProduct2,NbrComplaintsProduct3,NbrComplaintsProduct4,NbrComplaintsProduct5,NbrComplaintsProduct6,NbrComplaintsProduct7,NbrComplaintsProduct8,NbrComplaintsProductUnknown,NbrComplaintsType1,NbrComplaintsType2,NbrComplaintsType3,NbrComplaintsType4,NbrComplaintsType5,NbrComplaintsType6,NbrComplaintsType7,NbrComplaintsType8,NbrComplaintsType9,NbrComplaintsTypeUnknown,NbrSolutionsType1,NbrSolutionsType2,NbrSolutionsType3,NbrSolutionsType4,NbrSolutionsTypeUnknown,TotalMeal_REG,MeanMeal_REGPerSub,TotalMeal_EXCEP,MeanMeal_EXCEPPerSub,HasBeenClientForXDays,NbrSub,SubPaid,SubNotPaid,ProportionPaidSub,AvgPricePerMeal,TotalProductDiscount,TotalFormulaDiscount,TotalDiscount,TotalPrice,TotalCredit,NbrDaysSub,AvgDurationPerSub,AverageNbrMealPerDay,NbrCustomEventsProduct,NbrGrubProduct,label,Region_dum,StreetID_dum,features,rawPrediction,probability,prediction
56451,1,14405,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,12.0,0,0.0,286.0,2,2,0,1.0,0.0,0.0,499.2,499.2,0.0,0.0,26.0,13.0,0.9230769230769232,1,1,1,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(798), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 32, 33, 34, 36, 39, 40, 43, 44, 45, 46, 47, 49, 854), List(1.0, 14405.0, 24.0, 12.0, 286.0, 2.0, 2.0, 1.0, 499.2, 499.2, 26.0, 13.0, 0.9230769230769231, 1.0, 1.0, 1.0, 1.0))","List(1, 2, List(), List(178.14628712310395, 321.85371287689657))","List(1, 2, List(), List(0.3562925742462075, 0.6437074257537925))",1.0
709042,1,16777,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,49,16.333333333333332,25,8.333333333333334,1128.0,3,3,0,1.0,3.733333333333333,0.0,779.2,779.2,280.0,0.0,53.0,17.666666666666668,0.9386446886446884,0,3,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(31), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 39, 40, 41, 43, 44, 45, 47, 49, 87), List(1.0, 16777.0, 49.0, 16.333333333333332, 25.0, 8.333333333333334, 1128.0, 3.0, 3.0, 1.0, 3.733333333333333, 779.2, 779.2, 280.0, 53.0, 17.666666666666668, 0.9386446886446885, 3.0, 1.0, 1.0))","List(1, 2, List(), List(175.75923446110272, 324.2407655388978))","List(1, 2, List(), List(0.35151846892220506, 0.6484815310777949))",1.0
60536,1,19113,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,912,304.0,45,15.0,1093.0,3,3,0,1.0,16.776313333333334,0.0,0.0,0.0,15300.0,0.0,1091.0,363.6666666666667,0.8359317450226542,0,3,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(339), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 43, 44, 45, 47, 49, 395), List(1.0, 19113.0, 912.0, 304.0, 45.0, 15.0, 1093.0, 3.0, 3.0, 1.0, 16.776313333333334, 15300.0, 1091.0, 363.6666666666667, 0.8359317450226542, 3.0, 1.0, 1.0))","List(1, 2, List(), List(165.88437856286737, 334.1156214371329))","List(1, 2, List(), List(0.33176875712573456, 0.6682312428742655))",1.0
61494,7,69155,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,734,183.5,55,13.75,678.0,4,3,1,0.75,14.873686666666666,0.0,413.6,413.6,10720.0,0.0,771.0,192.75,0.9292888097931706,0,4,1,"List(0, 8, List(2), List(1.0))","List(0, 1282, List(75), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 43, 44, 45, 47, 50, 131), List(7.0, 69155.0, 734.0, 183.5, 55.0, 13.75, 678.0, 4.0, 3.0, 1.0, 0.75, 14.873686666666666, 413.6, 413.6, 10720.0, 771.0, 192.75, 0.9292888097931706, 4.0, 1.0, 1.0))","List(1, 2, List(), List(195.34113967334736, 304.6588603266524))","List(1, 2, List(), List(0.3906822793466949, 0.6093177206533051))",1.0
55101,1,18256,1,0.0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,1216,304.0,55,13.75,1107.0,4,3,1,0.75,17.12171,0.0,0.0,0.0,15300.0,0.0,1104.0,276.0,6.959710111391306,0,4,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(1069), List(1.0))","List(0, 1338, List(0, 1, 2, 9, 21, 23, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 41, 43, 44, 45, 47, 49, 1125), List(1.0, 18256.0, 1.0, 1.0, 1.0, 1.0, 1216.0, 304.0, 55.0, 13.75, 1107.0, 4.0, 3.0, 1.0, 0.75, 17.12171, 15300.0, 1104.0, 276.0, 6.959710111391306, 4.0, 1.0, 1.0))","List(1, 2, List(), List(221.21384665704397, 278.78615334295586))","List(1, 2, List(), List(0.4424276933140881, 0.557572306685912))",1.0
57657,1,15907,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,12,12.0,0,0.0,13.0,1,1,0,1.0,0.0,0.0,249.6,249.6,0.0,0.0,13.0,13.0,0.9230769230769232,0,1,1,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(111), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 32, 33, 34, 36, 39, 40, 43, 44, 45, 47, 49, 167), List(1.0, 15907.0, 12.0, 12.0, 13.0, 1.0, 1.0, 1.0, 249.6, 249.6, 13.0, 13.0, 0.9230769230769231, 1.0, 1.0, 1.0))","List(1, 2, List(), List(177.36283369400667, 322.63716630599373))","List(1, 2, List(), List(0.35472566738801303, 0.645274332611987))",1.0
56876,1,29176,1,0.0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,574,95.66666666666669,80,13.333333333333334,963.0,6,4,2,0.6666666666666666,10.28249,0.0,2207.36842,2207.36842,4440.0,0.0,489.0,81.5,38.50755157999127,3,3,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(179), List(1.0))","List(0, 1338, List(0, 1, 2, 5, 13, 24, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 39, 40, 41, 43, 44, 45, 46, 47, 49, 235), List(1.0, 29176.0, 1.0, 1.0, 1.0, 1.0, 574.0, 95.66666666666667, 80.0, 13.333333333333334, 963.0, 6.0, 4.0, 2.0, 0.6666666666666666, 10.28249, 2207.36842, 2207.36842, 4440.0, 489.0, 81.5, 38.507551579991265, 3.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(227.61499561482697, 272.38500438517303))","List(1, 2, List(), List(0.4552299912296539, 0.5447700087703461))",1.0
62202,1,21348,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,152,76.0,35,17.5,100.0,2,1,1,0.5,19.47368,0.0,0.0,0.0,1480.0,0.0,99.0,49.5,3.886363636363636,0,2,1,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(933), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 41, 43, 44, 45, 47, 49, 989), List(1.0, 21348.0, 152.0, 76.0, 35.0, 17.5, 100.0, 2.0, 1.0, 1.0, 0.5, 19.47368, 1480.0, 99.0, 49.5, 3.8863636363636362, 2.0, 1.0, 1.0))","List(1, 2, List(), List(141.12700474613325, 358.872995253867))","List(1, 2, List(), List(0.28225400949226637, 0.7177459905077336))",1.0
298867,5,44975,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,608,304.0,35,17.5,729.0,2,2,0,1.0,16.71053,0.0,0.0,0.0,10160.0,0.0,728.0,364.0,0.8351648351648352,0,2,1,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(768), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 43, 44, 45, 47, 48, 824), List(5.0, 44975.0, 608.0, 304.0, 35.0, 17.5, 729.0, 2.0, 2.0, 1.0, 16.71053, 10160.0, 728.0, 364.0, 0.8351648351648352, 2.0, 1.0, 1.0))","List(1, 2, List(), List(170.08533559116614, 329.9146644088341))","List(1, 2, List(), List(0.34017067118233213, 0.6598293288176679))",1.0
182620,5,47769,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,784,71.27272727272727,140,12.727272727272728,1143.0,11,11,0,1.0,18.522490909090912,0.0,1004.8,1004.8,14902.4,0.0,929.0,84.45454545454545,0.8447805353598169,0,11,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(326), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 39, 40, 41, 43, 44, 45, 47, 48, 382), List(5.0, 47769.0, 784.0, 71.27272727272727, 140.0, 12.727272727272727, 1143.0, 11.0, 11.0, 1.0, 18.522490909090912, 1004.8000000000001, 1004.8000000000001, 14902.4, 929.0, 84.45454545454545, 0.8447805353598169, 11.0, 1.0, 1.0))","List(1, 2, List(), List(285.0930766295252, 214.9069233704745))","List(1, 2, List(), List(0.5701861532590508, 0.42981384674094925))",0.0


In [23]:
#Get model accuracy
print("accuracy: " + str(evaluator.evaluate(preds)))

#Get AUC
metrics = BinaryClassificationMetrics(preds.rdd.map(lambda x: (float(x[0]), float(x[1]))))
print("AUC: " + str(metrics.areaUnderROC))

In [24]:
#Get more metrics
from pyspark.mllib.evaluation import MulticlassMetrics

#Cast a DF of predictions to an RDD to access RDD methods of MulticlassMetrics
preds_labels = cvModel.transform(test)\
  .select("prediction", "label")\
  .rdd.map(lambda x: (float(x[0]), float(x[1])))

metrics = MulticlassMetrics(preds_labels)

print("accuracy = %s" % metrics.accuracy)

In [25]:
#Get more metrics
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.mllib.util import MLUtils

labels = preds.rdd.map(lambda lp: lp.label).distinct().collect()
for label in sorted(labels):
    print("Class %s precision = %s" % (label, metrics.precision(label)))
    print("Class %s recall = %s" % (label, metrics.recall(label)))
 #   print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)))

In [26]:
#Select the best RF model
rfcBestModel = rfcModel.bestModel.stages[-1] #-1 means "get last stage in the pipeline"


In [27]:
#Get tuned number of trees
rfcBestModel.getNumTrees

In [28]:
#Prettify feature importances
import pandas as pd
def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))
  
ExtractFeatureImp(rfcBestModel.featureImportances, train, "features").head(10)

Unnamed: 0,idx,name,score
41,41,TotalPrice,0.117071
32,32,HasBeenClientForXDays,0.085446
43,43,NbrDaysSub,0.082592
31,31,MeanMeal_EXCEPPerSub,0.074673
34,34,SubPaid,0.060538
37,37,AvgPricePerMeal,0.059804
33,33,NbrSub,0.05251
28,28,TotalMeal_REG,0.050055
30,30,TotalMeal_EXCEP,0.047162
47,47,NbrGrubProduct,0.038351


In [29]:
from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', maxDepth = 3)
dtModel = dt.fit(train)
predictions = dtModel.transform(test)
predictions.select( 'label', 'rawPrediction', 'prediction', 'probability').show(10)

In [30]:
evaluator = BinaryClassificationEvaluator()
print("Test Area Under ROC: " + str(evaluator.evaluate(predictions, {evaluator.metricName: "areaUnderROC"})))

In [31]:
from pyspark.ml.classification import LogisticRegression,DecisionTreeClassifier,RandomForestClassifier
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder,TrainValidationSplit
from pyspark.ml.evaluation import BinaryClassificationEvaluator

#Define pipeline
lr = LogisticRegression()
dt = DecisionTreeClassifier(featuresCol = 'features', labelCol = 'label', maxDepth = 3)
rfc = RandomForestClassifier()

pipe = Pipeline().setStages([lr,rfc,dt])

In [32]:
#Set param grid of lr
params = ParamGridBuilder()\
  .addGrid(lr.regParam, [0.1, 0.01])\
  .addGrid(lr.maxIter, [50, 100,150])\
  .build()

#Evaluator: uses max(AUC) by default to get the final model
evaluator = BinaryClassificationEvaluator()
#Check params through: evaluator.explainParams()


#Set param grid
rfParams = ParamGridBuilder()\
  .addGrid(rfc.numTrees, [150, 300, 500])\
  .build()

In [33]:


#Cross-validation of entire pipeline
cv = CrossValidator()\
  .setEstimator(pipe)\
  .setEstimatorParamMaps(params)\
  .setEvaluator(evaluator)\
  .setNumFolds(5) # Here: 5-fold cross validation

#Run cross-validation, and choose the best set of parameters.
#Spark automatically saves the best model in cvModel.
cvModel = cv.fit(train)


#Get best tuned parameters of pipeline
cvBestPipeline = cvModel.bestModel
cvBestLRModel = cvBestPipeline.stages[-1]._java_obj.parent()

In [34]:
display(train)

CustomerID,Region,StreetID,NbrComplaints,ComplaintsPerMonth,NbrComplaintsProduct1,NbrComplaintsProduct2,NbrComplaintsProduct3,NbrComplaintsProduct4,NbrComplaintsProduct5,NbrComplaintsProduct6,NbrComplaintsProduct7,NbrComplaintsProduct8,NbrComplaintsProductUnknown,NbrComplaintsType1,NbrComplaintsType2,NbrComplaintsType3,NbrComplaintsType4,NbrComplaintsType5,NbrComplaintsType6,NbrComplaintsType7,NbrComplaintsType8,NbrComplaintsType9,NbrComplaintsTypeUnknown,NbrSolutionsType1,NbrSolutionsType2,NbrSolutionsType3,NbrSolutionsType4,NbrSolutionsTypeUnknown,TotalMeal_REG,MeanMeal_REGPerSub,TotalMeal_EXCEP,MeanMeal_EXCEPPerSub,HasBeenClientForXDays,NbrSub,SubPaid,SubNotPaid,ProportionPaidSub,AvgPricePerMeal,TotalProductDiscount,TotalFormulaDiscount,TotalDiscount,TotalPrice,TotalCredit,NbrDaysSub,AvgDurationPerSub,AverageNbrMealPerDay,NbrCustomEventsProduct,NbrGrubProduct,label,Region_dum,StreetID_dum,features
923623,5,43628,1,0.0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1520,304.0,70,14.0,1824.0,5,5,0,1.0,15.912104000000005,0.0,0.0,0.0,24186.4,-44.0,1820.0,364.0,0.835167356512313,0,5,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(1256), List(1.0))","List(0, 1338, List(0, 1, 2, 11, 13, 24, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 42, 43, 44, 45, 47, 48, 1312), List(5.0, 43628.0, 1.0, 1.0, 1.0, 1.0, 1520.0, 304.0, 70.0, 14.0, 1824.0, 5.0, 5.0, 1.0, 15.912104000000003, 24186.4, -44.0, 1820.0, 364.0, 0.835167356512313, 5.0, 1.0, 1.0))"
644766,1,102093,7,0.0048442906574394,0,0,0,0,0,7,0,0,0,6,0,1,0,0,0,0,0,0,0,3,1,1,0,2,1849,264.1428571428572,105,15.0,1779.0,7,5,2,0.7142857142857143,16.2117,0.0,0.0,0.0,22029.8,-370.2,1743.0,249.0,65.71642205702143,0,7,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(1056), List(1.0))","List(0, 1338, List(0, 1, 2, 3, 9, 13, 15, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 41, 42, 43, 44, 45, 47, 49, 1112), List(1.0, 102093.0, 7.0, 0.004844290657439446, 7.0, 6.0, 1.0, 3.0, 1.0, 1.0, 2.0, 1849.0, 264.14285714285717, 105.0, 15.0, 1779.0, 7.0, 5.0, 2.0, 0.7142857142857143, 16.2117, 22029.8, -370.2, 1743.0, 249.0, 65.71642205702143, 7.0, 1.0, 1.0))"
80187,5,38209,1,0.0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1545,309.0,85,17.0,1856.0,5,5,0,1.0,16.073668,0.0,376.6,376.6,24780.0,0.0,1852.0,370.4,0.8342216951395729,0,5,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(3), List(1.0))","List(0, 1338, List(0, 1, 2, 11, 13, 23, 28, 29, 30, 31, 32, 33, 34, 36, 37, 39, 40, 41, 43, 44, 45, 47, 48, 59), List(5.0, 38209.0, 1.0, 1.0, 1.0, 1.0, 1545.0, 309.0, 85.0, 17.0, 1856.0, 5.0, 5.0, 1.0, 16.073668, 376.6, 376.6, 24780.0, 1852.0, 370.4, 0.8342216951395729, 5.0, 1.0, 1.0))"
285977,1,18415,7,0.0045721750489875,0,0,0,0,0,7,0,0,0,4,1,1,0,0,0,0,0,1,0,1,0,1,0,5,1519,303.8,70,14.0,1860.0,5,5,0,1.0,15.705696,0.0,0.0,0.0,23858.2,-721.8000000000001,1856.0,371.2,0.818706760989113,0,5,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(1179), List(1.0))","List(0, 1338, List(0, 1, 2, 3, 9, 13, 14, 15, 21, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 42, 43, 44, 45, 47, 49, 1235), List(1.0, 18415.0, 7.0, 0.0045721750489875895, 7.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 1519.0, 303.8, 70.0, 14.0, 1860.0, 5.0, 5.0, 1.0, 15.705696, 23858.199999999997, -721.8000000000001, 1856.0, 371.2, 0.818706760989113, 5.0, 1.0, 1.0))"
925863,5,43843,1,0.0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1519,303.8,70,14.0,1826.0,5,5,0,1.0,15.920856,0.0,0.0,0.0,24186.4,-44.0,1822.0,364.4,0.8337001354809574,0,5,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(203), List(1.0))","List(0, 1338, List(0, 1, 2, 11, 13, 23, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 42, 43, 44, 45, 47, 48, 259), List(5.0, 43843.0, 1.0, 1.0, 1.0, 1.0, 1519.0, 303.8, 70.0, 14.0, 1826.0, 5.0, 5.0, 1.0, 15.920856, 24186.4, -44.0, 1822.0, 364.4, 0.8337001354809574, 5.0, 1.0, 1.0))"
757625,5,41196,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1519,303.8,70,14.0,1825.0,5,5,0,1.0,16.283628,0.0,0.0,0.0,24736.0,-44.0,1821.0,364.2,0.834157760048171,0,5,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(147), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 42, 43, 44, 45, 47, 48, 203), List(5.0, 41196.0, 1519.0, 303.8, 70.0, 14.0, 1825.0, 5.0, 5.0, 1.0, 16.283628, 24736.0, -44.0, 1821.0, 364.2, 0.834157760048171, 5.0, 1.0, 1.0))"
163600,5,43993,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1519,303.8,70,14.0,1825.0,5,5,0,1.0,16.312576,0.0,0.0,0.0,24780.0,0.0,1821.0,364.2,0.834157760048171,0,5,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(40), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 43, 44, 45, 47, 48, 96), List(5.0, 43993.0, 1519.0, 303.8, 70.0, 14.0, 1825.0, 5.0, 5.0, 1.0, 16.312576, 24780.0, 1821.0, 364.2, 0.834157760048171, 5.0, 1.0, 1.0))"
451025,5,47561,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1844,102.44444444444444,395,21.944444444444443,1863.0,18,17,1,0.9444444444444444,18.166826666666665,0.0,0.0,0.0,25576.0,0.0,1842.0,102.33333333333331,2.687061642575633,0,18,0,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(463), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 41, 43, 44, 45, 47, 48, 519), List(5.0, 47561.0, 1844.0, 102.44444444444444, 395.0, 21.944444444444443, 1863.0, 18.0, 17.0, 1.0, 0.9444444444444444, 18.166826666666665, 25576.0, 1842.0, 102.33333333333333, 2.6870616425756335, 18.0, 1.0, 1.0))"
925903,5,40014,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1545,309.0,85,17.0,1502.0,5,4,1,0.8,15.739564,0.0,0.0,0.0,18890.4,0.0,1498.0,299.6,6.194953906852641,0,5,1,"List(0, 8, List(0), List(1.0))","List(0, 1282, List(361), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 41, 43, 44, 45, 47, 48, 417), List(5.0, 40014.0, 1545.0, 309.0, 85.0, 17.0, 1502.0, 5.0, 4.0, 1.0, 0.8, 15.739564000000001, 18890.4, 1498.0, 299.6, 6.194953906852641, 5.0, 1.0, 1.0))"
253165,1,29610,0,0.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1519,303.8,70,14.0,1825.0,5,5,0,1.0,16.180996,0.0,0.0,0.0,24580.0,0.0,1820.0,364.0,0.8346153846153846,0,5,0,"List(0, 8, List(1), List(1.0))","List(0, 1282, List(323), List(1.0))","List(0, 1338, List(0, 1, 28, 29, 30, 31, 32, 33, 34, 36, 37, 41, 43, 44, 45, 47, 49, 379), List(1.0, 29610.0, 1519.0, 303.8, 70.0, 14.0, 1825.0, 5.0, 5.0, 1.0, 16.180996, 24580.0, 1820.0, 364.0, 0.8346153846153846, 5.0, 1.0, 1.0))"


In [35]:
preds = cvModel.transform(test)\
  .select("prediction", "label")
preds.show(10)

#Get model performance on test set
from pyspark.mllib.evaluation import BinaryClassificationMetrics

out = preds.rdd.map(lambda x: (float(x[0]), float(x[1])))
metrics = BinaryClassificationMetrics(out)

print(metrics.areaUnderPR) #area under precision/recall curve
print(metrics.areaUnderROC)#area under Receiver Operating Characteristic curve

In [36]:
#Get more metrics
from pyspark.mllib.evaluation import MulticlassMetrics

#Cast a DF of predictions to an RDD to access RDD methods of MulticlassMetrics
preds_labels = cvModel.transform(test)\
  .select("prediction", "label")\
  .rdd.map(lambda x: (float(x[0]), float(x[1])))

metrics = MulticlassMetrics(preds_labels)

print("accuracy = %s" % metrics.accuracy)

In [37]:
#Get more metrics
from pyspark.mllib.evaluation import MulticlassMetrics

labels = preds.rdd.map(lambda lp: lp.label).distinct().collect()
for label in sorted(labels):
    print("Class %s precision = %s" % (label, metrics.precision(label)))
    print("Class %s recall = %s" % (label, metrics.recall(label)))

In [38]:
#Select the best RF model
rfcBestModel = rfcModel.bestModel.stages[-1] #-1 means "get last stage in the pipeline"

#Get tuned number of trees
rfcBestModel.getNumTrees

In [39]:
#Prettify feature importances
import pandas as pd
def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))
  
ExtractFeatureImp(rfcBestModel.featureImportances, train, "features").head(20)

Unnamed: 0,idx,name,score
41,41,TotalPrice,0.117071
32,32,HasBeenClientForXDays,0.085446
43,43,NbrDaysSub,0.082592
31,31,MeanMeal_EXCEPPerSub,0.074673
34,34,SubPaid,0.060538
37,37,AvgPricePerMeal,0.059804
33,33,NbrSub,0.05251
28,28,TotalMeal_REG,0.050055
30,30,TotalMeal_EXCEP,0.047162
47,47,NbrGrubProduct,0.038351
