# Setup notebook

In [1]:
# Import libraries
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'iframe'

## Functions

In [2]:
def plot_pie(df, col_to_plot, title, hole_size, color_sequence, legend_dict):
    '''
    Plot a pie chart of a Dataframe column, using a custom legend
    Inputs:
    - df: Dataframe with data
    - col_to_plot: name of the column of the dataframe to plot
    - title: title of chart
    - hole_size: size of hole
    - color_sequence: list of colors, must be same size of dict
    - legend_dict: dictionary used to compose legend
    '''
    # Check that legend_dict and color_sequence have same size
    if len(legend_dict)==0:
        fig = px.pie(
        total_raw_df, 
        names=df[col_to_plot],
        title=title,
        hole=hole_size,
        color_discrete_sequence = color_sequence,
    )
    elif len(legend_dict)!=len(color_sequence):
        raise ValueError('legends_dict and color_sequence must have the same length')
    else:
        fig = px.pie(
            total_raw_df, 
            names=df[col_to_plot].map(legend_dict),
            title=title,
            hole=hole_size,
            color_discrete_sequence = color_sequence,
        )
    fig.update_layout(
        font=dict(size=18),
        autosize=False,
        width=800,
        height=800,
    )
    fig.show()

In [3]:
def plot_sunburst(df, hierarchy, title, color, color_mapping):
    '''
    Plot a sunburst chart of a Dataframe column, using a custom legend
    Inputs:
    - df: Dataframe with data
    - hierarchy: list of hierarcy of sunburst rings, from inner to outer
    - title: title of chart
    - hole_size: size of hole
    - color: feature that will define color differences
    - color_discrete_sequence: #TODO
    - color_mapping: dict for colors mapping
    '''
    fig = px.sunburst(
        df,
        path=hierarchy,
        title=title,
        color=df[color],
        # color_discrete_sequence=color_sequence
    )
    fig.update_traces(
        textinfo="label+percent parent",
        insidetextorientation='horizontal',
        marker_colors=[color_mapping[cat] for cat in fig.data[-1].labels]
    )
    fig.update_layout(
        font=dict(size=18),
        autosize=False,
        width=800,
        height=800,
        )
    fig.show()

# Process data

### Variable Notes

pclass: A proxy for socio-economic status (SES)
- 1st = Upper
- 2nd = Middle
- 3rd = Lower

age: Age is fractional if less than 1. If the age is estimated, is it in the form of xx.5

sibsp: The dataset defines family relations in this way...
- Sibling = brother, sister, stepbrother, stepsister
- Spouse = husband, wife (mistresses and fiancés were ignored)

parch: The dataset defines family relations in this way...
- Parent = mother, father
- Child = daughter, son, stepdaughter, stepson
- Some children travelled only with a nanny, therefore parch=0 for them.

In [4]:
# Load raw dataframes
train_raw_df = pd.read_csv('../Data/train.csv')
test_raw_df = pd.read_csv('../Data/test.csv')
total_raw_df = pd.read_excel('../Data/Complete_dataset.xls')

In [5]:
# 'survived' column: convert 0->Deceased and 1->Survived
survived_dict = {0: 'Deceased', 1: 'Survived'}
class_dict = {1: '1st class', 2: '2nd class', 3: '3rd class'}
total_raw_df['survived'] = total_raw_df['survived'].map(survived_dict)
total_raw_df['pclass'] = total_raw_df['pclass'].map(class_dict)

### Exploratory Data Analysis

In [6]:
# See how many records and variables
num_pass, num_var = total_raw_df.shape[0], total_raw_df.shape[1]
print(f'Number of passengers = {num_pass}\nNumber of variables = {num_var}')

Number of passengers = 1309
Number of variables = 14


In [7]:
total_raw_df.columns

Index(['pclass', 'survived', 'name', 'sex', 'age', 'sibsp', 'parch', 'ticket',
       'fare', 'cabin', 'embarked', 'boat', 'body', 'home.dest'],
      dtype='object')

In [8]:
total_raw_df.head(10)

Unnamed: 0,pclass,survived,name,sex,age,sibsp,parch,ticket,fare,cabin,embarked,boat,body,home.dest
0,1st class,Survived,"Allen, Miss. Elisabeth Walton",female,29.0,0,0,24160,211.3375,B5,S,2,,"St Louis, MO"
1,1st class,Survived,"Allison, Master. Hudson Trevor",male,0.9167,1,2,113781,151.55,C22 C26,S,11,,"Montreal, PQ / Chesterville, ON"
2,1st class,Deceased,"Allison, Miss. Helen Loraine",female,2.0,1,2,113781,151.55,C22 C26,S,,,"Montreal, PQ / Chesterville, ON"
3,1st class,Deceased,"Allison, Mr. Hudson Joshua Creighton",male,30.0,1,2,113781,151.55,C22 C26,S,,135.0,"Montreal, PQ / Chesterville, ON"
4,1st class,Deceased,"Allison, Mrs. Hudson J C (Bessie Waldo Daniels)",female,25.0,1,2,113781,151.55,C22 C26,S,,,"Montreal, PQ / Chesterville, ON"
5,1st class,Survived,"Anderson, Mr. Harry",male,48.0,0,0,19952,26.55,E12,S,3,,"New York, NY"
6,1st class,Survived,"Andrews, Miss. Kornelia Theodosia",female,63.0,1,0,13502,77.9583,D7,S,10,,"Hudson, NY"
7,1st class,Deceased,"Andrews, Mr. Thomas Jr",male,39.0,0,0,112050,0.0,A36,S,,,"Belfast, NI"
8,1st class,Survived,"Appleton, Mrs. Edward Dale (Charlotte Lamson)",female,53.0,2,0,11769,51.4792,C101,S,D,,"Bayside, Queens, NY"
9,1st class,Deceased,"Artagaveytia, Mr. Ramon",male,71.0,0,0,PC 17609,49.5042,,C,,22.0,"Montevideo, Uruguay"


In [9]:
# Have a global look to dataset
total_raw_df.describe()

Unnamed: 0,age,sibsp,parch,fare,body
count,1046.0,1309.0,1309.0,1308.0,121.0
mean,29.881135,0.498854,0.385027,33.295479,160.809917
std,14.4135,1.041658,0.86556,51.758668,97.696922
min,0.1667,0.0,0.0,0.0,1.0
25%,21.0,0.0,0.0,7.8958,72.0
50%,28.0,0.0,0.0,14.4542,155.0
75%,39.0,1.0,0.0,31.275,256.0
max,80.0,8.0,9.0,512.3292,328.0


In [10]:
# Look for missing values
for el in total_raw_df.columns:
    num_null = total_raw_df[el].isnull().sum()
    print(f'Column "{el}" has {num_null} nulls')

Column pclass has 0 nulls
Column survived has 0 nulls
Column name has 0 nulls
Column sex has 0 nulls
Column age has 263 nulls
Column sibsp has 0 nulls
Column parch has 0 nulls
Column ticket has 0 nulls
Column fare has 1 nulls
Column cabin has 1014 nulls
Column embarked has 2 nulls
Column boat has 823 nulls
Column body has 1188 nulls
Column home.dest has 564 nulls


### Visualize some data

#### Passenger composition

In [11]:
# Survival by sex
plot_sunburst(
    df=total_raw_df, 
    hierarchy=['sex', 'survived'],
    title='Passenger survival by sex - interactive',
    color='sex',
    color_mapping = {
        'Deceased': "#79857e", 
        'Survived': "#07f01b", 
        'female': "#c71585", 
        'male': "#0000cd"
    }
)

In [12]:
 # Survival by class
total_raw_df['pclass'].replace(to_replace=[1, 2, 3], value=['1st class', '2nd class', '3rd class'], inplace=True)
    
plot_sunburst(
    df=total_raw_df.astype({'pclass': str}), 
    hierarchy=['pclass', 'survived'],
    title='Passenger survival by class - interactive',
    color='pclass',
    color_mapping = {
        'Deceased': "#79857e", 
        'Survived': "#07f01b", 
        '1st class': "#e89910", 
        '2nd class': "#e81710",
        '3rd class': "#10e8e1"
    }
)

##### Age ranges:
- < 3 = baby
- 4-17 = kid
- 18-40 = adult
- 41-60 = middle aged man
- 61+ = elderly

In [13]:
# Survival by age range
# Convert number to age range
bins = [0, 3, 17, 40, 60, 99]
labels = ['Baby', 'Kid', 'Adult', 'Middle aged', 'Elderly']
total_raw_df['age'] = pd.cut(x=total_raw_df['age'], bins=bins, labels=labels)

In [14]:
plot_sunburst(
    df=total_raw_df[total_raw_df['age'].notnull()],  # Filter out null values, 
    hierarchy=['age', 'survived'],
    title='Passenger survival by age range - interactive',
    color='age',
    color_mapping = {
        'Deceased': "#79857e", 
        'Survived': "#07f01b", 
        'Baby': "#e7feff", 
        'Kid': "#add8e6",
        'Adult': "#6495ed",
        'Middle aged': '#0000cd',
        'Elderly': '#000080'
    }
)