## test sankey_flow visualization with Titanic dataset

In [1]:
import pandas as pd
import numpy as np

## read data

In [2]:
raw = pd.read_csv('data/titanic_train.csv')

In [3]:
raw.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [4]:
raw.columns.values

array(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
       'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'], dtype=object)

In [5]:
# use these columns as layers of nodes
use_cols = ['Survived', 'Pclass', 'Sex', 'SibSp', 'Parch', 'Embarked']
for col in use_cols:
    print '%s: %s' %(col, str(raw[col].unique()))

Survived: [0 1]
Pclass: [3 1 2]
Sex: ['male' 'female']
SibSp: [1 0 3 4 2 5 8]
Parch: [0 1 2 5 3 4 6]
Embarked: ['S' 'C' 'Q' nan]


In [6]:
raw['Embarked'] = raw['Embarked'].fillna('unknown')

In [7]:
raw[use_cols].head()

Unnamed: 0,Survived,Pclass,Sex,SibSp,Parch,Embarked
0,0,3,male,1,0,S
1,1,1,female,1,0,C
2,1,3,female,0,0,S
3,1,1,female,1,0,S
4,0,3,male,0,0,S


## color strategy
**node_color_type**, default="col_val"  
Can be ['col', 'val', 'col_val', 'cus']
- 'col': each column has different color
- 'val': each unique value has different color (unique values through all columns)
- 'col_val': each unique value in each column has different color
- 'cus': customer provide node color mapping

**node_color_mapping**, default=None  
Customized color mapping.
- node_color_mapping = '#fff'  
All nodes have the same color
- node_color_mapping = dict()  
```python
node_color_mapping = {
    # nodes in same column will have the same color
    column_name1: '#fff'
    
    # each unique value in a column has its own color
    column_name2: {value1: '#fff', value2: '#ccc', ...}
}
```

**link_color_type**, default='source'  
Can be ['source', 'target', 'both', 'same']
- 'source': same color as the source node
- 'target': same color as the target node
- 'both': color from both target and source
- 'same': all links have same color

**link_color**, default=None  
Only required when `link_color_type="same"`

In [8]:
import generate_sankey_flow

In [16]:
node_color_mapping = {
    'Survived': {0: '#111', 1: '#222'}, 
    'Pclass': '#333', 
    'Sex': {'female': '#444', 'male': '#555'}, 
    'SibSp': '#666', 
    'Parch': '#777', 
    'Embarked': '#888'
}

In [19]:
generate_sankey_flow.draw_sankey_flow(df=raw[use_cols], node_color_type='col', link_color_type='source', 
                                      width=1600, height=900, graph_name='Titanic', 
                                      node_color_mapping=None, color_map=None, link_color=None)

In [13]:
df = raw[use_cols]
node_color_type = 'col_val'
link_color_type = 'source'
width = 1600
height = 1000
graph_name='Titanic'
node_color_mapping=None
color_map=None
link_color=None


# get node infos
node_infos = _get_node_names(df)

# get node colors
try:
    cm = plt.cm.get_cmap(color_map)
except:
    cm = plt.cm.get_cmap('Vega20')

node_colors = _get_node_colors(node_infos, node_color_type, node_color_mapping, cm)

# prepare data for sankey
sankey_data = _prepare_sankey_data(df, node_colors)

# render the output
temp = open('sankey_flow_template.html').read()
template = jinja2.Template(temp)

# generate output html
if graph_name is None:
    output_path = 'sankey_flow_output.html'
else:
    output_path = 'sankey_flow_output_%s.html' %(graph_name)
with open(output_path, 'wb') as fh:
    fh.write(template.render({'data': json.dumps(sankey_data), 'link_color_type': link_color_type, 'link_color': link_color, 
                              'width': width, 'height': height}))
    print('The output is in %s. Enjoy!' %(output_path))

### get node names

In [14]:
node_infos = {}
use_cols = ['Survived', 'Pclass', 'Sex', 'SibSp', 'Parch', 'Embarked']
for col in use_cols:
    col_values = sorted(raw[col].unique())
    col_names = []
    for col_value in col_values:
        node_name = '%s: %s' %(col, str(col_value))
        col_names.append(node_name)
    node_infos[col] = {'col_names': col_names, 'col_values': col_values}

In [15]:
node_infos

{'Embarked': {'col_names': ['Embarked: C',
   'Embarked: Q',
   'Embarked: S',
   'Embarked: unknown'],
  'col_values': ['C', 'Q', 'S', 'unknown']},
 'Parch': {'col_names': ['Parch: 0',
   'Parch: 1',
   'Parch: 2',
   'Parch: 3',
   'Parch: 4',
   'Parch: 5',
   'Parch: 6'],
  'col_values': [0, 1, 2, 3, 4, 5, 6]},
 'Pclass': {'col_names': ['Pclass: 1', 'Pclass: 2', 'Pclass: 3'],
  'col_values': [1, 2, 3]},
 'Sex': {'col_names': ['Sex: female', 'Sex: male'],
  'col_values': ['female', 'male']},
 'SibSp': {'col_names': ['SibSp: 0',
   'SibSp: 1',
   'SibSp: 2',
   'SibSp: 3',
   'SibSp: 4',
   'SibSp: 5',
   'SibSp: 8'],
  'col_values': [0, 1, 2, 3, 4, 5, 8]},
 'Survived': {'col_names': ['Survived: 0', 'Survived: 1'],
  'col_values': [0, 1]}}

### create node_colors dictionary

In [12]:
import matplotlib.pyplot as plt
import matplotlib

In [10]:
node_color_type = 'col_val'

node_color_mapping = {
    'Survived': {0: '#111', 1: '#222'}, 
    'Pclass': '#333', 
    'Sex': {'female': '#444', 'male': '#555'}, 
    'SibSp': '#666', 
    'Parch': '#777', 
    'Embarked': '#888'
}

In [16]:
node_colors = {}
cm = plt.cm.get_cmap('Vega20')

if node_color_type == 'col':
    for col_idx, col in enumerate(node_infos.keys()):
        col_color = cm(col_idx % 20)
        for col_name in node_infos[col]['col_names']:
            node_colors[col_name] = matplotlib.colors.rgb2hex(col_color)
            
if node_color_type == 'val':
    unique_values = []
    for col in node_infos.keys():
        unique_values += list(node_infos[col]['col_values'])
    unique_values = list(set(unique_values))
    for col in node_infos.keys():
        for col_value in node_infos[col]['col_values']:
            val_idx = unique_values.index(col_value)
            node_colors['%s: %s' %(col, str(col_value))] = matplotlib.colors.rgb2hex(cm(val_idx % 20))
            
if node_color_type == 'col_val':
    node_names = []
    for col in node_infos.keys():
        node_names += node_infos[col]['col_names']
    for node_name_idx, node_name in enumerate(node_names):
        node_colors[node_name] = matplotlib.colors.rgb2hex(cm(node_name_idx % 20))
        
if node_color_type == 'cus':
    for col in node_color_mapping.keys():
        if type(node_color_mapping[col]) == dict:
            for col_val in node_color_mapping[col].keys():
                node_colors['%s: %s' %(col, str(col_val))] = node_color_mapping[col][col_val]
        else:
            for col_name in node_infos[col]['col_names']:
                node_colors[col_name] = node_color_mapping[col]

In [17]:
node_colors

{'Embarked: C': u'#1f77b4',
 'Embarked: Q': u'#aec7e8',
 'Embarked: S': u'#ff7f0e',
 'Embarked: unknown': u'#ffbb78',
 'Parch: 0': u'#2ca02c',
 'Parch: 1': u'#98df8a',
 'Parch: 2': u'#d62728',
 'Parch: 3': u'#ff9896',
 'Parch: 4': u'#9467bd',
 'Parch: 5': u'#c5b0d5',
 'Parch: 6': u'#8c564b',
 'Pclass: 1': u'#c49c94',
 'Pclass: 2': u'#e377c2',
 'Pclass: 3': u'#f7b6d2',
 'Sex: female': u'#7f7f7f',
 'Sex: male': u'#c7c7c7',
 'SibSp: 0': u'#17becf',
 'SibSp: 1': u'#9edae5',
 'SibSp: 2': u'#1f77b4',
 'SibSp: 3': u'#aec7e8',
 'SibSp: 4': u'#ff7f0e',
 'SibSp: 5': u'#ffbb78',
 'SibSp: 8': u'#2ca02c',
 'Survived: 0': u'#bcbd22',
 'Survived: 1': u'#dbdb8d'}

## get links and nodes

In [None]:
links = []
nodes = []
for i in range(len(use_cols)-1):
    source_col, target_col = use_cols[i], use_cols[i+1]
    temp_df = raw[[source_col, target_col]]
    temp_df['count'] = 1
    temp_df = temp_df.rename(columns={source_col: 'source', target_col: 'target'})
    temp_df_gp = temp_df.groupby(['source', 'target'], as_index=False).count()
    
    temp_df_gp['source'] = temp_df_gp['source'].apply(lambda x : '%s: %s' %(source_col, str(x)))
    temp_df_gp['target'] = temp_df_gp['target'].apply(lambda x : '%s: %s' %(target_col, str(x)))
    
    temp_df_gp['color_source'] = temp_df_gp['source'].apply(lambda x : node_colors[x] if x in node_colors.keys() else '#000')
    temp_df_gp['color_target'] = temp_df_gp['target'].apply(lambda x : node_colors[x] if x in node_colors.keys() else '#000')
    temp_df_gp['value'] = temp_df_gp['count'].map(str)
    
    links+= temp_df_gp[['source', 'target', 'value']].to_dict('records')

nodes = [{'name': n, 'color': c} for (n, c) in node_colors.items()]

## render output

In [None]:
data = {
    'links': links,
    'nodes': nodes
}

link_color_type = "same"
link_color = '#ccc'

In [None]:
import json
import jinja2

import sys
reload(sys)
sys.setdefaultencoding('utf-8')

In [None]:
temp = open('sankey_path_template.html').read()
template = jinja2.Template(temp)

# generate output html
with open('sankey_path_test.html', 'wb') as fh:
    fh.write(template.render({'data': json.dumps(data), 'link_color_type': link_color_type, 'link_color': link_color, 
                              'width': 1600, 'height': 900}))