# Survival Analysis Lab

Complete the following exercises to solidify your knowledge of survival analysis.

In [131]:
import pandas as pd
import numpy as np
import chart_studio.plotly as py
import plotly
import matplotlib.pyplot as plt
import seaborn as sns
import cufflinks as cf
from lifelines import KaplanMeierFitter

pd.options.plotting.backend = 'plotly'

cf.go_offline()

In [10]:
data = pd.read_csv('../data/attrition.csv')

In [12]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1470 entries, 0 to 1469
Data columns (total 35 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   Age                       1470 non-null   int64 
 1   Attrition                 1470 non-null   int64 
 2   BusinessTravel            1470 non-null   object
 3   DailyRate                 1470 non-null   int64 
 4   Department                1470 non-null   object
 5   DistanceFromHome          1470 non-null   int64 
 6   Education                 1470 non-null   int64 
 7   EducationField            1470 non-null   object
 8   EmployeeCount             1470 non-null   int64 
 9   EmployeeNumber            1470 non-null   int64 
 10  EnvironmentSatisfaction   1470 non-null   int64 
 11  Gender                    1470 non-null   object
 12  HourlyRate                1470 non-null   int64 
 13  JobInvolvement            1470 non-null   int64 
 14  JobLevel                

In [16]:
data.columns

Index(['Age', 'Attrition', 'BusinessTravel', 'DailyRate', 'Department',
       'DistanceFromHome', 'Education', 'EducationField', 'EmployeeCount',
       'EmployeeNumber', 'EnvironmentSatisfaction', 'Gender', 'HourlyRate',
       'JobInvolvement', 'JobLevel', 'JobRole', 'JobSatisfaction',
       'MaritalStatus', 'MonthlyIncome', 'MonthlyRate', 'NumCompaniesWorked',
       'Over18', 'OverTime', 'PercentSalaryHike', 'PerformanceRating',
       'RelationshipSatisfaction', 'StandardHours', 'StockOptionLevel',
       'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance',
       'YearsAtCompany', 'YearsInCurrentRole', 'YearsSinceLastPromotion',
       'YearsWithCurrManager'],
      dtype='object')

## 1. Generate and plot a survival function that shows how employee retention rates vary by gender and employee age.

*Tip: If your lines have gaps in them, you can fill them in by using the `fillna(method=ffill)` and the `fillna(method=bfill)` methods and then taking the average. We have provided you with a revised survival function below that you can use for the exercises in this lab*

In [13]:
def survival(data, group_field, time_field, event_field):
    kmf = KaplanMeierFitter()
    results = []

    for i in data[group_field].unique():
        group = data[data[group_field]==i]
        T = group[time_field]
        E = group[event_field]
        kmf.fit(T, E, label=str(i))
        results.append(kmf.survival_function_)

    survival = pd.concat(results, axis=1)
    front_fill = survival.fillna(method='ffill')
    back_fill = survival.fillna(method='bfill')
    smoothed = (front_fill + back_fill) / 2
    return smoothed

In [98]:
retention_rates_age = survival(data, 'Gender', 'Age', 'Attrition')

retention_rates_age.iplot(kind='line', xTitle='Age (yr)', yTitle='Retention Rate',
            title='Retention Rates and Age  by Gender')

## 2. Compare the plot above with one that plots employee retention rates by gender over the number of years the employee has been working for the company.

In [97]:
retention_rates_yr = survival(data, 'Gender', 'YearsAtCompany', 'Attrition')

# retention_rates_yr.iplot(kind='line', xTitle='Years at Company', yTitle='Retention Rate',
#            title='Retention Rates by Gender and Year at Company')

retention_rates_yr.plot(kind='line',
            title='Retention Rates and Years at Company by Gender',
                       labels=dict(value='Retention Rate', timeline='Years', variable='Gender')).update_layout(
                xaxis_title='Years at Company', yaxis_title='Retention Rate')

## 3. Let's look at retention rate by gender from a third perspective - the number of years since the employee's last promotion. Generate and plot a survival curve showing this.

In [95]:
retention_rates_lastprom = survival(data, 'Gender', 'YearsSinceLastPromotion', 'Attrition')

retention_rates_lastprom.plot(kind='line',
            title='Retention Rates and Years Since Last Promotion by Gender',
                       labels=dict(value='Retention Rate', timeline='Years', variable='Gender'), markers=True,
                    color_discrete_map={"Female": "#636EFA","Male": "#EF553B"
             }).update_layout(
                xaxis_title='Years since Last Promotion', yaxis_title='Retention Rate')

## 4. Let's switch to looking at retention rates from another demographic perspective: marital status. Generate and plot survival curves for the different marital statuses by number of years at the company.

In [93]:
retention_rates_marital = survival(data, 'MaritalStatus', 'YearsAtCompany', 'Attrition')

retention_rates_marital.plot(kind='line',
            title='Retention Rates and Years at Company by Marital Status',
                       labels=dict(value='Retention Rate', timeline='Years', variable='Marital Status'
                                  )).update_layout(
                xaxis_title='Years at Company', yaxis_title='Retention Rate')

## 5. Let's also look at the marital status curves by employee age. Generate and plot the survival curves showing retention rates by marital status and age.

In [102]:
retention_rates_marital2 = survival(data, 'MaritalStatus', 'Age', 'Attrition')

retention_rates_marital2.plot(kind='line',
            title='Retention Rates and Age by Marital Status',
                       labels=dict(value='Retention Rate', timeline='Age', variable='Marital Status'
                                  )).update_layout(
                xaxis_title='Age', yaxis_title='Retention Rate')

## 6. Now that we have looked at the retention rates by gender and marital status individually, let's look at them together. 

Create a new field in the data set that concatenates marital status and gender, and then generate and plot a survival curve that shows the retention by this new field over the age of the employee.

In [101]:
data['Gender_MaritalStatus'] = data['Gender'] + '-' + data['MaritalStatus']
data['Gender_MaritalStatus']

0        Female-Single
1         Male-Married
2          Male-Single
3       Female-Married
4         Male-Married
             ...      
1465      Male-Married
1466      Male-Married
1467      Male-Married
1468      Male-Married
1469      Male-Married
Name: Gender_MaritalStatus, Length: 1470, dtype: object

In [107]:
retention_rates_gendermarital = survival(data, 'Gender_MaritalStatus', 'Age', 'Attrition')

retention_rates_gendermarital.plot(kind='line',
            title='Retention Rates and Age by Gender-Marital Status',
                       labels=dict(value='Retention Rate', timeline='Age', variable='Gender-Marital Status'
                                  )).update_layout(
                xaxis_title='Age', yaxis_title='Retention Rate')

## 6. Let's find out how job satisfaction affects retention rates. Generate and plot survival curves for each level of job satisfaction by number of years at the company.

In [110]:
retention_rates_jobsatisfaction = survival(data, 'JobSatisfaction', 'YearsAtCompany', 'Attrition')

retention_rates_jobsatisfaction.plot(kind='line',
            title='Retention Rates and Years at Company by Job Satisfaction',
                       labels=dict(value='Retention Rate', timeline='Years at Company', variable='Job Satisfaction'
                                  )).update_layout(
                xaxis_title='Years at Company', yaxis_title='Retention Rate')

## 7. Let's investigate whether the department the employee works in has an impact on how long they stay with the company. Generate and plot survival curves showing retention by department and years the employee has worked at the company.

In [113]:
retention_rates_departmanet = survival(data, 'Department', 'YearsAtCompany', 'Attrition')

retention_rates_departmanet.plot(kind='line',
            title='Retention Rates and Years at Company by Department',
                       labels=dict(value='Retention Rate', timeline='Years at Company', variable='Department'
                                  )).update_layout(
                xaxis_title='Years at Company', yaxis_title='Retention Rate')

## 8. From the previous example, it looks like the sales department has the highest attrition. Let's drill down on this and look at what the survival curves for specific job roles within that department look like.

Filter the data set for just the sales department and then generate and plot survival curves by job role and the number of years at the company.

In [114]:
data['Department'].unique()

array(['Sales', 'Research & Development', 'Human Resources'], dtype=object)

In [115]:
sales = data[data['Department'] == 'Sales']

In [122]:
retention_rates_sales = survival(sales, 'JobRole', 'YearsAtCompany', 'Attrition')

retention_rates_sales.plot(kind='line',
            title='Retention Rates and Years at Company by Sales Department-Job Role', markers=True,
                       labels=dict(value='Retention Rate', timeline='Years at Company', variable='Job Role'
                                  )).update_layout(
                xaxis_title='Years at Company', yaxis_title='Retention Rate')

## 9. Let examine how compensation affects attrition.

- Use the `pd.qcut` method to bin the HourlyRate field into 5 different pay grade categories (Very Low, Low, Moderate, High, and Very High).
- Generate and plot survival curves showing employee retention by pay grade and age.

In [125]:
data['Compensation'] = pd.qcut(data['HourlyRate'], q=5, labels=['Very Low', 'Low', 'Moderate', 'High', 'Very High'])

In [127]:
retention_rates_compensation = survival(data, 'Compensation', 'YearsAtCompany', 'Attrition')

retention_rates_compensation.plot(kind='line',
            title='Retention Rates and Years at Company by Compensation',
                       labels=dict(value='Retention Rate', timeline='Years at Company', variable='Compensation'
                                  )).update_layout(
                xaxis_title='Years at Company', yaxis_title='Retention Rate')

## 10. Finally, let's take a look at how the demands of the job impact employee attrition.

- Create a new field whose values are 'Overtime' or 'Regular Hours' depending on whether there is a Yes or a No in the OverTime field.
- Create a new field that concatenates that field with the BusinessTravel field.
- Generate and plot survival curves showing employee retention based on these conditions and employee age.

In [129]:
data['OverTime'].unique()

array(['Yes', 'No'], dtype=object)

In [136]:
data['OverTime_New'] = np.where(data['OverTime']=='Yes', 'Overtime', 'Regular Hours')
data['OverTime_New']

0            Overtime
1       Regular Hours
2            Overtime
3            Overtime
4       Regular Hours
            ...      
1465    Regular Hours
1466    Regular Hours
1467         Overtime
1468    Regular Hours
1469    Regular Hours
Name: OverTime_New, Length: 1470, dtype: object

In [142]:
data['BusinessTravel'].unique()

array(['Travel_Rarely', 'Travel_Frequently', 'Non-Travel'], dtype=object)

In [139]:
data['Overtime-BusinessTravel'] = data['OverTime_New'] + '-' + data['BusinessTravel']
data['Overtime-BusinessTravel']

0                Overtime-Travel_Rarely
1       Regular Hours-Travel_Frequently
2                Overtime-Travel_Rarely
3            Overtime-Travel_Frequently
4           Regular Hours-Travel_Rarely
                     ...               
1465    Regular Hours-Travel_Frequently
1466        Regular Hours-Travel_Rarely
1467             Overtime-Travel_Rarely
1468    Regular Hours-Travel_Frequently
1469        Regular Hours-Travel_Rarely
Name: Overtime-BusinessTravel, Length: 1470, dtype: object

In [141]:
retention_rates_Overtime_BusinessTravel = survival(data, 'Overtime-BusinessTravel', 'Age', 'Attrition')

retention_rates_Overtime_BusinessTravel.plot(kind='line',
            title='Retention Rates and Age by Overtime - Business Travel',
                       labels=dict(value='Retention Rate', timeline='Age', variable='Overtime - Business Travel'
                                  )).update_layout(
                xaxis_title='Age', yaxis_title='Retention Rate')