The purpose of this notebook is to access & prepare the data required for churn prediction.  It should be run on a cluster leveraging **Databricks ML 7.1+** and **CPU-based** nodes.

###Step 1: Load the Data

In 2018, [KKBox](https://www.kkbox.com/), a popular music streaming service based in Taiwan, released a [dataset](https://www.kaggle.com/c/kkbox-churn-prediction-challenge/data) consisting of a little over two years of (anonymized) customer transaction and activity data with the goal of challenging the Data & AI community to predict which customers would churn in a future period.  

**NOTE** Due to the terms and conditions by which these data are made available, anyone interested in recreating this work will need to download the files from Kaggle that make up this dataset and create a similar folder structure as described below in their environment.

The primary data files available for download are organized under as follows under a pre-defined [mount point](https://docs.databricks.com/data/databricks-file-system.html#mount-object-storage-to-dbfs):

<img src='https://brysmiwasb.blob.core.windows.net/demos/images/kkbox_filedownloads.png' width=250>



Read into dataframes, these files form the following data model:

<img src='https://brysmiwasb.blob.core.windows.net/demos/images/kkbox_schema.png' width=300>

Each service subscriber is uniquely identified by a value in the *msno* field of the members table. Data in the transactions and user logs tables provide a record of subscription management and streaming activities, respectively.  Not every member has a complete set of data in this schema.  In addition, the transaction and streaming logs are quite verbose with multiple records being recorded for a subscriber on a given date.  On dates where there is no activity, no entries are found for a subscriber in these tables.

In order to protect data privacy, many values in these tables have been ordinal-encoded, limiting their interpretability.  In addition, timestamp information has been truncated to a daily level, making the sequencing of records on a given date dependent on business logic addressed in later steps in this notebook.
 
Let's load this data now:

In [0]:
import shutil
from datetime import date

from pyspark.sql.types import *
from pyspark.sql.functions import lit

In [0]:
# this has been added for scenarios where you might
# wish to alter some of the churn label prediction
# logic but do not wish to rerun the whole notebook
skip_reload = False

In [0]:
# create database to house SQL tables
_ = spark.sql('CREATE DATABASE IF NOT EXISTS kkbox')

In [0]:
if not skip_reload:
  
  # delete the old table if needed
  _ = spark.sql('DROP TABLE IF EXISTS kkbox.members')

  # drop any old delta lake files that might have been created
  shutil.rmtree('/dbfs/mnt/churning/silver/members', ignore_errors=True)

  # members dataset schema
  member_schema = StructType([
    StructField('msno', StringType()),
    StructField('city', IntegerType()),
    StructField('bd', IntegerType()),
    StructField('gender', StringType()),
    StructField('registered_via', IntegerType()),
    StructField('registration_init_time', DateType())
    ])

  # read data from csv
  members = (
    spark
      .read
      .csv(
        '/mnt/churning/members/members_v3.csv',
        schema=member_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  (
    members
      .write
      .format('delta')
      .mode('overwrite')
      .save('/mnt/churning/silver/members')
    )

    # create table object to make delta lake queriable
  _ = spark.sql('''
      CREATE TABLE kkbox.members 
      USING DELTA 
      LOCATION '/mnt/churning/silver/members'
      ''')

In [0]:
if not skip_reload:
  
# delete the old database and tables if needed
  _ = spark.sql('DROP TABLE IF EXISTS kkbox.transactions')

  # drop any old delta lake files that might have been created
  shutil.rmtree('/dbfs/mnt/churning/silver/transactions', ignore_errors=True)

  # transaction dataset schema
  transaction_schema = StructType([
    StructField('msno', StringType()),
    StructField('payment_method_id', IntegerType()),
    StructField('payment_plan_days', IntegerType()),
    StructField('plan_list_price', IntegerType()),
    StructField('actual_amount_paid', IntegerType()),
    StructField('is_auto_renew', IntegerType()),
    StructField('transaction_date', DateType()),
    StructField('membership_expire_date', DateType()),
    StructField('is_cancel', IntegerType())  
    ])

  # read data from csv
  transactions = (
    spark
      .read
      .csv(
        '/mnt/churning/transactions',
        schema=transaction_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  ( transactions
      .write
      .format('delta')
      .partitionBy('transaction_date')
      .mode('overwrite')
      .save('/mnt/churning/silver/transactions')
    )

    # create table object to make delta lake queriable
  _ = spark.sql('''
      CREATE TABLE kkbox.transactions
      USING DELTA 
      LOCATION '/mnt/churning/silver/transactions'
      ''')

In [0]:
if not skip_reload:
  # delete the old table if needed
  _ = spark.sql('DROP TABLE IF EXISTS kkbox.user_logs')

  # drop any old delta lake files that might have been created
  shutil.rmtree('/dbfs/mnt/churning/silver/user_logs', ignore_errors=True)

  # transaction dataset schema
  user_logs_schema = StructType([ 
    StructField('msno', StringType()),
    StructField('date', DateType()),
    StructField('num_25', IntegerType()),
    StructField('num_50', IntegerType()),
    StructField('num_75', IntegerType()),
    StructField('num_985', IntegerType()),
    StructField('num_100', IntegerType()),
    StructField('num_uniq', IntegerType()),
    StructField('total_secs', FloatType())  
    ])

  # read data from csv
  user_logs = (
    spark
      .read
      .csv(
        '/mnt/churning/user_logs',
        schema=user_logs_schema,
        header=True,
        dateFormat='yyyyMMdd'
        )
      )

  # persist in delta lake format
  ( user_logs
      .write
      .format('delta')
      .partitionBy('date')
      .mode('overwrite')
      .save('/mnt/churning/silver/user_logs')
    )

  # create table object to make delta lake queriable
  _ = spark.sql('''
    CREATE TABLE IF NOT EXISTS kkbox.user_logs
    USING DELTA 
    LOCATION '/mnt/churning/silver/user_logs'
    ''')

###Step 2: Acquire Churn Labels

To build our model, we will need to identify which customers have churned within two periods of interest.  These periods are February 2017 and March 2017.  We will train our model to predict churn in February 2017 and then evaluate our model's ability to predict churn in March 2017, making these our training and testing datasets, respectively.

Per instructions provided in the Kaggle competition, a KKBox subscriber is not identified as churned until he or she fails to renew their subscription 30-days following its expiration.  Most subscriptions are themselves on a 30-day renewal schedule (though some subscriptions renew on significantly longer cycles). This means that identifying churn involves a sequential walk through the customer data, looking for renewal gaps that would indicate a customer churned on a prior expiration date.

While the competition makes available pre-labeled training and testing datasets, *train.csv* and *train_v2.csv*, respectively, several past participants have noted that these datasets should be regenerated.  A Scala script for doing so is provided by KKBox.  Modifying the script for this environment, we might regenerate our training and test datasets as follows:

In [0]:
_ = spark.sql('DROP TABLE IF EXISTS kkbox.train')

shutil.rmtree('/dbfs/mnt/churning/silver/train', ignore_errors=True)

In [0]:
%scala

import java.time.{LocalDate}
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._
import scala.collection.mutable

def calculateLastday(wrappedArray: mutable.WrappedArray[Row]) :String ={
  val orderedList = wrappedArray.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")

      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")

      //same plan, always subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expiration date keeps extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same day same plan transaction: subscription preceeds cancellation
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })
  orderedList.last.getAs[String]("membership_expire_date")
}

def calculateRenewalGap(log:mutable.WrappedArray[Row], lastExpiration: String): Int = {
  val orderedDates = log.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")

      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")

      //same data same plan transaction, assumption: subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel of same plan, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expire date keep extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same date cancel should follow subscription
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })

  //Search for the first subscription after expiration
  //If active cancel is the first action, find the gap between the cancellation and renewal
  val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
  var lastExpireDate = LocalDate.parse(s"${lastExpiration.substring(0,4)}-${lastExpiration.substring(4,6)}-${lastExpiration.substring(6,8)}", formatter)
  var gap = 9999
  for(
    date <- orderedDates
    if gap == 9999
  ) {
    val transString = date.getAs[String]("transaction_date")
    val transDate = LocalDate.parse(s"${transString.substring(0,4)}-${transString.substring(4,6)}-${transString.substring(6,8)}", formatter)
    val expireString = date.getAs[String]("membership_expire_date")
    val expireDate = LocalDate.parse(s"${expireString.substring(0,4)}-${expireString.substring(4,6)}-${expireString.substring(6,8)}", formatter)
    val isCancel = date.getAs[String]("is_cancel")

    if(isCancel == "1") {
      if(expireDate.isBefore(lastExpireDate)) {
        lastExpireDate = expireDate
      }
    } else {
      gap = ChronoUnit.DAYS.between(lastExpireDate, transDate).toInt
    }
  }
  gap
}

val data = spark
  .read
  .option("header", value = true)
  .csv("/mnt/churning/transactions/")

val historyCutoff = "20170131"

val historyData = data.filter(col("transaction_date")>="20170101" and col("transaction_date")<=lit(historyCutoff))
val futureData = data.filter(col("transaction_date") > lit(historyCutoff))

val calculateLastdayUDF = udf(calculateLastday _)
val userExpire = historyData
  .groupBy("msno")
  .agg(
    calculateLastdayUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      )
    ).alias("last_expire")
  )

val predictionCandidates = userExpire
  .filter(
    col("last_expire") >= "20170201" and col("last_expire") <= "20170228"
  )
  .select("msno", "last_expire")


val joinedData = predictionCandidates
  .join(futureData,Seq("msno"), "left_outer")

val noActivity = joinedData
  .filter(col("payment_method_id").isNull)
  .withColumn("is_churn", lit(1))


val calculateRenewalGapUDF = udf(calculateRenewalGap _)
val renewals = joinedData
  .filter(col("payment_method_id").isNotNull)
  .groupBy("msno", "last_expire")
  .agg(
    calculateRenewalGapUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      ),
      col("last_expire")
    ).alias("gap")
  )

val validRenewals = renewals.filter(col("gap") < 30)
  .withColumn("is_churn", lit(0))
val lateRenewals = renewals.filter(col("gap") >= 30)
  .withColumn("is_churn", lit(1))

val resultSet = validRenewals
  .select("msno","is_churn")
  .union(
    lateRenewals
      .select("msno","is_churn")
      .union(
        noActivity.select("msno","is_churn")
      )
  )

resultSet.write.format("delta").mode("overwrite").save("/mnt/churning/silver/train/")

In [0]:
%sql

CREATE TABLE kkbox.train
USING DELTA
LOCATION '/mnt/churning/silver/train/';

SELECT *
FROM kkbox.train;

msno,is_churn
++4RuqBw0Ss6bQU4oMxaRlbBPoWzoEiIZaxPM04Y4+U=,0
++5Z7z4xXBhCjID+BYk/RLkqtTTAULRXhvjfOc88aEw=,0
++f8snpQnmR06b0a8bkOJE1bOryfFpqa3yXhPons9e0=,0
++orpnUqSevh2M5A97pRRiONA58g5m9DwaNrhD44HY0=,0
++pSqgOqSB8laOm+RTW6NLTqsMVQ0egh4Rs5+GOSJrQ=,0
++ywLqSa3Ts36aYwZ2FUpf8ruOCf4f/OgVfiiZ0Qnt4=,0
+/3AlIwWITFtPuXyKpWTkiBXzHOVW+BqLl9TuJXnwCA=,0
+/HS8LzrRGXolKbxRzDLqrmwuXqPOYixBIPXkyNcKNI=,0
+/eU+cOIqC35MGrmnHMXd/7D1Hm7WbgBAihN9mW9y+8=,0
+/jSYWhQxVdB3T4XCYeuYfVzcQCAbjX2aupePWsCYGc=,0


In [0]:
_ = spark.sql('DROP TABLE IF EXISTS kkbox.test')

shutil.rmtree('/dbfs/mnt/churning/silver/test', ignore_errors=True)

In [0]:
%scala

import java.time.{LocalDate}
import java.time.format.DateTimeFormatter
import java.time.temporal.ChronoUnit

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._
import scala.collection.mutable

def calculateLastday(wrappedArray: mutable.WrappedArray[Row]) :String ={
  val orderedList = wrappedArray.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")


      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")

      //same plan, always subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expiration date keeps extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same day same plan transaction: subscription preceeds cancellation
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })
  orderedList.last.getAs[String]("membership_expire_date")
}

def calculateRenewalGap(log:mutable.WrappedArray[Row], lastExpiration: String): Int = {
  val orderedDates = log.sortWith((x:Row, y:Row) => {
    if(x.getAs[String]("transaction_date") != y.getAs[String]("transaction_date")) {
      x.getAs[String]("transaction_date") < y.getAs[String]("transaction_date")
    } else {
      
      val x_sig = x.getAs[String]("plan_list_price") +
        x.getAs[String]("payment_plan_days") +
        x.getAs[String]("payment_method_id")

      val y_sig = y.getAs[String]("plan_list_price") +
        y.getAs[String]("payment_plan_days") +
        y.getAs[String]("payment_method_id")

      //same data same plan transaction, assumption: subscribe then unsubscribe
      if(x_sig != y_sig) {
        x_sig > y_sig
      } else {
        if(x.getAs[String]("is_cancel")== "1" && y.getAs[String]("is_cancel") == "1") {
          //multiple cancel of same plan, consecutive cancels should only put the expiration date earlier
          x.getAs[String]("membership_expire_date") > y.getAs[String]("membership_expire_date")
        } else if(x.getAs[String]("is_cancel")== "0" && y.getAs[String]("is_cancel") == "0") {
          //multiple renewal, expire date keep extending
          x.getAs[String]("membership_expire_date") < y.getAs[String]("membership_expire_date")
        } else {
          //same date cancel should follow subscription
          x.getAs[String]("is_cancel") < y.getAs[String]("is_cancel")
        }
      }
    }
  })

  //Search for the first subscription after expiration
  //If active cancel is the first action, find the gap between the cancellation and renewal
  val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd")
  var lastExpireDate = LocalDate.parse(s"${lastExpiration.substring(0,4)}-${lastExpiration.substring(4,6)}-${lastExpiration.substring(6,8)}", formatter)
  var gap = 9999
  for(
    date <- orderedDates
    if gap == 9999
  ) {
    val transString = date.getAs[String]("transaction_date")
    val transDate = LocalDate.parse(s"${transString.substring(0,4)}-${transString.substring(4,6)}-${transString.substring(6,8)}", formatter)
    val expireString = date.getAs[String]("membership_expire_date")
    val expireDate = LocalDate.parse(s"${expireString.substring(0,4)}-${expireString.substring(4,6)}-${expireString.substring(6,8)}", formatter)
    val isCancel = date.getAs[String]("is_cancel")

    if(isCancel == "1") {
      if(expireDate.isBefore(lastExpireDate)) {
        lastExpireDate = expireDate
      }
    } else {
      gap = ChronoUnit.DAYS.between(lastExpireDate, transDate).toInt
    }
  }
  gap
}

val data = spark
  .read
  .option("header", value = true)
  .csv("/mnt/churning/transactions/")

val historyCutoff = "20170228"

val historyData = data.filter(col("transaction_date")>="20170201" and col("transaction_date")<=lit(historyCutoff))
val futureData = data.filter(col("transaction_date") > lit(historyCutoff))

val calculateLastdayUDF = udf(calculateLastday _)
val userExpire = historyData
  .groupBy("msno")
  .agg(
    calculateLastdayUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      )
    ).alias("last_expire")
  )

val predictionCandidates = userExpire
  .filter(
    col("last_expire") >= "20170301" and col("last_expire") <= "20170331"
  )
  .select("msno", "last_expire")


val joinedData = predictionCandidates
  .join(futureData,Seq("msno"), "left_outer")

val noActivity = joinedData
  .filter(col("payment_method_id").isNull)
  .withColumn("is_churn", lit(1))


val calculateRenewalGapUDF = udf(calculateRenewalGap _)
val renewals = joinedData
  .filter(col("payment_method_id").isNotNull)
  .groupBy("msno", "last_expire")
  .agg(
    calculateRenewalGapUDF(
      collect_list(
        struct(
          col("payment_method_id"),
          col("payment_plan_days"),
          col("plan_list_price"),
          col("transaction_date"),
          col("membership_expire_date"),
          col("is_cancel")
        )
      ),
      col("last_expire")
    ).alias("gap")
  )

val validRenewals = renewals.filter(col("gap") < 30)
  .withColumn("is_churn", lit(0))
val lateRenewals = renewals.filter(col("gap") >= 30)
  .withColumn("is_churn", lit(1))

val resultSet = validRenewals
  .select("msno","is_churn")
  .union(
    lateRenewals
      .select("msno","is_churn")
      .union(
        noActivity.select("msno","is_churn")
      )
  )

resultSet.write.format("delta").mode("overwrite").save("/mnt/churning/silver/test/")

In [0]:
%sql

CREATE TABLE kkbox.test
USING DELTA
LOCATION '/mnt/churning/silver/test/';

SELECT *
FROM kkbox.test;

msno,is_churn
++4RuqBw0Ss6bQU4oMxaRlbBPoWzoEiIZaxPM04Y4+U=,0
++5Z7z4xXBhCjID+BYk/RLkqtTTAULRXhvjfOc88aEw=,0
++JHjhFuSV7upQUju29UmOHStYHmNiW5th0xEyUGW8s=,0
++f8snpQnmR06b0a8bkOJE1bOryfFpqa3yXhPons9e0=,0
++orpnUqSevh2M5A97pRRiONA58g5m9DwaNrhD44HY0=,0
++pSqgOqSB8laOm+RTW6NLTqsMVQ0egh4Rs5+GOSJrQ=,0
++ywLqSa3Ts36aYwZ2FUpf8ruOCf4f/OgVfiiZ0Qnt4=,0
+/3AlIwWITFtPuXyKpWTkiBXzHOVW+BqLl9TuJXnwCA=,0
+/HS8LzrRGXolKbxRzDLqrmwuXqPOYixBIPXkyNcKNI=,0
+/eU+cOIqC35MGrmnHMXd/7D1Hm7WbgBAihN9mW9y+8=,0


###Step 3: Cleanse & Enhance Transaction Logs

In the churn script provided by KKBox (and used in the last step), time between transaction events is used in order to determine churn status. In situations where multiple transactions are recorded on a given date, complex logic is used to determine which transaction represents the final state of the account on that date.  This logic states that when we have multiple transactions for a given subscriber on a given date, we should:</p>

1. Concatenate the plan_list_price, payment_plan_days, and payment_method_id values and consider the "bigger" of these values as preceding the others<br>
2. If the concatenated value (defined in the last step) is the same across records for this date, cancellations, *i.e.* records where is_cancel=1, should follow other transactions<br>
3. If there are multiple cancellations in this sequence, the record with the earliest expiration date is the last record for this transaction date<br>
4. If there are no cancellations but multiple non-cancellations in this sequence, the non-cancellation record with the latest expiration date is the last record on the transaction date<br>

Rewriting this logic in SQL allows us to generate a cleansed version of the transaction log with the final record for each date:

In [0]:
%sql
DROP TABLE IF EXISTS kkbox.transactions_clean;

CREATE TABLE kkbox.transactions_clean
USING DELTA
AS
  WITH 
    transaction_sequenced (
      SELECT
        msno,
        transaction_date,
        plan_list_price,
        payment_plan_days,
        payment_method_id,
        is_cancel,
        membership_expire_date,
        RANK() OVER (PARTITION BY msno, transaction_date ORDER BY plan_sort DESC, is_cancel) as sort_id  -- calc rank on price, days & method sort followed by cancel sort
      FROM (
        SELECT
          msno,
          transaction_date,
          plan_list_price,
          payment_plan_days,
          payment_method_id,
          CONCAT(CAST(plan_list_price as string), CAST(payment_plan_days as string), CAST(payment_method_id as string)) as plan_sort,
          is_cancel,
          membership_expire_date
        FROM kkbox.transactions
        )
      )
  SELECT
    p.msno,
    p.transaction_date,
    p.plan_list_price,
    p.actual_amount_paid,
    p.plan_list_price - p.actual_amount_paid as discount,
    p.payment_plan_days,
    p.payment_method_id,
    p.is_cancel,
    p.is_auto_renew,
    p.membership_expire_date
  FROM kkbox.transactions p
  INNER JOIN (
    SELECT
      x.msno,
      x.transaction_date,
      x.plan_list_price,
      x.payment_plan_days,
      x.payment_method_id,
      x.is_cancel,
      CASE   -- if is_cancel is 0 in last record then go with max membership date identified, otherwise go with lowest membership date
        WHEN x.is_cancel=0 THEN MAX(x.membership_expire_date)
        ELSE MIN(x.membership_expire_date)
        END as membership_expire_date
    FROM (
      SELECT
        a.msno,
        a.transaction_date,
        a.plan_list_price,
        a.payment_plan_days,
        a.payment_method_id,
        a.is_cancel,
        a.membership_expire_date
      FROM transaction_sequenced a
      INNER JOIN (
        SELECT msno, transaction_date, MAX(sort_id) as max_sort_id -- find last entries on a given date
        FROM transaction_sequenced 
        GROUP BY msno, transaction_date
        ) b
        ON a.msno=b.msno AND a.transaction_date=b.transaction_date AND a.sort_id=b.max_sort_id
        ) x
    GROUP BY 
      x.msno, 
      x.transaction_date, 
      x.plan_list_price,
      x.payment_plan_days,
      x.payment_method_id,
      x.is_cancel
   ) q
   ON 
     p.msno=q.msno AND 
     p.transaction_date=q.transaction_date AND 
     p.plan_list_price=q.plan_list_price AND 
     p.payment_plan_days=q.payment_plan_days AND 
     p.payment_method_id=q.payment_method_id AND 
     p.is_cancel=q.is_cancel AND 
     p.membership_expire_date=q.membership_expire_date;
     
SELECT * 
FROM kkbox.transactions_clean
ORDER BY msno, transaction_date;

msno,transaction_date,plan_list_price,actual_amount_paid,discount,payment_plan_days,payment_method_id,is_cancel,is_auto_renew,membership_expire_date
+++FOrTS7ab3tIgIh8eWwX4FqRv8w/FoiOuyXsFvphY=,2016-09-09,0,0,0,7,35,0,0,2016-09-14
+++IZseRRiQS9aaSkH6cMYU6bGDcxUieAi/tH67sC5s=,2015-11-21,1788,1788,0,410,38,0,0,2017-01-04
+++IZseRRiQS9aaSkH6cMYU6bGDcxUieAi/tH67sC5s=,2016-10-23,1599,1599,0,395,22,0,0,2018-02-06
+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2016-11-16,99,99,0,30,41,0,1,2016-12-15
+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2016-12-15,99,99,0,30,41,0,1,2017-01-15
+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2017-01-15,99,99,0,30,41,0,1,2017-02-15
+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2017-02-15,99,99,0,30,41,0,1,2017-03-15
+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2017-03-15,99,99,0,30,41,0,1,2017-04-15
+++l/EXNMLTijfLBa8p2TUVVVp2aFGSuUI/h7mLmthw=,2015-01-31,149,149,0,31,39,0,1,2015-03-19
+++l/EXNMLTijfLBa8p2TUVVVp2aFGSuUI/h7mLmthw=,2015-02-28,149,149,0,31,39,0,1,2015-04-19


Using this *cleansed* transaction data, we can now more easily identify the start and end of subscriptions using the 30-day gap logic found in the Scala code.  It's important to note that over the 2+ year period represented by the dataset, many subscribers will churn and many of those that do churn will re-subscribe.  With this in mind, we will generate a subscription ID to identify the different subscriptions, each of which will have a non-overlapping starting and ending date for a given subscriber:

In [0]:
%sql
DROP TABLE IF EXISTS kkbox.subscription_windows;

CREATE TABLE kkbox.subscription_windows 
USING delta
AS
  WITH end_dates (
      SELECT p.*
      FROM (
        SELECT
          m.msno,
          m.transaction_date,
          m.membership_expire_date,
          m.next_transaction_date,
          CASE
            WHEN m.next_transaction_date IS NULL THEN 1
            WHEN DATEDIFF(m.next_transaction_date, m.membership_expire_date) > 30 THEN 1
            ELSE 0
            END as end_flag,
          CASE
            WHEN m.next_transaction_date IS NULL THEN m.membership_expire_date
            WHEN DATEDIFF(m.next_transaction_date, m.membership_expire_date) > 30 THEN m.membership_expire_date
            ELSE DATE_ADD(m.next_transaction_date, -1)  -- then just move the needle to just prior to the next transaction
            END as end_date
        FROM (
          SELECT
            x.msno,
            x.transaction_date,
            CASE  -- correct backdated expirations for subscription end calculations
              WHEN x.membership_expire_date < x.transaction_date THEN x.transaction_date
              ELSE x.membership_expire_date
              END as membership_expire_date,
            LEAD(x.transaction_date, 1) OVER (PARTITION BY x.msno ORDER BY x.transaction_date) as next_transaction_date
          FROM kkbox.transactions_clean x
          ) m
        ) p
      WHERE p.end_flag=1
    )
  SELECT
    ROW_NUMBER() OVER (ORDER BY subscription_start, msno) as subscription_id,
    msno,
    subscription_start,
    subscription_end
  FROM (
    SELECT
      x.msno,
      MIN(x.transaction_date) as subscription_start,
      y.window_end as subscription_end
    FROM kkbox.transactions_clean x
    INNER JOIN (
      SELECT
        a.msno,
        COALESCE( MAX(b.end_date), '2015-01-01') as window_start,
        a.end_date as window_end
      FROM end_dates a
      LEFT OUTER JOIN end_dates b
        ON a.msno=b.msno AND a.end_date > b.end_date
      GROUP BY a.msno, a.end_date
      ) y
      ON x.msno=y.msno AND x.transaction_date BETWEEN y.window_start AND y.window_end
    GROUP BY x.msno, y.window_end
    )
  ORDER BY subscription_id;
  
SELECT *
FROM kkbox.subscription_windows
ORDER BY subscription_id;

subscription_id,msno,subscription_start,subscription_end
1,++KbErD/TJoTzWzoMQzvaHHnRPE5GZLjXR2YbfbJ+ow=,2015-01-01,2017-04-09
2,++X86/NhBH23Ord6wyYnjHUiBIY/OUwpyT2lPIPOSY8=,2015-01-01,2015-11-01
3,++bSthEZMtWTJ+QnupXtVNyqq8g5Xwmj9anbxgw2AOQ=,2015-01-01,2016-01-01
4,++gZjzjkC8lAqAtcOxAz0677Ygp0CE1hNqd3v9lRKG0=,2015-01-01,2017-05-21
5,++ny9gX4wM/1tNrLg/FsJPh0bLPPuY5mscc3zhnRYP8=,2015-01-01,2015-06-02
6,++yV3f1YM3HSXUnY0V6KkJORX2e7VMaE0jGP4Xo3prI=,2015-01-01,2016-01-01
7,+/AkQcvJ26pYXlnwl88w/yjUEsyJpYEXRtrenOArEAk=,2015-01-01,2015-01-31
8,+/OXTCS/xccwbuw/IBoOiO80bmJDJBDECRueoVmTgxs=,2015-01-01,2017-04-02
9,+/QE/tkXQ2rHm5fOWSfhQOBETxZKgIt4r6rF4GGfX64=,2015-01-01,2017-04-01
10,+/yURAZt6uqLEfJlosdxW3IXNgypQv8F0hoo6CkicPY=,2015-01-01,2016-12-08


To verify we have our subscription windows aligned with the script used to identify customers at-risk for churn in February and March 2017, let's perform a quick test.  The script identifies an at-risk subscription as one where the last transaction recorded in the historical period, *i.e.* the time period leading up to the start of the month of interest, has an expiration date falling between the 30-day window leading up to the start of the period of interest and the end of that period.  For example, if we were to identify at-risk customers for February 2017, we should look for those customers with active subscriptions set to expire within the 30-days before February 1, 2017 and February 28, 2017.  This shifted window allows time for the 30-day grace period to expire within the period of interest. 

**NOTE** Better logic would limit our assessment to those subscriptions with an expiration date between 30-days prior to the start of the period AND 30-days prior to the end of the period.  (Such logic would exclude subscriptions expiring within the period of interest but which do not exit the 30-day grace period until after the period is over.) When we use this logic, we find numerous subscriptions that the provided script identifies as at-risk but which we would not.  We will align our logic with that of the competition for this exercise.

With this logic in mind, let's see if all our labeled at-risk customers adhere to this logic:

**NOTE** The next two cells should return NO RESULTS if our logic is valid

In [0]:
%sql

SELECT
  x.msno
FROM kkbox.train x
LEFT OUTER JOIN (
  SELECT DISTINCT -- subscriptions that had risk in Feb 2017
    a.msno
  FROM kkbox.subscription_windows a
  INNER JOIN kkbox.transactions_clean b
    ON a.msno=b.msno AND b.transaction_date BETWEEN a.subscription_start AND a.subscription_end
  WHERE 
        a.subscription_start < '2017-02-01' AND
        (b.membership_expire_date BETWEEN DATE_ADD('2017-02-01',-30) AND '2017-02-28')
  ) y
  ON x.msno=y.msno
WHERE y.msno IS NULL

msno


In [0]:
%sql

SELECT
  x.msno
FROM kkbox.test x
LEFT OUTER JOIN (
  SELECT DISTINCT -- subscriptions that had risk in Feb 2017
    a.msno
  FROM kkbox.subscription_windows a
  INNER JOIN kkbox.transactions_clean b
    ON a.msno=b.msno AND b.transaction_date BETWEEN a.subscription_start AND a.subscription_end
  WHERE 
        a.subscription_start < '2017-03-01' AND
        (b.membership_expire_date BETWEEN DATE_ADD('2017-03-01',-30) AND '2017-03-31')
  ) y
  ON x.msno=y.msno
WHERE y.msno IS NULL

msno


While we do not fail to identify the same at-risk subscriptions as the provided script, if we were to alter the code above we would find a few subscriptions that we do identify as at-risk but which the Scala script does not. While it might be useful to examine why this is, so long as there are no members that the Scala script identifies as at risk that we do not, we should should be able to use this dataset to derive features for subscriptions in our testing and training datasets.

Leveraging subscription duration information derived in the last few cells, we can now enhance our transaction log to detect account-level changes.  This information will form the basis for transaction-feature generation in the next notebook:

In [0]:
%sql
DROP TABLE IF EXISTS kkbox.transactions_enhanced;

CREATE TABLE kkbox.transactions_enhanced
USING DELTA
AS
  SELECT
    b.subscription_id,
    a.*,
    COALESCE( DATEDIFF(a.transaction_date, LAG(a.transaction_date, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date)), 0) as days_since_last_transaction,
    COALESCE( a.plan_list_price - LAG(a.plan_list_price, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date), 0) as change_in_list_price,
    COALESCE(a.actual_amount_paid - LAG(a.actual_amount_paid, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date), 0) as change_in_actual_amount_paid,
    COALESCE(a.discount - LAG(a.discount, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date), 0) as change_in_discount,
    COALESCE(a.payment_plan_days - LAG(a.payment_plan_days, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date), 0) as change_in_payment_plan_days,
    CASE WHEN (a.payment_method_id != LAG(a.payment_method_id, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date)) THEN 1 ELSE 0 END  as change_in_payment_method_id,
    CASE
      WHEN a.is_cancel = LAG(a.is_cancel, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date) THEN 0
      WHEN a.is_cancel = 0 THEN -1
      ELSE 1
      END as change_in_cancellation,
    CASE
      WHEN a.is_auto_renew = LAG(a.is_auto_renew, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date) THEN 0
      WHEN a.is_auto_renew = 0 THEN -1
      ELSE 1
      END as change_in_auto_renew,
    COALESCE( DATEDIFF(a.membership_expire_date, LAG(a.membership_expire_date, 1) OVER(PARTITION BY b.subscription_id ORDER BY a.transaction_date)), 0) as days_change_in_membership_expire_date

  FROM kkbox.transactions_clean a
  INNER JOIN kkbox.subscription_windows b
    ON a.msno=b.msno AND 
       a.transaction_date BETWEEN b.subscription_start AND b.subscription_end
  ORDER BY 
    a.msno,
    a.transaction_date;
    
SELECT * FROM kkbox.transactions_enhanced;

subscription_id,msno,transaction_date,plan_list_price,actual_amount_paid,discount,payment_plan_days,payment_method_id,is_cancel,is_auto_renew,membership_expire_date,days_since_last_transaction,change_in_list_price,change_in_actual_amount_paid,change_in_discount,change_in_payment_plan_days,change_in_payment_method_id,change_in_cancellation,change_in_auto_renew,days_change_in_membership_expire_date
2473078,+++FOrTS7ab3tIgIh8eWwX4FqRv8w/FoiOuyXsFvphY=,2016-09-09,0,0,0,7,35,0,0,2016-09-14,0,0,0,0,0,0,-1,-1,0
1514713,+++IZseRRiQS9aaSkH6cMYU6bGDcxUieAi/tH67sC5s=,2015-11-21,1788,1788,0,410,38,0,0,2017-01-04,0,0,0,0,0,0,-1,-1,0
1514713,+++IZseRRiQS9aaSkH6cMYU6bGDcxUieAi/tH67sC5s=,2016-10-23,1599,1599,0,395,22,0,0,2018-02-06,337,-189,-189,0,-15,1,0,0,398
2745735,+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2016-11-16,99,99,0,30,41,0,1,2016-12-15,0,0,0,0,0,0,-1,1,0
2745735,+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2016-12-15,99,99,0,30,41,0,1,2017-01-15,29,0,0,0,0,0,0,0,31
2745735,+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2017-01-15,99,99,0,30,41,0,1,2017-02-15,31,0,0,0,0,0,0,0,31
2745735,+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2017-02-15,99,99,0,30,41,0,1,2017-03-15,31,0,0,0,0,0,0,0,28
2745735,+++hVY1rZox/33YtvDgmKA2Frg/2qhkz12B9ylCvh8o=,2017-03-15,99,99,0,30,41,0,1,2017-04-15,28,0,0,0,0,0,0,0,31
447259,+++l/EXNMLTijfLBa8p2TUVVVp2aFGSuUI/h7mLmthw=,2015-01-31,149,149,0,31,39,0,1,2015-03-19,0,0,0,0,0,0,-1,1,0
447259,+++l/EXNMLTijfLBa8p2TUVVVp2aFGSuUI/h7mLmthw=,2015-02-28,149,149,0,31,39,0,1,2015-04-19,28,0,0,0,0,0,0,0,31


###Step 4: Generate Dates Table

Finally, it is very likely we will want to derive features from both the transaction log and the user activity data where we examine days without activity.  To make this analysis easier, it may be helpful to generate a table containing one record for each date from the beginning date to the end date in our dataset.  We know that these data span January 1, 2015 through March 31, 2017.  With that in mind, we can generate such a table as follows:

In [0]:
# calculate days in range
start_date = date(2015, 1, 1)
end_date = date(2017, 3, 31)
days = end_date - start_date

# generate temp view of dates in range
( spark
    .range(0, days.days)  
    .withColumn('start_date', lit(start_date.strftime('%Y-%m-%d')))  # first date in activity dataset
    .selectExpr('date_add(start_date, CAST(id as int)) as date')
    .createOrReplaceTempView('dates')
  )

# persist data to SQL table
_ = spark.sql('DROP TABLE IF EXISTS kkbox.dates') 
_ = spark.sql('CREATE TABLE kkbox.dates USING DELTA AS SELECT * FROM dates')

# display SQL table content
display(spark.table('kkbox.dates').orderBy('date'))

date
2015-01-01
2015-01-02
2015-01-03
2015-01-04
2015-01-05
2015-01-06
2015-01-07
2015-01-08
2015-01-09
2015-01-10
