In [8]:
# libraries
import plotly
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [20]:
# import csv
df = pd.read_csv('dff.csv')
df.tail()

Unnamed: 0,Titel,Bevilliget beløb,Modtager,Institution,Virkemidler,Område,År,Beskrivelse,Region
4385,Enabling Ultra Deep Hydrodesulphurization by N...,10781874,Ib Chorkendorff,Danmarks Tekniske Universitet,Øvrige forskningsprojekter,Teknologi og Produktion,2013,Alle olieprodukter renses i dag for svovl for ...,Region Hovedstaden
4386,Acute stroke research,717359,Hanne Krarup Christensen,"Bispebjerg Hospital, Neurologisk Afdeling",Delestillinger,Sundhed og Sygdom,2013,Aktuelle ansøgning angår frikøb af overlæge Ha...,Region Hovedstaden
4387,Atherosclerotic cardiovascular disease in HIV-...,764683,Anne-Mette Lebech,"Hvidovre Hospital, Infektionsmedicinsk Afdeling",Delestillinger,Sundhed og Sygdom,2013,Behandling af HIV positive patienter med anti-...,Region Hovedstaden
4388,Epigenetic modulation of mechanisms involved i...,829294,Ole Schmeltz Søgaard,Aarhus Universitetshospital,Delestillinger,Sundhed og Sygdom,2013,HIV infektion behandles i dag med en kombinati...,Region Midtjylland
4389,Novel mechanisms of insulin resistance and mit...,665923,Kurt Højlund,"Odense Universitetshospital, Endokrinologisk A...",Delestillinger,Sundhed og Sygdom,2013,Insulinresistens (IR) i muskelvæv spiller en v...,Region Syddanmark


In [60]:
# data for sankey plot
gk = df.groupby(['År','Virkemidler','Region', 'Område']).agg({'Bevilliget beløb':'sum'})
gk = gk.reset_index()
gk

In [None]:
# data for sankey
df_sankey = gk[gk.År == 2021]

# create list of labels, i.e. unique values from each column except the values
labels = []

for col in df_sankey.loc[:, df_sankey.columns != 'Bevilliget beløb']:
    labels = labels + (list(set(df_sankey[col].values)))

# initiate input for for loop
category_columns = ['År', 'Virkemidler'] # len should be at least to, otherwise data cannot flow from category 1 to category 2.
value_column = 'Bevilliget beløb'
df_link_input = pd.DataFrame({'source' : [], 'target': [], 'count': []})

# create data for go.Sankey function
for i in range(len(category_columns)-1):
    if len(category_columns) == 1:
        print("Number of input categories must be at least 2")
    else:
        temporary_df = df_sankey.groupby([category_columns[i], category_columns[i+1]]).agg({value_column:'sum'}).reset_index() # loop over columns and group by column to the right, i.e. 'År' and 'Virkemidler', and then 'Virkemidler' and 'Område'
        temporary_df.columns = ['source','target','count']
        df_link_input = df_link_input.append(temporary_df)

# add index for source-target pair
df_link_input['sourceID'] = df_link_input['source'].apply(lambda x: labels.index(x))
df_link_input['targetID'] = df_link_input['target'].apply(lambda x: labels.index(x))

In [None]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = labels,
      color = "blue"
    ),
    link = dict(
      source = df_link_input['sourceID'], # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = df_link_input['targetID'],
      value = df_link_input['count']
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

In [None]:
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # maximum of 6 value cols -> 6 colors
    colorPalette = ['#4B8BBE','#306998','#FFE873','#FFD43B','#646464']
    labelList = []
    colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        colorNumList.append(len(labelListTemp))
        labelList = labelList + labelListTemp    

    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
        
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
        else:
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
        
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))

    # creating the sankey diagram
    data = dict(
        type='sankey',
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.5
          ),
          label = labelList,
          color = colorList
        ),
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
        )
      )
    
    layout =  dict(
        title = title,
        font = dict(
          size = 10
        )
    )
       
    fig = dict(data=[data], layout=layout)
    return fig

fig = genSankey(df_sankey,cat_cols=['År', 'Virkemidler', 'Område'], value_cols='Bevilliget beløb', title='Titel')
plotly.offline.plot(fig, validate=False)