# Setup notebook

In [1]:
# Import libraries
import pandas as pd
import numpy as np
from collections import defaultdict
import plotly.express as px
import plotly.io as pio
pio.renderers.default = 'vscode'
# pio.renderers.default = "browser"

## Functions

In [2]:
import plotly.express as px

def plot_pie(df, col_to_plot, title, hole_size, color_sequence, legend_dict):
    """
    Plots a pie chart using Plotly.

    Args:
        df (pd.DataFrame): The DataFrame containing the data.
        col_to_plot (str): The column name to be used for pie chart categories.
        title (str): The title of the pie chart.
        hole_size (float): The size of the hole in the center (for donut charts, e.g., 0 for full pie, 0.5 for half).
        color_sequence (list): A list of colors for the pie chart slices.
        legend_dict (dict): A dictionary mapping original category values to new legend labels.

    Raises:
        ValueError: If `legend_dict` and `color_sequence` do not have the same length.

    Returns:
        None: Displays the pie chart.

    Example:
        >>> df = pd.DataFrame({'Category': ['A', 'B', 'C'], 'Values': [10, 20, 30]})
        >>> legend_dict = {'A': 'Alpha', 'B': 'Beta', 'C': 'Gamma'}
        >>> plot_pie(df, 'Category', 'Sample Pie Chart', 0.2, ['red', 'blue', 'green'], legend_dict)
    """

    if len(legend_dict) == 0:
        fig = px.pie(
            df,  # Changed from total_raw_df (assuming it's a typo)
            names=df[col_to_plot],
            title=title,
            hole=hole_size,
            color_discrete_sequence=color_sequence,
        )
    elif len(legend_dict) != len(color_sequence):
        raise ValueError('legend_dict and color_sequence must have the same length')
    else:
        fig = px.pie(
            df,  # Changed from total_raw_df
            names=df[col_to_plot].map(legend_dict),
            title=title,
            hole=hole_size,
            color_discrete_sequence=color_sequence,
        )

    fig.update_layout(
        font=dict(size=18),
        autosize=False,
        width=800,
        height=800,
    )
    fig.show()


In [3]:
import plotly.express as px

def plot_sunburst(df, hierarchy, title, color, color_mapping):
    """
    Plots a sunburst chart using Plotly.

    Args:
        df (pd.DataFrame): The DataFrame containing hierarchical data.
        hierarchy (list): A list of column names defining the hierarchy for the sunburst chart.
        title (str): The title of the sunburst chart.
        color (str): The column name used for coloring the chart.
        color_mapping (dict): A dictionary mapping category labels to specific colors.

    Raises:
        KeyError: If any category in the data does not exist in `color_mapping`.

    Returns:
        None: Displays the sunburst chart.

    Example:
        >>> df = pd.DataFrame({
        ...     'Category': ['A', 'A', 'B', 'B'],
        ...     'Subcategory': ['X', 'Y', 'X', 'Z'],
        ...     'Value': [10, 20, 30, 40]
        ... })
        >>> hierarchy = ['Category', 'Subcategory']
        >>> color_mapping = {'A': 'red', 'B': 'blue', 'X': 'green', 'Y': 'purple', 'Z': 'orange'}
        >>> plot_sunburst(df, hierarchy, 'Sunburst Chart', 'Category', color_mapping)
    """

    fig = px.sunburst(
        df,
        path=hierarchy,
        title=title,
        color=df[color],
    )

    fig.update_traces(
        textinfo="label+percent parent",
        insidetextorientation='horizontal',
        marker_colors=[color_mapping[cat] for cat in fig.data[-1].labels]
    )

    fig.update_layout(
        font=dict(size=18),
        autosize=False,
        width=800,
        height=800,
    )

    fig.show()


# Process data

### Variable Notes

pclass: A proxy for socio-economic status (SES)
- 1st = Upper
- 2nd = Middle
- 3rd = Lower

age: Age is fractional if less than 1. If the age is estimated, is it in the form of xx.5

sibsp: The dataset defines family relations in this way...
- Sibling = brother, sister, stepbrother, stepsister
- Spouse = husband, wife (mistresses and fiancés were ignored)

parch: The dataset defines family relations in this way...
- Parent = mother, father
- Child = daughter, son, stepdaughter, stepson
- Some children travelled only with a nanny, therefore parch=0 for them.

In [4]:
# Load raw dataframes
train_raw_df = pd.read_csv('../Data/train.csv')
test_raw_df = pd.read_csv('../Data/test.csv')
total_raw_df = pd.read_excel('../Data/Complete_dataset.xls')

In [5]:
# 'survived' column: convert 0->Deceased and 1->Survived
survived_dict = {0: 'Deceased', 1: 'Survived'}
class_dict = {1: '1st class', 2: '2nd class', 3: '3rd class'}
total_raw_df['survived'] = total_raw_df['survived'].map(survived_dict)
total_raw_df['pclass'] = total_raw_df['pclass'].map(class_dict)

### Exploratory Data Analysis

In [6]:
# See how many records and variables
num_pass, num_var = total_raw_df.shape[0], total_raw_df.shape[1]
print(f'Number of passengers = {num_pass}\nNumber of variables = {num_var}')

Number of passengers = 1309
Number of variables = 14


In [7]:
total_raw_df.columns

Index(['pclass', 'survived', 'name', 'sex', 'age', 'sibsp', 'parch', 'ticket',
       'fare', 'cabin', 'embarked', 'boat', 'body', 'home.dest'],
      dtype='object')

In [8]:
total_raw_df.head(10)

Unnamed: 0,pclass,survived,name,sex,age,sibsp,parch,ticket,fare,cabin,embarked,boat,body,home.dest
0,1st class,Survived,"Allen, Miss. Elisabeth Walton",female,29.0,0,0,24160,211.3375,B5,S,2,,"St Louis, MO"
1,1st class,Survived,"Allison, Master. Hudson Trevor",male,0.9167,1,2,113781,151.55,C22 C26,S,11,,"Montreal, PQ / Chesterville, ON"
2,1st class,Deceased,"Allison, Miss. Helen Loraine",female,2.0,1,2,113781,151.55,C22 C26,S,,,"Montreal, PQ / Chesterville, ON"
3,1st class,Deceased,"Allison, Mr. Hudson Joshua Creighton",male,30.0,1,2,113781,151.55,C22 C26,S,,135.0,"Montreal, PQ / Chesterville, ON"
4,1st class,Deceased,"Allison, Mrs. Hudson J C (Bessie Waldo Daniels)",female,25.0,1,2,113781,151.55,C22 C26,S,,,"Montreal, PQ / Chesterville, ON"
5,1st class,Survived,"Anderson, Mr. Harry",male,48.0,0,0,19952,26.55,E12,S,3,,"New York, NY"
6,1st class,Survived,"Andrews, Miss. Kornelia Theodosia",female,63.0,1,0,13502,77.9583,D7,S,10,,"Hudson, NY"
7,1st class,Deceased,"Andrews, Mr. Thomas Jr",male,39.0,0,0,112050,0.0,A36,S,,,"Belfast, NI"
8,1st class,Survived,"Appleton, Mrs. Edward Dale (Charlotte Lamson)",female,53.0,2,0,11769,51.4792,C101,S,D,,"Bayside, Queens, NY"
9,1st class,Deceased,"Artagaveytia, Mr. Ramon",male,71.0,0,0,PC 17609,49.5042,,C,,22.0,"Montevideo, Uruguay"


In [9]:
# Have a global look to dataset
total_raw_df.describe()

Unnamed: 0,age,sibsp,parch,fare,body
count,1046.0,1309.0,1309.0,1308.0,121.0
mean,29.881135,0.498854,0.385027,33.295479,160.809917
std,14.4135,1.041658,0.86556,51.758668,97.696922
min,0.1667,0.0,0.0,0.0,1.0
25%,21.0,0.0,0.0,7.8958,72.0
50%,28.0,0.0,0.0,14.4542,155.0
75%,39.0,1.0,0.0,31.275,256.0
max,80.0,8.0,9.0,512.3292,328.0


In [10]:
# Look for missing values
for el in total_raw_df.columns:
    num_null = total_raw_df[el].isnull().sum()
    print(f'Column {el} has {num_null} nulls')

Column pclass has 0 nulls
Column survived has 0 nulls
Column name has 0 nulls
Column sex has 0 nulls
Column age has 263 nulls
Column sibsp has 0 nulls
Column parch has 0 nulls
Column ticket has 0 nulls
Column fare has 1 nulls
Column cabin has 1014 nulls
Column embarked has 2 nulls
Column boat has 823 nulls
Column body has 1188 nulls
Column home.dest has 564 nulls


### Visualize some data

#### Passenger composition

In [11]:
# Survival by sex
plot_sunburst(
    df=total_raw_df, 
    hierarchy=['sex', 'survived'],
    title='Passenger survival by sex',
    color='sex',
    color_mapping = {
        'Deceased': "#79857e", 
        'Survived': "#07f01b", 
        'female': "#c71585", 
        'male': "#0000cd"
    }
)

In [12]:
 # Survival by class
total_raw_df['pclass'].replace(to_replace=[1, 2, 3], value=['1st class', '2nd class', '3rd class'], inplace=True)
    
plot_sunburst(
    df=total_raw_df.astype({'pclass': str}), 
    hierarchy=['pclass', 'survived'],
    title='Passenger survival by class',
    color='pclass',
    color_mapping = {
        'Deceased': "#79857e", 
        'Survived': "#07f01b", 
        '1st class': "#e89910", 
        '2nd class': "#e81710",
        '3rd class': "#10e8e1"
    }
)

##### Age ranges:
- < 3 = baby
- 4-17 = kid
- 18-40 = adult
- 41-60 = middle aged man
- 61+ = elderly

In [13]:
# Survival by age range
# Convert number to age range
bins = [0, 14, 40, 99]
labels = ['Kid', 'Adult', 'Older adult']
total_raw_df['age'] = pd.cut(x=total_raw_df['age'], bins=bins, labels=labels)

In [14]:
plot_sunburst(
    df=total_raw_df[total_raw_df['age'].notnull()],  # Filter out null values, 
    hierarchy=['age', 'survived'],
    title='Passenger survival by age range',
    color='age',
    color_mapping = {
        'Deceased': "#79857e", 
        'Survived': "#07f01b", 
        'Kid': "#00FFFF",
        'Adult': "#0096FF",
        'Older adult': '#00008B',
    }
)

#### Passenger survival

Considering sex, ticket class, and age as macrocategories, we can define as subcategories:
- Sex: (2 categories)
    - Male
    - Female
- Class: (3 categories)
    - 1st
    - 2nd
    - 3rd
- Age: (5 categories)
    - < 15 = kid
    - 16-40 = adult
    - 41+ = older adult
    
Combining subcategories, we have 2 * 3 * 3 = 18 clusters:
- Cluster 1: Male, 1st class, kid
- Cluster 2: Male, 1st class, adult
- Cluster 3: Male, 1st class, older adult
- Cluster 4: Male, 2nd class, kid
- Cluster 5: Male, 2nd class, adult
- Cluster 6: Male, 2nd class, older adult
- Cluster 7: Male, 3rd class, kid
- Cluster 8: Male, 3rd class, adult
- Cluster 9: Male, 3rd class, older adult
- Cluster 10: Female, 1st class, kid
- Cluster 11: Female, 1st class, adult
- Cluster 12: Female, 1st class, older adult
- Cluster 13: Female, 2nd class, kid
- Cluster 14: Female, 2nd class, adult
- Cluster 15: Female, 2nd class, older adult
- Cluster 16: Female, 3rd class, kid
- Cluster 17: Female, 3rd class, adult
- Cluster 18: Female, 3rd class, older adult

In [15]:
# Create a df that is a subset of the original
sub_raw_df = total_raw_df[['pclass', 'survived', 'sex', 'age']].copy()
sub_raw_df.head(10)

Unnamed: 0,pclass,survived,sex,age
0,1st class,Survived,female,Adult
1,1st class,Survived,male,Kid
2,1st class,Deceased,female,Kid
3,1st class,Deceased,male,Adult
4,1st class,Deceased,female,Adult
5,1st class,Survived,male,Older adult
6,1st class,Survived,female,Older adult
7,1st class,Deceased,male,Adult
8,1st class,Survived,female,Older adult
9,1st class,Deceased,male,Older adult


In [16]:
# Define subcategories and their values
sub_cat = {
    'sex': ['male', 'female'],
    'pclass': ['1st class', '2nd class', '3rd class'],
    'age': ['Kid', 'Adult', 'Older adult'],
}

In [17]:
# Create a dictionary of dataframes divided by the 3 categories
# To access one specific df, use wanted values.
# For example class_dfs['male']['2nd class']['Kid'] is the df of male, 2nd class, kids passengers
# Note that some df can be empty
class_dfs = {}
for sex in sub_cat['sex']:
    class_dfs[sex] = defaultdict(dict)
    for pclass in sub_cat['pclass']:
        class_dfs[sex][pclass] = defaultdict(dict)
        for age in sub_cat['age']:
            class_dfs[sex][pclass][age] = sub_raw_df[(sub_raw_df['sex']==sex) & (sub_raw_df['pclass']==pclass) & (sub_raw_df['age']==age)]

In [18]:
# Calculate survival rate for every category
cluster_survival_rates = {}
for sex in sub_cat['sex']:
    for pclass in sub_cat['pclass']:
        for age in sub_cat['age']:
            # print(sex, pclass, age)
            single_df = class_dfs[sex][pclass][age]
            survived = len(single_df[single_df['survived']=='Survived'])
            try:
                surv_rate = float("{:.2f}".format(survived/len(single_df)*100))
            except:
                surv_rate = 'N/A' # Some df could be empty: mark survivale rate as N/A
            cluster_survival_rates.update({f'{sex}_{pclass}_{age}': {'Survival rate': surv_rate, 'Number': len(single_df)}})

# Filter out classes that have N/A as survivale rate
cluster_survival_rates = {k: v for k, v in cluster_survival_rates.items() if v['Survival rate']!='N/A'}
# Order by decreasing survival rate
cluster_survival_rates = {k: v for k, v in sorted(cluster_survival_rates.items(), key=lambda item: item[1]['Survival rate'], reverse=True)}

In [19]:
# Map color based on selection
color_by_sex = ['#0000cd', '#c71585', '#c71585', '#c71585', '#0000cd', '#c71585', '#c71585', '#c71585', '#c71585', '#c71585', '#0000cd', '#c71585', '#0000cd', '#0000cd', '#0000cd', '#0000cd', '#0000cd', '#0000cd']
color_by_class = ['#e89910', '#e81710', '#e89910', '#e89910', '#e81710', '#e81710', '#e81710', '#e89910', '#10e8e1', '#10e8e1', '#e89910', '#10e8e1', '#10e8e1', '#e89910', '#10e8e1', '#e81710', '#e81710', '#10e8e1']
color_by_age = ['#00FFFF', '#00FFFF', '#0096FF', '#00008B', '#00FFFF', '#0096FF', '#00008B', '#00FFFF', '#0096FF', '#00FFFF', '#0096FF', '#00008B', '#00FFFF', '#00008B', '#0096FF', '#0096FF', '#00008B', '#00008B']

In [20]:
# Plot survival rate by cluster, color by sex
fig = px.bar(
    pd.DataFrame(cluster_survival_rates, index=['Survival rate', 'Number']).T, 
    y='Survival rate',
    title='Passenger survival by cluster, colored by sex',
)
fig.update_layout(barmode='stack')
fig.update_traces(marker_color=color_by_sex)
fig.show()

In [None]:
# Plot survival rate by cluster, color by class
fig = px.bar(
    pd.DataFrame(cluster_survival_rates, index=['Survival rate', 'Number']).T, 
    y='Survival rate',
    title='Passenger survival by cluster, colored by class',
)
fig.update_layout(barmode='stack')
fig.update_traces(marker_color=color_by_class)
fig.show()

In [None]:
# Plot survival rate by cluster, color by age
fig = px.bar(
    pd.DataFrame(cluster_survival_rates, index=['Survival rate', 'Number']).T, 
    y='Survival rate',
    title='Passenger survival by cluster, colored by age',
)
fig.update_layout(barmode='stack')
fig.update_traces(marker_color=color_by_age)
fig.show()