In [9]:
### Examples ###
# 1) https://gist.github.com/praful-dodda/c98d9fd5dab6e6a9e68bf96ee73630e9
# 2) https://chart-studio.plotly.com/~alishobeiri/1591/plotly-sankey-diagrams/#/
# 3) https://medium.com/plotly/4-interactive-sankey-diagram-made-in-python-3057b9ee8616

In [1]:
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import plotly
# import plotly.plotly as py
import chart_studio.plotly as py
import plotly.graph_objects as go

In [5]:
# df = pd.read_csv('other_data/data1_alluvial.csv',sep=';')
df = pd.read_csv('./data/df_build_alluvial.csv',sep=';')

In [7]:
# fig = genSankey(df,cat_cols=['lvl3','lvl4','lvl5','lvl6','lvl7'],value_cols='count',title='Word Etymology')
fig = genSankey_2(df,cat_cols=['lvl1','lvl3','lvl4','lvl5','lvl6','lvl7'],value_cols='count',title='SS-OCoClus result')
# plotly.offline.plot(fig, validate=False)
# py.iplot(fig, validate=False)
# fig.update_layout(title_text="Basic Sankey Diagram", font_size=12)
# dpi=600
# board = plt.figure(figsize=(3, 2),dpi=dpi)
# fig.update_layout(font_size=12)
fig.show()


Sorting because non-concatenation axis is not aligned. A future version
of pandas will change to not sort by default.

To accept the future behavior, pass 'sort=False'.





In [4]:
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # Source: https://medium.com/kenlok/how-to-create-sankey-diagrams-from-dataframes-in-python-e221c1b4d6b0
    # maximum of 6 value cols -> 6 colors
#     colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
    colorPalette = ['#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896',
                    '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7',
                    '#bcbd22', '#dbdb8d','#17becf', '#9edae5']
    labelList = []
    colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.5
          ),
          label = labelList,
          color = colorList
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
    
#     fig = go.Figure(data = [go.Sankey(data,layout)])
    fig = go.Figure(data = [go.Sankey(data)])
#     fig = dict(data=[data], layout=layout)
    return fig

In [3]:
def genSankey_2(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # Source: https://gist.github.com/praful-dodda/c98d9fd5dab6e6a9e68bf96ee73630e9
    # maximum of 6 value cols -> 6 colors
#     colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
    colorPalette = ['#1f77b4', '#aec7e8', '#ff7f0e', '#ffbb78', '#2ca02c', '#98df8a', '#d62728', '#ff9896',
                    '#9467bd', '#c5b0d5', '#8c564b', '#c49c94', '#e377c2', '#f7b6d2', '#7f7f7f', '#c7c7c7',
                    '#bcbd22', '#dbdb8d','#17becf', '#9edae5']
    colorNumList = []
    specials = [". ","> ",">> ",".. ","- "]
    labelList = []
    lableDict = {}
    for i,catCol in enumerate(cat_cols):
        n = len(labelList)
        labelListTemp =  list(enumerate(set(df[catCol].values),start=n))
        colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        lableDict[catCol] = dict(labelListTemp)

    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    # revese the dict 
    rDict = {}
    for k,v in lableDict.items():
        rDict[k] = {str(v) : k for k,v in v.items()}
    
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    sourceTargetDf = pd.DataFrame()
    for i in range(len(cat_cols)-1):
  
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]].copy()
            tempDf[cat_cols[i]] =  tempDf[cat_cols[i]].astype(str)
            tempDf[cat_cols[i+1]] =  tempDf[cat_cols[i+1]].astype(str)
#             tempDf['s'],tempDf['d'] = cat_cols[i],cat_cols[i+1]
            tempDf['s'],tempDf['t'] = cat_cols[i],cat_cols[i+1]
            tempDf.columns = ['source','target','count','s','t']
#             tempDf['label'] = df['label']
#             tempDf['label'] = df['lvl3']
            tempDf['label'] = df['cluster']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
#             sourceTargetDf = sourceTargetDf.groupby(['source','target','s','t','label']).agg({'count':'sum'}).reset_index()
            sourceTargetDf = sourceTargetDf.groupby(['source','target','s','t','label']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
#     sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
#     sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    tempDf = tempDf[(tempDf['source']!='nan') & (tempDf['target']!='nan')]
#     display(tempDf)
#     display(rDict)
    reset_level = 0
    
    ### Combining codes for placing elements at their corresponding vertical axis.
    unique_list = []
    for k,v in rDict.items():
#         print(k,reset_level)
        v_keys = [x+'_'+str(reset_level) for x in list(v.keys())]
        reset_level += 1
        if v_keys[0][:3] == 'nan':
            v_keys.pop(0)        
#         print(v_keys)
        [unique_list.append(x) for x in v_keys]
#     print(unique_list)
    nodified = nodify(unique_list)
    
#     display(sourceTargetDf)
    sourceTargetDf = sourceTargetDf[(sourceTargetDf['source']!='nan') & (sourceTargetDf['target']!='nan')]
    sourceTargetDf['sourceID'] = sourceTargetDf.apply(lambda x: rDict[x['s']][x['source']],axis=1)
    sourceTargetDf['targetID'] = sourceTargetDf.apply(lambda x: rDict[x['t']][x['target']],axis=1)
#     display(sourceTargetDf)

    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        arrangement = "snap",
        orientation = 'h',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.5
          ),
          label = [x[1] for x in labelList],
          color = colorList,
          x=nodified[0],
          y=nodified[1]
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count'],
          label = sourceTargetDf['label'],
#           color = colorList
        )
      )
    
    layout =  dict(
        title = title,
        height = 772,
        width = 950,
        font = dict(
          size = 12
        )
    )
       
# #     fig = dict(data=[data], layout=layout)
    fig = go.Figure(data = [go.Sankey(data)], layout=layout)
# #     fig.update_layout(title_text="Basic Sankey Diagram", font_size=12)
    return fig

In [2]:
def nodify(node_names):
#     node_names = unique_list
    # uniqe name endings
    ends = sorted(list(set([e[-1] for e in node_names])))
    
    # intervals
    steps = 1.3/len(ends)

    # x-values for each unique name ending
    # for input as node position
    nodes_x = {}
    xVal = 0
    for e in ends:
        nodes_x[str(e)] = xVal
        xVal += steps

    # x and y values in list form
    x_values = [nodes_x[n[-1]] for n in node_names]
    y_values = [.1]*len(x_values)
    
    return x_values, y_values