In [1]:
import pandas as pd
# import plotly.graph_objects as go

In [2]:
df = pd.read_csv('train.csv')

In [3]:
df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [4]:
# Not all columns are relevant so we will keep those we intend to analyse
df_trimmed = df[['Pclass', 'Sex', 'Age', 'Survived']]

# Most of the columns are categorical which is ideal for a sankey visualisation except for Age. Hence, we will consider binning them into age groups which will be defined below
df_trimmed

Unnamed: 0,Pclass,Sex,Age,Survived
0,3,male,22.0,0
1,1,female,38.0,1
2,3,female,26.0,1
3,1,female,35.0,1
4,3,male,35.0,0
...,...,...,...,...
886,2,male,27.0,0
887,1,female,19.0,1
888,3,female,,0
889,1,male,26.0,1


In [5]:
# Check for NA 
df_trimmed.isnull().any()

Pclass      False
Sex         False
Age          True
Survived    False
dtype: bool

In [6]:
# From here, we can see that there are 177 passengers with NULL as their age. We will clean the NULL at the binning level.

df_trimmed.groupby(['Age'], dropna = False).size()

Age
0.42       1
0.67       1
0.75       2
0.83       2
0.92       1
        ... 
70.50      1
71.00      2
74.00      1
80.00      1
NaN      177
Length: 89, dtype: int64

In [7]:
# Set up binning of different age group here using the categorisation here as a guideline: https://www.nih.gov/nih-style-guide/age
# But we will merge 'Newborns' and 'Infants' and rename 'Older Adults' as 'Senior'

# Infants (1 month to 1 year)
# Children (1 year through 12 years)
# Adolescents (13 years through 17 years. They may also be referred to as teenagers depending on the context.)
# Adults (18 years or older)
# Senior (65 and older)

age_bins = [0,1,12,18,65,999]
df_trimmed = df_trimmed.assign(AgeGroup = pd.cut(x = df_trimmed['Age'], bins = age_bins, right = True, include_lowest = False, labels = ['Infants', 'Children', 'Adolescents', 'Adults', 'Senior']))

In [8]:
# Add a new categori - 'Unknown' to group all passengers without NULL as their AgeGroup
df_trimmed['AgeGroup'] = df_trimmed['AgeGroup'].cat.add_categories('Unknown')
df_trimmed['AgeGroup'] = df_trimmed['AgeGroup'].fillna('Unknown')

# Check if there is anymore NULL in the AgeGrou column
df_trimmed[df_trimmed.AgeGroup.isna()]

Unnamed: 0,Pclass,Sex,Age,Survived,AgeGroup


In [9]:
df_trimmed_sankey_1 = df_trimmed.groupby(by = ['Pclass', 'Sex'])['Survived'].count().reset_index().rename(columns={'Pclass': 'Source', 'Sex': 'Target', 'Survived': 'Value'}, inplace = False)

# This forms the first sankey dataframe
df_trimmed_sankey_1

Unnamed: 0,Source,Target,Value
0,1,female,94
1,1,male,122
2,2,female,76
3,2,male,108
4,3,female,144
5,3,male,347


In [10]:
# Map the values within the Source column to their respective Pclasses
df_trimmed_sankey_1['Source'] = df_trimmed_sankey_1.Source.map({1: 'Pclass1', 2: 'Pclass2', 3: 'Pclass3'})

df_trimmed_sankey_1

Unnamed: 0,Source,Target,Value
0,Pclass1,female,94
1,Pclass1,male,122
2,Pclass2,female,76
3,Pclass2,male,108
4,Pclass3,female,144
5,Pclass3,male,347


In [11]:
df_trimmed_sankey_2 = df_trimmed.groupby(by = ['Sex', 'AgeGroup'])['Survived'].count().reset_index().rename(columns={'Sex': 'Source', 'AgeGroup': 'Target', 'Survived': 'Value'}, inplace = False)

# This forms the second sankey dataframe
df_trimmed_sankey_2

Unnamed: 0,Source,Target,Value
0,female,Infants,4
1,female,Children,28
2,female,Adolescents,36
3,female,Adults,193
4,female,Senior,0
5,female,Unknown,53
6,male,Infants,10
7,male,Children,27
8,male,Adolescents,34
9,male,Adults,374


In [12]:
df_trimmed_sankey_3 = df_trimmed.groupby(by = ['AgeGroup', 'Survived'])['Pclass'].count().reset_index().rename(columns={'AgeGroup': 'Source', 'Survived': 'Target', 'Pclass': 'Value'}, inplace = False)

# This forms the third sankey dataframe
df_trimmed_sankey_3

Unnamed: 0,Source,Target,Value
0,Infants,0,2
1,Infants,1,12
2,Children,0,27
3,Children,1,28
4,Adolescents,0,40
5,Adolescents,1,30
6,Adults,0,348
7,Adults,1,219
8,Senior,0,7
9,Senior,1,1


In [13]:
# Map the values within the Target column to their respective Survival Status
df_trimmed_sankey_3['Target'] = df_trimmed_sankey_3.Target.map({0: 'Died', 1: 'Survived'})

df_trimmed_sankey_3

Unnamed: 0,Source,Target,Value
0,Infants,Died,2
1,Infants,Survived,12
2,Children,Died,27
3,Children,Survived,28
4,Adolescents,Died,40
5,Adolescents,Survived,30
6,Adults,Died,348
7,Adults,Survived,219
8,Senior,Died,7
9,Senior,Survived,1


In [29]:
# Combine all sankey dataframe together
sankey_main = pd.concat([df_trimmed_sankey_1, df_trimmed_sankey_2, df_trimmed_sankey_3], axis = 0)

# We now need to manipulate the dataset in a way that plotly is able to understand in order to generate the Alluvial diagram
sankey_main

Unnamed: 0,Source,Target,Value
0,Pclass1,female,94
1,Pclass1,male,122
2,Pclass2,female,76
3,Pclass2,male,108
4,Pclass3,female,144
5,Pclass3,male,347
0,female,Infants,4
1,female,Children,28
2,female,Adolescents,36
3,female,Adults,193


In [30]:
# This allows us to retrieve the unique categorical values under the Source and Target columns
unique_source_target = list(pd.unique(list(sankey_main[['Source', 'Target']].values.ravel('K'))))

# We will use this for the mapping dictionary later
unique_source_target

['Pclass1',
 'Pclass2',
 'Pclass3',
 'female',
 'male',
 'Infants',
 'Children',
 'Adolescents',
 'Adults',
 'Senior',
 'Unknown',
 'Died',
 'Survived']

In [31]:
# This allows us to assign a unique value to each categorical value
mapping_dict = {k:v for v, k in enumerate(unique_source_target)}

mapping_dict

{'Pclass1': 0,
 'Pclass2': 1,
 'Pclass3': 2,
 'female': 3,
 'male': 4,
 'Infants': 5,
 'Children': 6,
 'Adolescents': 7,
 'Adults': 8,
 'Senior': 9,
 'Unknown': 10,
 'Died': 11,
 'Survived': 12}

In [32]:
# We will now apply the mapping dictionary back to the Source and Target column
sankey_main['Source'] = sankey_main['Source'].map(mapping_dict)
sankey_main['Target'] = sankey_main['Target'].map(mapping_dict)

sankey_main

Unnamed: 0,Source,Target,Value
0,0,3,94
1,0,4,122
2,1,3,76
3,1,4,108
4,2,3,144
5,2,4,347
0,3,5,4
1,3,6,28
2,3,7,36
3,3,8,193


In [35]:
sankey_main = sankey_main.to_dict(orient = 'list')

sankey_main

{'Source': [0,
  0,
  1,
  1,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  4,
  4,
  4,
  4,
  4,
  4,
  5,
  5,
  6,
  6,
  7,
  7,
  8,
  8,
  9,
  9,
  10,
  10],
 'Target': [3,
  4,
  3,
  4,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  11,
  12,
  11,
  12,
  11,
  12,
  11,
  12,
  11,
  12],
 'Value': [94,
  122,
  76,
  108,
  144,
  347,
  4,
  28,
  36,
  193,
  0,
  53,
  10,
  27,
  34,
  374,
  8,
  124,
  2,
  12,
  27,
  28,
  40,
  30,
  348,
  219,
  7,
  1,
  125,
  52]}