# Table of Contents
1. [Context and goals](#Context-and-goals)
2. [Dataset description](#Dataset-description)
2. [Data cleaning](#Data-cleaning)
3. [Exploratory Data Analysis](#Exploratory-Data-Analysis)
    1.   [Sociodemographic variables](#Sociodemographic-variables)
    2.   [Service usage variables](#Service-usage-variables)
    3.   [Conditioned Analysis](#Conditional-Analysis)
    4.   [EDA summary](#EDA-summary)
4. [Predictive Analysis](#Predictive-Analysis)
    1. [Evaluation metrics](#Evaluation-metrics)
    2. [Decision Trees](#Decision-Trees)
    3. [Gradient Boosting](#Gradient-Boosting)
    4. [Comparison and Conclusion](#Comparison)

# Context and goals
The case study is about a bank manager worried about the increasing number of churns in the credit card service. In a business, the cost to get a new customer is usually much **higher** than what it takes to keep an existing one. For this purpose, the main goal will be **predicting** the highest number of **potential churners** to let the manager proactively propose better offers to customers.

This notebook consists of a preliminary exploratory data analysis and a predictive analysis by using **decision trees** and ensemble methods (**Gradient Boosting**).

# Dataset description
The dataset consists of 10127 observations, one for each different customer bank's account. For each account the provided relevant information are the following:
* **Attrition_Flag** : Wether the user leaved or not the service (bank's account closed or not)
* **Customer_Age** : Customer's age in year
* **Gender** : Customer's gender (male or female)
* **Dependent_count** : Number of people who depend upon the customer for their support and welfare.
* **Education_Level** : Customer's educational qualification (high school, graduate, etc.)
* **Marital_Status** : Customer's marital status (married, single, etc.)
* **Income_Category** :Customer's income bracket in dollars (less than 40K, 40K-60K, etc.)
* **Card_Category** : Credit card category (Blue, Silver, etc.)
* **Months_on_book** : Period of relationship with bank in months
* **Total_Relationship_Count** : Total number of products held by the customer
* **Months_Inactive_12_mon** : Number of months inactive in the last 12 months
* **Contacts_Count_12_mon** : Number of contacts (phone calls) in the last 12 months
* **Credit_Limit** : Credit limit on the credit card
* **Total_Revolving_Bal** : Total revolving balance on the credit card
* **Avg_Open_To_Buy** : Open to buy credit line (average of last 12 months). This also turns out to be the difference between the credit limit (Credit_Limit) assigned to a cardholder account and the present balance on the account (Total_Revolving_Bal).
* **Total_Amt_Chng_Q4_Q1** : Change in total transactions amount (Q4 over Q1) 
* **Total_Trans_Amt** : Total transactions amount (last 12 months)
* **Total_Trans_Ct** : Total number of transactions (last 12 months)
* **Total_Ct_Chng_Q4_Q1** : Change in total number of transactions (Q4 over Q1)
* **Avg_Utilization_Ratio** : Average card utilization ratio

# Data cleaning

In [None]:
# Import libraries
library(tidyverse)
library(forcats)
library(psych)
library(gridExtra)
library(rlang) 
install.packages("gghalves")
library(gghalves)
library(ggrepel)

In [None]:
# load data
bank_data_raw <- read_csv("../input/credit-card-customers/BankChurners.csv")

# --- Remove useless information (account id and last two columns) ---
bank_data_raw <- bank_data_raw %>% 
select(-c(
  CLIENTNUM,
  Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_1,
  Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_2
))

bank_data_raw_2 <- bank_data_raw

# --- Attrition_Flag encoding: 0 Existing Customer, 1 Attrited Customer ---
bank_data_raw_2 <- bank_data_raw_2 %>% mutate(Attrition_Flag = recode(Attrition_Flag, "Attrited Customer" = 1, "Existing Customer" = 0))

bank_data <- bank_data_raw_2

# --- Factor conversion and levels reordering ---

bank_data$Attrition_Flag <- as_factor(bank_data$Attrition_Flag)

bank_data$Gender <- as_factor(bank_data$Gender)

bank_data$Education_Level <- as_factor(bank_data$Education_Level)
bank_data$Education_Level <- fct_relevel(bank_data$Education_Level, "Unknown", "Uneducated", "High School", "College", "Graduate", "Post-Graduate", "Doctorate")

bank_data$Marital_Status <- as_factor(bank_data$Marital_Status)
bank_data$Marital_Status <- fct_relevel(bank_data$Marital_Status, "Unknown", "Single", "Married", "Divorced")

bank_data$Income_Category <- as_factor(bank_data$Income_Category)
bank_data$Income_Category <- fct_relevel(bank_data$Income_Category, "Unknown", "Less than $40K", "$40K - $60K", "$60K - $80K", "$80K - $120K", "$120K +")

bank_data$Card_Category <- as_factor(bank_data$Card_Category)
bank_data$Card_Category <- fct_relevel(bank_data$Card_Category, "Blue", "Silver", "Gold", "Platinum")

## Missing data

The dataset contains few **missing data** with values "Unknown" in *Education_Level*, *Marital_Status*, *Income_Category*. 

These data have been considered **not imputable** because they may not heavily depend on other variables. Also, due to the low frequency of this modality, imputation would not provide any drastic improvement in the results.

# Exploratory Data Analysis

Sociodemographic and service usage variables have been analyzed separately. Also, a conditioned analysis has been conducted to better understand which variables could be more relevant to separate churners from not churners.

In [None]:
options(repr.plot.width = 14, repr.plot.height = 8)

gender_fill_scale <- scale_fill_manual(
  values = c("#4292c6", "#fb6a4a"), 
  breaks = c("M", "F"), 
  labels = c("M", "F")
) 

card_fill_scale <- scale_fill_manual(
  values = c("#08519c", "#bbbbbb", "#d3a826", "#f6f9fb"), 
  breaks = c("Blue", "Silver", "Gold", "Platinum"), 
  labels = c("Blue", "Silver", "Gold", "Platinum")
) 

attrition_fill_scale <- scale_fill_manual(
  values = c("#7FC97F", "#FDC086"), 
  breaks = c(0, 1), 
  labels = c("0", "1")
)

attrition_color_scale <- scale_color_manual(
  values = c("#7FC97F", "#FDC086"), 
  breaks = c(0, 1), 
  labels = c("0", "1")
)

# --- Histogram utility function ---
get_histogram <- function(data, var, binwidth = 30) {
  p <- data %>%
    ggplot() +
    geom_histogram(aes(x = get(var)), color = "black", fill = "#f6f9fb", binwidth = binwidth) +
    geom_vline(aes(xintercept = summary(get(var))[4], linetype="mean"), size=1.3) +
    geom_vline(aes(xintercept = summary(get(var))[3], linetype="median"), size=1.3) +
    theme_classic() +
    labs(
      x = var,
      y = "frequency",
      title = paste(var, "distribution"),
      linetype = "measure"
    )
  
  return(p)
}

# --- Box plot utility function
get_boxplot <- function(data, var) {
  data %>%
    ggplot() +
    geom_boxplot(aes(x = get(var)), alpha = 0.7, color = "black", fill = "#f6f9fb") +
    theme_classic() +
    labs(
      x = var,
      title = paste(var, "distribution")
    )
}

# --- Barplot with absolute and percentage frequencies utility function ---
get_var_freq_mixed <- function(data, var, palette) {
    
  p <- data %>%
    count(get(var)) %>%
    mutate(pct = n / sum(n),
           pctlabel = paste0(n, " (", round(pct*100, 2), "%)")) %>%
    ggplot(aes(x = factor(`get(var)`), y = n, fill = factor(`get(var)`))) + 
    geom_bar(stat = "identity", color = "black") +
    geom_text(aes(label = pctlabel), vjust = -0.25, size = 4) +
    theme_classic() +
    labs(
      x = var,
      y = "Absolute frequency",
      fill = var,
      title = paste(var, "distribution")
    )
  
  if(hasArg(palette)) {
    p <- p + scale_fill_brewer(palette=palette)
  }
  
  return(p)
}

## Sociodemographic variables

Six variables provide sociodemographic information: 
* *Customer_Age*
* *Gender*
* *Dependent_count*
* *Education_Level*
* *Marital_Status*
* *Income_Category*


### Customer_Age

In [None]:
summary(bank_data$Customer_Age)

In [None]:
sd(bank_data$Customer_Age) 

In [None]:
get_histogram(bank_data, var = "Customer_Age", binwidth = 1)

### Gender

In [None]:
get_var_freq_mixed(bank_data, var = "Gender") + gender_fill_scale

### Dependent_count

In [None]:
get_var_freq_mixed(bank_data, var = "Dependent_count", palette = "Blues")

### Education_Level


In [None]:
get_var_freq_mixed(bank_data, var = "Education_Level", palette = "Reds")

### Marital_Status

In [None]:
get_var_freq_mixed(bank_data, var = "Marital_Status", palette = "Spectral")

### Income_Category

In [None]:
get_var_freq_mixed(bank_data, var = "Income_Category", palette = "Greens")

## Service usage variables
Thirteen variables provide service usage information:

* *Card_Category*
* *Months_on_book*
* *Total_Relationship_Count*
* *Months_Inactive_12_mon*
* *Contacts_Count_12_mon*
* *Credit_Limit*
* *Total_Revolving_Bal*
* *Avg_Open_To_Buy*
* *Total_Amt_Chng_Q4_Q1*
* *Total_Trans_Amt*
* *Total_Trans_Ct*
* *Total_Ct_Chng_Q4_Q1*
* *Avg_Utilization_Ratio*

### Card_Category

In [None]:
get_var_freq_mixed(bank_data, var = "Card_Category") + card_fill_scale

### Months_on_book

In [None]:
summary(bank_data$Months_on_book)

In [None]:
sd(bank_data$Months_on_book) 

In [None]:
get_histogram(bank_data, var = "Months_on_book", binwidth = 1)

### Total_Relationship_count

In [None]:
get_var_freq_mixed(bank_data, var = "Total_Relationship_Count", palette = "Blues")

### Months_Inactive_12_mon

In [None]:
get_var_freq_mixed(bank_data, var = "Months_Inactive_12_mon", palette = "Reds")

### Contacts_Count_12_mon

In [None]:
get_var_freq_mixed(bank_data, var = "Contacts_Count_12_mon", palette = "Blues")

### Credit_Limit

In [None]:
summary(bank_data$Credit_Limit)

In [None]:
IQR(bank_data$Credit_Limit)

In [None]:
get_histogram(bank_data, var = "Credit_Limit", binwidth = 1000)

In [None]:
get_boxplot(bank_data, var = "Credit_Limit")

### Total_Revolving_Bal

In [None]:
summary(bank_data$Total_Revolving_Bal)

In [None]:
get_histogram(bank_data, var = "Total_Revolving_Bal", binwidth = 50)

In [None]:
# 1st mode
bank_data %>%
    filter(Total_Revolving_Bal %in% c(0:50)) %>%
    count(Total_Revolving_Bal) %>%
    arrange(desc(n)) %>%
    head(1)

# 2nd mode
bank_data %>%
    filter(Total_Revolving_Bal %in% c(2000:2517)) %>%
    count(Total_Revolving_Bal) %>%
    arrange(desc(n)) %>%
    head(1)

### Avg_Open_To_Buy

In [None]:
summary(bank_data$Avg_Open_To_Buy)

In [None]:
IQR(bank_data$Avg_Open_To_Buy)

In [None]:
get_histogram(bank_data, var = "Avg_Open_To_Buy", binwidth = 1000)

In [None]:
get_boxplot(bank_data, var = "Avg_Open_To_Buy")

### Total_Amt_Chng_Q4_Q1

In [None]:
summary(bank_data$Total_Amt_Chng_Q4_Q1)

In [None]:
IQR(bank_data$Total_Amt_Chng_Q4_Q1)

In [None]:
get_histogram(bank_data, var = "Total_Amt_Chng_Q4_Q1", binwidth = 0.05)

In [None]:
get_boxplot(bank_data, var = "Total_Amt_Chng_Q4_Q1")

### Total_Trans_Amt

In [None]:
summary(bank_data$Total_Trans_Amt)

In [None]:
IQR(bank_data$Total_Trans_Amt)

In [None]:
get_histogram(bank_data, var = "Total_Trans_Amt", binwidth = 500)

In [None]:
get_boxplot(bank_data, var = "Total_Trans_Amt")

### Total_Trans_Ct

In [None]:
summary(bank_data$Total_Trans_Ct)

In [None]:
grid.arrange(
    get_histogram(bank_data, var = "Total_Trans_Ct", binwidth = 5),
    get_histogram(bank_data %>% filter(Total_Trans_Ct %in% c(30:90)), var = "Total_Trans_Ct", binwidth = 1)
)

In [None]:
# 1st mode
bank_data %>%
    filter(Total_Trans_Ct %in% c(0:50)) %>%
    count(Total_Trans_Ct) %>%
    arrange(desc(n)) %>%
    head(1)

# 2nd mode
bank_data %>%
    filter(Total_Trans_Ct %in% c(50:100)) %>%
    count(Total_Trans_Ct) %>%
    arrange(desc(n)) %>%
    head(1)

### Total_Ct_Chng_Q4_Q1

In [None]:
summary(bank_data$Total_Ct_Chng_Q4_Q1)

In [None]:
IQR(bank_data$Total_Ct_Chng_Q4_Q1)

In [None]:
get_histogram(bank_data, var = "Total_Ct_Chng_Q4_Q1", binwidth = 0.05)

In [None]:
get_boxplot(bank_data, var = "Total_Ct_Chng_Q4_Q1")

### Avg_Utilization_Ratio

In [None]:
summary(bank_data$Avg_Utilization_Ratio)

In [None]:
get_histogram(bank_data, var = "Avg_Utilization_Ratio", binwidth = 0.03)

## Conditioned Analysis

To better understand which variables could be more relevant to separate churners from not churners, every variable has been conditioned to the response variable *Attrition_Flag* (1 = churn, 0 = not churn).

In [None]:
options(repr.plot.width = 20, repr.plot.height = 10)

# Percentage bar plot conditioned to Attrition_Flag helper function
get_var_freq_by_attrition <- function(data, var, palette) {
    p <- data %>%
      count(Attrition_Flag, get(var)) %>%
      group_by(Attrition_Flag) %>%
      mutate(pct = n / sum(n),
             pctlabel = paste0(round(pct*100, 2), "%")) %>%
      ggplot(aes(x = factor(`get(var)`), y = pct, fill = factor(`get(var)`))) + 
      geom_bar(stat = "identity", color = "black") +
      geom_text(aes(label = pctlabel), vjust = -0.20) +
      theme_classic() +
      facet_wrap(~ Attrition_Flag, nrow = 2) +
      labs(
        x = var,
        y = "Percentage frequency",
        fill = var,
        title = paste("Percentage frequencies of", var, "by Attrition_Flag")
      ) 

  if(hasArg(palette)) {
    p <- p + scale_fill_brewer(palette=palette)
  }
  
  return(p)
}

# Box plot conditioned to Attrition_Flag helper function
get_boxplot_by_attrition <- function(data, var) {
  data %>%
    ggplot() +
    geom_boxplot(aes(y = Attrition_Flag, x = get(var), fill = Attrition_Flag), outlier.size = 0.5) +
    attrition_fill_scale +
    theme_classic() +
    labs(
      x = var,
      title = paste(var, "by Attrition_Flag")
    )
}

get_boxviolinplot_by_attrition <- function(data, var) {
  p <- data %>%
    ggplot() +
    geom_half_violin(aes(x = Attrition_Flag, y = get(var), fill = Attrition_Flag), alpha = 0.3, colour = "grey", side = "r", width = 1.3) +
    geom_half_boxplot(aes(x = Attrition_Flag, y = get(var)), width = 0.3, outlier.size = 0.) +
    attrition_fill_scale +
    attrition_color_scale +
    theme_classic() +
    coord_flip() +
    labs(
      y = var,
      title = paste(var, "by Attrition_Flag")
    )
  
  return(p)
}

### Discrete variables

In [None]:
grid.arrange(
    ncol = 3, 
    get_var_freq_by_attrition(bank_data, var = "Gender") + gender_fill_scale,
    get_var_freq_by_attrition(bank_data, var = "Dependent_count", palette = "Blues"),
    get_var_freq_by_attrition(bank_data, var = "Education_Level", palette = "Reds")
)

In [None]:
grid.arrange(
    ncol = 3, 
    get_var_freq_by_attrition(bank_data, var = "Marital_Status", palette = "Spectral"),
    get_var_freq_by_attrition(bank_data, var = "Income_Category", palette = "Greens"),
    get_var_freq_by_attrition(bank_data, var = "Card_Category") + card_fill_scale
)

In [None]:
grid.arrange(
    ncol = 3, 
    get_var_freq_by_attrition(bank_data, var = "Total_Relationship_Count", palette = "Blues"),
    get_var_freq_by_attrition(bank_data, var = "Months_Inactive_12_mon", palette = "Reds"),
    get_var_freq_by_attrition(bank_data, var = "Contacts_Count_12_mon", palette = "Blues")
)

### Continuous variables

In [None]:
grid.arrange(
    nrow = 2,
    get_boxplot_by_attrition(bank_data, "Customer_Age"),
    get_boxplot_by_attrition(bank_data, "Months_on_book"),
    get_boxplot_by_attrition(bank_data, "Credit_Limit"),
    get_boxplot_by_attrition(bank_data, "Total_Revolving_Bal"),
    get_boxplot_by_attrition(bank_data, "Avg_Open_To_Buy"),
    get_boxplot_by_attrition(bank_data, "Total_Amt_Chng_Q4_Q1"),
    get_boxplot_by_attrition(bank_data, "Total_Trans_Amt"),
    get_boxplot_by_attrition(bank_data, "Total_Trans_Ct"),
    get_boxplot_by_attrition(bank_data, "Total_Ct_Chng_Q4_Q1"),
    get_boxplot_by_attrition(bank_data, "Avg_Utilization_Ratio")
)

In [None]:
suppressWarnings(
    grid.arrange(
        nrow = 2,
        get_boxviolinplot_by_attrition(bank_data, "Customer_Age"),
        get_boxviolinplot_by_attrition(bank_data, "Months_on_book"),
        get_boxviolinplot_by_attrition(bank_data, "Credit_Limit"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Revolving_Bal"),
        get_boxviolinplot_by_attrition(bank_data, "Avg_Open_To_Buy"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Amt_Chng_Q4_Q1"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Trans_Amt"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Trans_Ct"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Ct_Chng_Q4_Q1"),
        get_boxviolinplot_by_attrition(bank_data, "Avg_Utilization_Ratio")
    )
)

In [None]:
# continuous variables subset
grid.arrange(
    nrow = 2,
    get_boxplot_by_attrition(bank_data, "Total_Trans_Ct"),
    get_boxplot_by_attrition(bank_data, "Total_Trans_Amt"),
    get_boxplot_by_attrition(bank_data, "Total_Ct_Chng_Q4_Q1"),
    get_boxplot_by_attrition(bank_data, "Total_Revolving_Bal")
)

In [None]:
# continuous variables subset
suppressWarnings(
    grid.arrange(
        nrow = 2,
        get_boxviolinplot_by_attrition(bank_data, "Total_Trans_Ct"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Trans_Amt"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Ct_Chng_Q4_Q1"),
        get_boxviolinplot_by_attrition(bank_data, "Total_Revolving_Bal")
    )
)

## EDA summary

**Sociodemographic** variables:
* Customers' age ranges from 26 to 73 years (mean 46.33 ± 8.017)
* 47.09% of customers are males, whereas 52.91% are females.
* More than 91% of customers have at least one person who depends upon them for support and welfare, whereas more than 50% have two or three people.
* Most frequent qualifications are "Graduate" (30.89%) and "High School" (19.88%). 15% of observations have no information.
* Most frequent marital status is "Married" (46.28%), followed by "Single" (38.94%). Only 7.4% of customers are divorced.
* More than 50% of customers have an income higher than 40K dollars. The most frequent income bracket is 40K dollars.

Most relevant **service usage** variables:
* The period of relationship with bank ranges from 13 to 56 months (mean 35.93 ± 7.99)
* More than 92% of customers have been inactive between 1 and 3 months in the last 12 months. Only 29 customers (0.29%) haven't been inactive in the last 12 months
* More than 78% of customers hold at least 3 or more products. Only the 8.99% hold just one product.
* The total number of transactions made in the last 12 months ranges from 10 to 139 transactions, with two modes at 43 and 81 (highest one)
* The total transactions amount made in the last 12 months ranges from 510 to 18484 dollars, with a median equal to 3900 dollars and an interquartile range of 2156-4741   
* The total revolving balance on the credit card ranges from 0 to 2517 dollars (these two values are also the modes). 

Variables with an observable **separation** for the response variable:
* *Total_Relationship_Count*, *Months_Inactive_12_mon* and *Contacts_Count_12_mon* shows some percentage differences when conditioned to Attrition_Flag
* *Total_Trans_Ct*, *Total_Trans_Amt*, *Total_Ct_Chng_Q4_Q1* and *Total_Revolving_Bal* distributions show a separation when conditioned to Attrition_Flag

# Predictive Analysis

The problem has been analyzed as a **binary classification** considering that the response (Attrition_Flag) is a dichotomous variable. In this respect, a 75/25% train-test split has been used and [other 19 variables](#Dataset-description) have been considered as explanatory variables. K-Fold cross-validation has been used to evaluate models during the training phase.

In [None]:
set.seed(123)
sel <- sample(1: nrow(bank_data), size = (nrow(bank_data)*75)/100, replace = FALSE)

# train - test split (75% - 25%)
train <- bank_data[sel, ]
test <- bank_data[-sel, ]

In [None]:
table(bank_data$Attrition_Flag)
table(train$Attrition_Flag)
table(test$Attrition_Flag)

## Evaluation metrics
When facing a classification problem, **confusion matrix** is generally used to obtain evaluation metrics such as [accuracy](https://en.wikipedia.org/wiki/Precision_and_recall). 

However, it's important to underline that in the present case the dataset is heavily **unbalanced** (only 16.07% of observations are about churners). 
In such a case, accuracy can be a misleading classification metric because, for example, a simple model with only "not churner" as a response would have 83.93% accuracy.

A more robust approach is to use the [recall](https://en.wikipedia.org/wiki/Precision_and_recall) (or sensitivity) and [precision](https://en.wikipedia.org/wiki/Precision_and_recall) (or PPV) metrics. However, these two metrics affect each other (the higher the one it is the lower the other is) by forcing to choose a **trade-off** depending on the context. In the present case, a low recall means that only a few actual churners have been identified as churners, whereas a low precision means that only a few customers predicted as churners are actually churners. Not all trade-offs are equally "fair" (e.g. a little increase on recall could lend to a more drastic decrease in precision), so it could be useful to have a measure of "fairness". 

To better evaluate the two metrics, a harmonic mean called [F1-score](https://en.wikipedia.org/wiki/F-score) can be used to aggregate those two metrics and measure both performances. To achieve the main goal and still get the best of both metrics, the proposed approach is **maximizing F1**. However, with equal F1 values, **higher recall is preferable** to higher precision, as the bank may be more interested in identifying as many churners as possible in exchange for some risk to wrongly classify a non-churner as a churner and proposing them unnecessary better offers.

In [None]:
options(repr.plot.width = 14, repr.plot.height = 8)
get_var_freq_mixed(bank_data, "Attrition_Flag") + attrition_fill_scale

In [None]:
# Classification metrics helper function
classification_metrics <- function(confusion_matrix) { 
  metrics <- data.frame(matrix(0, nrow=1, ncol=5))

  metrics[1:4]<-confusion_matrix$byClass[c(2,6,5,7)]
  metrics[5]<-confusion_matrix$overall[1]
  names(metrics)<-c("Specificity", "Recall", "Precision", "F1", "Accuracy")
  
  return(metrics)
}

## Decision Trees

In [None]:
library(rpart)
library(rpart.plot)
library(caret)

### Full Tree
A classification tree has been generated by using a complexity parameter cp = 0.001. The resulting tree has 63 leaves.

In [None]:
set.seed(123)

# tree training
tree_0 <- rpart(Attrition_Flag ~ ., data = train, method = "class", cp=0.001)

# predict on test set
test_pred_tree_0<- predict(tree_0, newdata = test, type = "class") 

# get confusion matrix
conf_matrix_tree_0 <- caret::confusionMatrix(test_pred_tree_0, test$Attrition_Flag, positive = "1", mode = "everything")
conf_matrix_tree_0$table

# get classification metrics
(metrics_tree_0 <- classification_metrics(conf_matrix_tree_0))

F1, recall, and precision performances are already good, although precision is higher than recall. The accuracy value is pretty high as we expected from the unbalanced dataset's nature.

### Pruned Tree
The previous tree can be pruned by using the "one-standard-error" rule. The selected tree has 23 leaves (cp = 0.0046)

In [None]:
options(repr.plot.width = 16, repr.plot.height = 8)
plotcp(tree_0)

In [None]:
# tree pruning
tree_1 <- prune(tree_0, cp=0.0046)

# predict on test set
test_pred_tree_1<- predict(tree_1, newdata = test, type = "class") 

# get confusion matrix
conf_matrix_tree_1 <- caret::confusionMatrix(test_pred_tree_1, test$Attrition_Flag, positive = "1", mode = "everything")
conf_matrix_tree_1$table

# get classification metrics
(metrics_tree_1 <- classification_metrics(conf_matrix_tree_1))

The resulting tree is simpler than the original one and it also has slightly better performance. 

### Prior Tree
One way to fix class imbalance is specifying the response variable distribution

In [None]:
set.seed(123)

# tree training with explicit class frequencies
tree_2 <- rpart(Attrition_Flag ~ ., data = train, method = "class", parms=list(prior=c(0.16, 0.84)))

# predict on test set
test_pred_tree_2 <- predict(tree_2, newdata = test, type = "class") 

# get confusion matrix
conf_matrix_tree_2 <- confusionMatrix(data = test_pred_tree_2, reference = test$Attrition_Flag, positive = "1")
conf_matrix_tree_2$table

# get classification metrics
(metrics_tree_2 <- classification_metrics(conf_matrix_tree_2))

The performance shows a drastic increase in recall and likewise decreases in precision (and F1). In spite of a really high recall, the model suffers from **really low precision**, which would lead to often wrongly classify a non-churner as a churner (65% of predicted as churners actually are not churners). 

Cost-sensitive techniques allow to better fine-tune the precision-recall trade-off.

### Cost-sensitivity optimization
To achieve a recall higher than precision, the cost of misclassification errors can be taken into account with [**Cost-Sensitive Learning**](https://machinelearningmastery.com/cost-sensitive-learning-for-imbalanced-classification/):
> In cost-sensitive learning instead of each instance being either correctly or incorrectly classified, each class (or instance) is given a misclassification cost. Thus, instead of trying to optimize the accuracy, the problem is then to minimize the total misclassification cost.
Most classifiers assume that the misclassification costs (false negative and false positive cost) are the same. In most real-world applications, this assumption is not true.

In the present case, a higher misclassification cost must be given to the minority class (1 = churn). This can be achieved by specifying an asymmetric cost matrix and assigning a higher cost to False Negatives. This cost can be tuned to select the best precision-recall trade-off:

In [None]:
# missclassification costs between 1 and 4 with step 0.1
costs_tree<-seq(1,4,0.1)

# optimization results
optimization_results_tree<-data.frame(matrix(0, nrow=length(costs_tree), ncol=6))
names(optimization_results_tree)<-c("Cost", "Specificity", "Recall", "Precision", "F1", "Accuracy")

# create a tree for each i-th cost 
for(i in seq_along(costs_tree)) {
  
    set.seed(123)

    # i-th tree with i-th cost
    tree_opt<-rpart(Attrition_Flag ~ ., data = train, method = "class", parms=list(loss=c(0,costs_tree[i],1,0)))

    # predict on test set
    tree_opt_pred<- predict(tree_opt, newdata = test, type = "class") 

    # get confusion matrix
    conf_matrix_tree_opt <- caret::confusionMatrix(tree_opt_pred, test$Attrition_Flag, positive = "1", mode = "everything")

    # classification metrics
    optimization_results_tree[i,1]<-costs_tree[i]
    optimization_results_tree[i, 2:6] <- classification_metrics(conf_matrix_tree_opt)
  
}

# Recall optimization results
optimization_results_tree

In [None]:
optimization_results_tree_plot_data <- optimization_results_tree %>%
    pivot_longer(cols = c(-Cost), names_to = "metric", values_to = "value")

optimization_results_tree_plot_data %>%
    ggplot() +
    geom_line(aes(x = Cost, y = value, color = metric), size = 1, alpha = 0.4) +
    scale_x_continuous(breaks = scales::pretty_breaks(n = 20)) +
    scale_y_continuous(breaks = scales::pretty_breaks(n = 20)) +
    geom_text_repel(
      data = optimization_results_tree_plot_data %>% filter(Cost %in% c(1.0, 1.3, 1.6, 2.0, 2.7, "3.8")) %>% mutate(value = round(value,3)), 
      aes(x = Cost, y = value, label = value, color = metric),
      direction = "y",
      show.legend = F
    ) +  
    theme_classic() +
    scale_color_brewer(palette="Set1") +
    labs(
      x = "Cost",
      y = "Value",
      color = "Metric",
      title = "Metrics trends by varying the missclassification cost - classification tree"
    )

In [None]:
  optimization_results_tree_indexes_f1 <- c(7,11,18,29)
  optimization_results_tree_indexes_f1_selected <- c(4)
  
  optimization_results_tree_results_pr <- optimization_results_tree
  optimization_results_tree_results_pr <- optimization_results_tree_results_pr %>% 
    mutate(selected = if_else(row_number() %in% optimization_results_tree_indexes_f1_selected, 1, 0))
  
  optimization_results_tree_results_pr %>%
    ggplot() +
    geom_point(aes(x = Recall, y = Precision, color = F1), size = 4) +
    geom_line(aes(x = Recall, y = Precision), size = 0.8, alpha = 0.15) +
    geom_label_repel(
      data = optimization_results_tree_results_pr %>% 
        filter(row_number() %in% c(optimization_results_tree_indexes_f1, optimization_results_tree_indexes_f1_selected)), 
      aes(
        x = Recall, 
        y = Precision, 
        label = paste0("c = ", Cost), 
        fill = factor(selected)
      ), 
      size = 4
    ) +
    theme_classic() +
    scale_colour_gradientn(colours = c("#ff0000","#ff9c00","#fffca6"),
                           values = c(1.0,0.8,0))  +
    scale_x_continuous(breaks = scales::pretty_breaks(n = 10)) +
    scale_y_continuous(breaks = scales::pretty_breaks(n = 10)) +
    scale_fill_manual(
      values = c("grey", "white"), 
      breaks = c(1, 0), 
      labels = c(1, 0)
    ) +
    labs(
      x = "Recall",
      y = "Precision",
      title = "Recall, Precision and F1 trends by varying the missclassification cost - classification tree"
    ) +
    guides(fill=FALSE)

In [None]:
# Selected procedures
optimization_results_tree[c(4,11,18,29),]

The procedures shown above are the most relevant procedures obtained by the cost-sensitivity optimization.
The procedure with a cost equal to 1.3 is the one with the highest F1 and with recall higher than precision. The procedures with costs equal to 2.0, 2.7, 3.8 offer higher recall but lower precision (and F1). 

In [None]:
set.seed(123)

# tree with cost 1.3
tree_opt_1 <- rpart(Attrition_Flag ~ ., data = train, method = "class", parms=list(loss=c(0,1.3,1,0)))
tree_opt_1_pred <- predict(tree_opt_1, newdata = test, type = "class") 
conf_matrix_tree_opt_1 <- caret::confusionMatrix(tree_opt_1_pred, test$Attrition_Flag, positive = "1", mode = "everything")
metrics_tree_opt_1 <- classification_metrics(conf_matrix_tree_opt_1)

A quick comparison with the **Pruned Tree** shows that a cost equal to 1.3 gives the **best precision-recall trade-off**, with comparable F1 and **higher recall**.

In [None]:
# Pruned Tree
conf_matrix_tree_0$table
metrics_tree_0

In [None]:
# Tree with cost 1.3
conf_matrix_tree_opt_1$table
metrics_tree_opt_1

### Variables importance

In [None]:
options(repr.plot.width = 18, repr.plot.height = 8)

tree_opt_1_varimp <- data.frame(tree_opt_1$variable.importance)
tree_opt_1_varimp$variable <- rownames(tree_opt_1_varimp)
names(tree_opt_1_varimp) <- c("importance", "variable")

tree_opt_1_varimp %>%
    ggplot(aes(x = importance, y = reorder(variable, importance))) +
    geom_bar(aes(fill = importance), stat = "identity", color = "black") +
    geom_text(aes(label = round(importance, 3)), vjust = 0.20, hjust = -0.20) +
    scale_fill_gradient(low = "#c7e9c0", high = "#00441b") +
    theme_classic() +
    labs(
      title = "Variables importance - Classification tree",
      x = "Importance",
      y = "Variable",
      fill = "Importance"
    ) + guides(fill = FALSE)

The **most important variables** obtained from the model **confirm the EDA**: *Total_Trans_Ct*, *Total_Trans_Amt* and *Total_Revolving_Bal* are the most important ones. Also, *Total_Ct_Chng_Q4_Q1* and *Total_Relationship_Count* appears in the top 6.

### Decision trees summary
The most relevant decision tree performances are set out below. 

The tree with a cost equal to 1.3 gives the best performance (high F1 and recall higher than precision).

In [None]:
comparison_tree_opt <- data.frame(matrix(nrow = 7, ncol = 6))
names(comparison_tree_opt) <- c("Cost", "Specificity", "Recall", "Precision", "F1", "Accuracy")

comparison_tree_opt[1,1] <- "Full Tree"
comparison_tree_opt[1,2:6] <- metrics_tree_0

comparison_tree_opt[2,1] <- "Pruned Tree"
comparison_tree_opt[2,2:6] <- metrics_tree_1

comparison_tree_opt[3,1] <- "Prior Tree"
comparison_tree_opt[3,2:6] <- metrics_tree_2

comparison_tree_opt[4,] <- optimization_results_tree[4, ]
comparison_tree_opt[5,] <- optimization_results_tree[11, ]
comparison_tree_opt[6,] <- optimization_results_tree[18, ]
comparison_tree_opt[7,] <- optimization_results_tree[29, ]

comparison_tree_opt

## Gradient Boosting
A gradient boosting model allows improving overall performances by iteratively building trees based on wrongly classified observations. This can be useful to improve minor class predictions in an unbalanced dataset as in the present case. 

The model takes in input three parameters: **𝐵** (n.trees), **𝜆** (shrinkage) and **𝑑** (interaction.depth). To optimize the recall, another parameter **𝑐** (cost) has been taken into account to assign a higher misclassification cost to False Negatives. By using a set of values for each parameter, a boosting procedure has been built for every possible combination of them. A fixed value has been assigned to **𝐵** for every procedure. This parameter has been optimized later by considering cross-validation error. In particular, stratified cross-validation has been used to improve the data representation for each fold.

In [None]:
library(gbm)

In [None]:
# gbm library requires a non-factor response variable
train_boost <- train
test_boost <- test

# convert factor to numeric (0 and 1)
train_boost$Attrition_Flag <- as.numeric(train$Attrition_Flag) - 1
test_boost$Attrition_Flag <- as.numeric(test_boost$Attrition_Flag) - 1

table(train_boost$Attrition_Flag)
table(test_boost$Attrition_Flag)

In [None]:
# metrics names
metrics_names_boost <- c(
  "cost",
  "best.iteration", 
  "shrinkage",  
  "interaction.depth",
  "Specificity",
  "Recall",
  "Precision",
  "F1",
  "Accuracy"
)

# gbm training helper method. It takes training set, test set and parameters in input and returns
# a list containing the trained model and its performances.
train_boosting_with_params <- function(train, test, shrinkage, interaction.depth, cost, n.trees) {
  
  boost_model_opt_proc <- list()
  
  set.seed(1)
  
  # gbm training
  boost_model_opt_proc[["model"]] <-gbm(
      
    Attrition_Flag~., 
    distribution = "bernoulli",
    
    data=train,                                # use the training set
    train.fraction = 1,                        # use the whole training set
    cv.folds = 5,                              # k-fold cross validation with 5 folds
    class.stratify.cv = TRUE,                  # stratified cross validation
    
    # parameters ---------------------------------------------------------
    shrinkage = shrinkage,                    
    interaction.depth = interaction.depth,     
    n.trees = n.trees,                        
    
    weights = ifelse(train$Attrition_Flag == 1, cost, 1) # missclassification cost
    
  )
  
  # get best iteration (n.trees) on cross-validation
  boost_model_opt_proc[["best.iter"]] <- gbm.perf(boost_model_opt_proc[["model"]], method = "cv")[1]
  
  # predict on test set (probs)
  boost_model_opt_proc[["pred.prob"]] <- predict(
      boost_model_opt_proc[["model"]], 
      test, 
      n.trees = boost_model_opt_proc[["best.iter"]], 
      type = "response"
  )
  
  # set the prob cut-off at 0.5
  boost_model_opt_proc[["cutoff"]] <- 0.5
  
  # convert probs to classes
  boost_model_opt_proc[["pred.class"]] <- ifelse(boost_model_opt_proc[["pred.prob"]] >  boost_model_opt_proc[["cutoff"]],1,0)
  boost_model_opt_proc[["pred.class"]] <- as_factor(boost_model_opt_proc[["pred.class"]])
  
  # get the confusion matrix
  conf_mat_opt_proc <- caret::confusionMatrix(
      boost_model_opt_proc[["pred.class"]], 
      as_factor(test$Attrition_Flag), 
      positive = "1", 
      mode = "everything"
  )
  
  # results
  boost_model_opt_proc[["results"]] <- data.frame(matrix(0, nrow=1, ncol=9)) 
  names(boost_model_opt_proc[["results"]]) <- metrics_names_boost
  
  boost_model_opt_proc[["results"]][1,1]<-cost 
  boost_model_opt_proc[["results"]][1,2]<-boost_model_opt_proc[["best.iter"]][1]
  boost_model_opt_proc[["results"]][1,3]<-shrinkage
  boost_model_opt_proc[["results"]][1,4]<-interaction.depth
  boost_model_opt_proc[["results"]][1,5:9] <- classification_metrics(conf_mat_opt_proc)
  
  return(boost_model_opt_proc)
}

### Optimization procedures

In [None]:
options(repr.plot.width = 8, repr.plot.height = 6)

# shrinkage tuning values
shrinkage_parms <- c(0.05, 0.1, 0.2)

# interaction.depth tuning values
interaction_depth_params <- c(1, 3, 4, 6)

# cost tuning values
costs_boost_params <- c(1, 1.5, 2, 5, 7, 10)

# boosting procedures list
boost_optimization_models <- list()

boost_opt_mod_index = 1

# for each i-th shrinkage
for(i in c(1:length(shrinkage_parms))) {

  # for each j-th interaction.depth
  for(j in c(1:length(interaction_depth_params))) {

    # for each k-th cost
    for(k in c(1:length(costs_boost_params))) {

      # boosting with i-th shrinkage, j-th interaction.depth and k-th cost.
      # n.tree has a fixed value of 5000
      boost_optimization_models[[boost_opt_mod_index]] <- train_boosting_with_params(
          train_boost, 
          test_boost, 
          shrinkage_parms[i], 
          interaction_depth_params[j], 
          costs_boost_params[k], 
          5000
      )
      boost_opt_mod_index <- boost_opt_mod_index + 1
    }

  }

} 

# matrix with boosting procedures
boost_optimization_results <- data.frame(matrix(0, nrow=72, ncol=9)) 
names(boost_optimization_results)<- metrics_names_boost
boost_optimization_results$procedure <- c(1:72)

# extract results for each procedure 
for(i in c(1:72)) {
  boost_optimization_results[i, 1:9] <- boost_optimization_models[[i]]$results
}

In [None]:
# results
boost_optimization_results

In [None]:
boost_optimization_results %>% arrange(desc(F1))

In [None]:
options(repr.plot.width = 18, repr.plot.height = 8)

boost_optimization_results_indexes_f1 <- c(7,8,13,14,31,32,33,37,39,44,45,61,62,63,65,69,71)
boost_optimization_results_indexes_f1_selected <- c(17,23,46,55,56)

boost_optimization_results_pr <- boost_optimization_results
boost_optimization_results_pr <- boost_optimization_results_pr %>% 
mutate(selected = if_else(row_number() %in% boost_optimization_results_indexes_f1_selected, 1, 0))

boost_optimization_results_pr %>%
    ggplot() +
    geom_point(aes(x = Recall, y = Precision, color = F1), size = 4) +
    geom_line(aes(x = Recall, y = Precision), size = 0.8, alpha = 0.15) +
    geom_label_repel(
      data = boost_optimization_results_pr %>% 
        filter(
            row_number() %in% c(
                boost_optimization_results_indexes_f1, 
                boost_optimization_results_indexes_f1_selected
            )
        ), 
      aes(
        x = Recall, 
        y = Precision, 
        label = paste0(shrinkage, ", ", interaction.depth, ", ", cost), 
        fill = factor(selected)
      ), 
      size = 2.7
    ) +
    theme_classic() +
    scale_colour_gradientn(colours = c("#ff0000","#ff9c00","#fffca6"), values = c(1.0,0.87,0))  +
    scale_x_continuous(breaks = scales::pretty_breaks(n = 10)) +
    scale_y_continuous(breaks = scales::pretty_breaks(n = 10)) +
    scale_fill_manual(
      values = c("grey", "white"), 
      breaks = c(1, 0), 
      labels = c(1, 0)
    ) +
    labs(
      x = "Recall",
      y = "Precision",
      title = "Recall, Precision and F1 trends by varying the missclassification cost - classification tree"
    ) +
    guides(fill=FALSE)

In [None]:
boost_optimization_results_pr %>%
    ggplot() +
    geom_point(aes(x = procedure, y = F1, color = factor(selected))) +
    geom_line(aes(x = procedure, y = F1), alpha = 0.25) +
    geom_label_repel(
        data = boost_optimization_results_pr %>% 
        filter(
            row_number() %in% c(
                boost_optimization_results_indexes_f1, 
                boost_optimization_results_indexes_f1_selected
            )
        ), 
        aes(
            x = procedure, 
            y = F1, 
            label = paste0(shrinkage, ", ", interaction.depth, ", ", cost),
            fill = factor(selected)
        ), 
        size = 2.5
    ) +
    theme_classic() +
    scale_x_continuous(breaks = scales::pretty_breaks(n = 36)) +
    scale_fill_manual(
        values = c("grey", "white"), 
        breaks = c(1, 0), 
        labels = c(1, 0)
    ) +
    scale_color_manual(
        values = c("black", "grey"), 
        breaks = c(1, 0), 
        labels = c(1, 0)
    ) +
    guides(fill=FALSE, color = FALSE) +
    labs(
        x = "Procedure",
        title = "F1 trend by varying procedures"
    )

In [None]:
# most relevant boosting procedures
boost_optimization_results[c(6,17,23,46,55,56),] %>% arrange(desc(F1))

The results above show the most relevant boosting procedures based on F1, Recall, Precision and model complexity (n.trees).
Procedures 56 and 55 are those with the **highest F1**. Procedure 55 is akin to a boosting without any cost-sensitivity (cost = 1.0). Procedure 56 has the highest F1 and the lowest number of trees (404) and for this reason it's been considered the **best F1 choice**.

The procedures 23, 46, and 17 have F1 comparable to procedure 56, but with recall higher than precision. Among these three procedures, procedure 23 has the highest F1. Procedure 6 is the one with the highest recall among the 72 procedures, but with a more drastic decrease in precision and F1, which suggests a not "fair" trade-off. 

By comparing the procedure 56 and 23, this last one has **F1 comparable** to procedure 56 but with **recall higher than precision**. For that reason, procedure 23 has been considered the **best trade-off**.

In [None]:
set.seed(1)

# boosting with procedure 56 parameters
best_boost_model_proc_56 <- gbm(
  Attrition_Flag~., 
  distribution = "bernoulli",
  data=train_boost, 
  train.fraction = 1, 
  cv.folds = 5,
  class.stratify.cv = TRUE,
  
  shrinkage = 0.2,
  interaction.depth = 3, 
  n.trees = 404,
  weights = ifelse(train$Attrition_Flag == 1, 1.5, 1)
)

# predict on test set with the best iteration (n.trees)
best_boost_model_proc_56_pred <- predict(best_boost_model_proc_56,test_boost,n.trees = 404, type = "response")
best_boost_model_proc_56_pred_factor <- ifelse(best_boost_model_proc_56_pred > 0.5,1,0)
best_boost_model_proc_56_pred_factor <- as_factor(best_boost_model_proc_56_pred_factor)
test_boost_factor <- as_factor(test_boost$Attrition_Flag)

# get confusion matrix
conf_matrix_best_boost_model_proc_56 <- caret::confusionMatrix(best_boost_model_proc_56_pred_factor, test_boost_factor, positive = "1", mode = "everything")
metrics_best_boost_model_proc_56 <- classification_metrics(conf_matrix_best_boost_model_proc_56)

gbm.perf(boost_optimization_models[[56]]$model, method = "cv")
title(main = "Boosting 56")

In [None]:
set.seed(1)

# boosting with procedure 23 parameters
best_boost_model_proc_23 <- gbm(
  Attrition_Flag~., 
  distribution = "bernoulli",
  data=train_boost, 
  train.fraction = 1, 
  cv.folds = 5,
  class.stratify.cv = TRUE,
  
  shrinkage = 0.05,
  interaction.depth = 6, 
  n.trees = 1130,
  weights = ifelse(train$Attrition_Flag == 1, 7.0, 1)
)

# predict on test set with the best iteration (n.trees)
best_boost_model_proc_23_pred <- predict(best_boost_model_proc_23,test_boost,n.trees=1130, type = "response")

best_boost_model_proc_23_pred_factor <- ifelse(best_boost_model_proc_23_pred > 0.5,1,0)
best_boost_model_proc_23_pred_factor <- as_factor(best_boost_model_proc_23_pred_factor)
test_boost_factor <- as_factor(test_boost$Attrition_Flag)

# get confusion matrix
conf_matrix_best_boost_model_proc_23 <- caret::confusionMatrix(best_boost_model_proc_23_pred_factor, test_boost_factor, positive = "1", mode = "everything")
metrics_best_boost_model_proc_23 <- classification_metrics(conf_matrix_best_boost_model_proc_23)

gbm.perf(boost_optimization_models[[23]]$model, method = "cv")
title(main = "Boosting 23")

In [None]:
# Boosting procedure 56
conf_matrix_best_boost_model_proc_56$table
metrics_best_boost_model_proc_56

In [None]:
# Boosting procedure 23
conf_matrix_best_boost_model_proc_23$table
metrics_best_boost_model_proc_23

### Variables importance

In [None]:
# Variables importance for boosting procedure 23
best_boost_model_proc_23_varimp <- data.frame(summary(best_boost_model_proc_23, n.trees = 1130))

In [None]:
best_boost_model_proc_23_varimp %>%
    ggplot(aes(x = rel.inf, y = reorder(var, rel.inf))) +
    geom_bar(aes(fill = rel.inf), stat = "identity", color = "black") +
    geom_text(aes(label = round(rel.inf, 3)), vjust = 0.20, hjust = -0.20) +
    scale_fill_gradient(low = "#c7e9c0", high = "#00441b") +
    theme_classic() +
    labs(
      title = "Variables importance - Boosting procedure 46",
      x = "Importance",
      y = "Variabile",
      fill = "Importance"
    ) + guides(fill = FALSE)

Similarly to the decision tree, variables importance for the selected boosting procedure confirms the EDA, with higher importance on *Total_Trans_Ct*, *Months_Inactive_12_mon*, and *Contacts_Count_12_mon*. 

The plot below shows the marginal effect of the most important variables on the response variable *Attrition_Flag*:

In [None]:
effect_1 <- plot(best_boost_model_proc_23,i="Total_Trans_Ct", type = "response", main = "Marginal effect of Total_Trans_Ct on Attrition_Flag")
effect_2 <- plot(best_boost_model_proc_23,i="Total_Trans_Amt", type = "response", main = "Marginal effect of Total_Trans_Amt on Attrition_Flag")
effect_3 <- plot(best_boost_model_proc_23,i="Total_Revolving_Bal", type = "response", main = "Marginal effect of Total_Revolving_Bal on Attrition_Flag")
effect_4 <- plot(best_boost_model_proc_23,i="Total_Ct_Chng_Q4_Q1", type = "response", main = "Marginal effect of Total_Ct_Chng_Q4_Q1 on Attrition_Flag")
effect_5 <- plot(best_boost_model_proc_23,i="Total_Relationship_Count", type = "response", main = "Marginal effect of Total_Relationship_Count on Attrition_Flag")
effect_6 <- plot(best_boost_model_proc_23,i="Total_Amt_Chng_Q4_Q1", type = "response", main = "Marginal effect of Total_Amt_Chng_Q4_Q1 on Attrition_Flag")
grid.arrange(effect_1, effect_2, effect_3, effect_4, effect_5, effect_6, nrow = 2, ncol = 3)

### Gradient Boosting summary
The most relevant boosting procedures performances are set out below. 

The boosting procedure 23 gives the best performance (high F1 and recall higher than precision).

In [None]:
boost_optimization_results[c(6,17,23,46,55,56),] %>% arrange(desc(F1))

## Comparison and Conclusion
By comparing the selected decision tree and the selected boosting procedure, this last one performs better in every metric:

In [None]:
# Decision tree with cost 1.3
conf_matrix_tree_opt_1$table
metrics_tree_opt_1

In [None]:
# Boosting procedure 23 with shrinkage = 0.05, interaction.depth = 6, n.trees = 1130 and cost = 7.0
conf_matrix_best_boost_model_proc_23$table
metrics_best_boost_model_proc_23

In spite of better performance, model interpretability must be taken into account: a gradient boosting model is **less interpretable** than a single decision tree due to its ensemble nature. The selected decision tree has already a good performance, so it may not worth losing interpretability in favor of a better performance. Moreover, having more information about the economic impact of offers proposed to customers could lead to a better precision-recall trade-off.