In [1]:
import pandas as pd
import numpy as np
from sklearn.utils import shuffle
import plotly
import plotly.graph_objs as go

In [2]:
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # maximum of 6 value cols -> 6 colors
    colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464','#307998']
    labelList = []
    colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp
        
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    
    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.5
          ),
          label = labelList,
          color = colorList
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

In [3]:
vals = pd.DataFrame({"count":[2,3,4,7,8,9,3,4,7,8,5,9,4]})
cat_ex = pd.DataFrame({"material": ['x', 'x','z','x', 'x','y', 'z', 'y', 'y', 'z', 'y']})
cat_tag = pd.DataFrame({"tag": ['tag1','tag2','tag1','tag4','tag1','tag3','tag1','tag1','tag3','tag2','tag4']})
data = pd.read_csv('sankytest.csv', delimiter=";")
data = pd.concat([data,cat_ex,cat_tag,vals], axis = 1)
#sdf = shuffle(sdf)
data

Unnamed: 0,name,catagory,color,region,material,tag,count
0,ball,toy,red,africa,x,tag1,2.0
1,ball,toy,red,africa,x,tag2,3.0
2,ball,toy,red,america,z,tag1,4.0
3,ball,toy,black,africa,x,tag4,7.0
4,ball,toy,blue,africa,x,tag1,8.0
5,pen,writing tool,black,asia,y,tag3,9.0
6,pen,writing tool,red,asia,z,tag1,3.0
7,pen,writing tool,black,africa,y,tag1,4.0
8,pen,writing tool,green,asia,y,tag3,7.0
9,knfie,utensil,silver,america,z,tag2,8.0


In [16]:
print(data.columns)

Index(['name', 'catagory', 'color', 'region', 'material', 'tag', 'count'], dtype='object')


In [4]:
fig = genSankey(data,cat_cols=['name', 'catagory','color','material','tag','region'],value_cols='count',title='Sankey test')
plotly.offline.plot(fig, validate=False)

'temp-plot.html'