In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import numpy as np
import urllib, json
import matplotlib as mpl

In [2]:
import plotly.graph_objects as go

## Sankey diagrahm

In [3]:
df_emb = pd.read_excel("../data/embeddings/all_nw500_1000_wl80_vs75_pca3.xlsx")[["Node","Embedding"]]

In [4]:
node2emb = {node:embedding for node, embedding in df_emb.values}

In [5]:
df_edgelist = pd.read_excel("../data/edgelist/UN_full.xlsx")[["origin","destination","weight_scaled"]]

In [6]:
df_edgelist["origin_class"] = [node2emb[node_] for node_ in df_edgelist["origin"]]
df_edgelist["destination_class"] = [node2emb[node_] for node_ in df_edgelist["destination"]]

In [7]:
cmap = mpl.colormaps["Set2"]

In [8]:
df_edgelist["origin_color"] = [cmap(x) for x in df_edgelist["origin_class"]]
df_edgelist["destination_color"] = [cmap(x) for x in df_edgelist["destination_class"]]

In [9]:
sankey_data = df_edgelist.groupby(["origin_class","destination_class"])["weight_scaled"].sum().reset_index()

In [20]:
sankey_data["sankey_dest"] = [3+class_ for class_ in sankey_data["destination_class"]]

In [59]:
sankey_data["origin_color"] = ["rgba" + str(cmap(x)) for x in sankey_data["origin_class"]]
sankey_data["destination_color"] = ["rgba" + str(cmap(x, alpha=0.6)) for x in sankey_data["destination_class"]]

In [61]:
sankey_data.head()

Unnamed: 0,origin_class,destination_class,weight_scaled,origin_color,destination_color,sankey_dest
0,0,0,5433.342297,"rgba(0.4, 0.7607843137254902, 0.64705882352941...","rgba(0.4, 0.7607843137254902, 0.64705882352941...",3
1,0,1,7.293687,"rgba(0.4, 0.7607843137254902, 0.64705882352941...","rgba(0.9882352941176471, 0.5529411764705883, 0...",4
2,0,2,90.588868,"rgba(0.4, 0.7607843137254902, 0.64705882352941...","rgba(0.5529411764705883, 0.6274509803921569, 0...",5
3,1,0,16202.95351,"rgba(0.9882352941176471, 0.5529411764705883, 0...","rgba(0.4, 0.7607843137254902, 0.64705882352941...",3
4,1,1,1067.199721,"rgba(0.9882352941176471, 0.5529411764705883, 0...","rgba(0.9882352941176471, 0.5529411764705883, 0...",4


Unnamed: 0,origin,destination,weight_scaled,origin_class,destination_class,origin_color,destination_color
0,AFG,AUS,3.03613,0,0,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","(0.4, 0.7607843137254902, 0.6470588235294118, ..."
1,AFG,AUT,3.537707,0,0,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","(0.4, 0.7607843137254902, 0.6470588235294118, ..."


In [72]:
fig = go.Figure(
    data = [go.Sankey(
                node=dict(
                    pad = 15,
                    thickness = 20,
                    line = dict(color = "black", width = 0.5),
                    label = ["Class 0 origin","Class 1 origin","Class 2 origin", "Class 0 destination","Class 1 destination","Class 2 destination"],
                    color = ["rgba" + str(cmap(x)) for x in range(3)]*2
                    
                    ),
                link=dict(
                    source=sankey_data["origin_class"].values,
                    target=sankey_data["sankey_dest"].values,
                    value=sankey_data["weight_scaled"].values,
                    arrowlen=15,
                    color= sankey_data["destination_color"].values

                ),
                
            )]
)
fig.update_layout(title_text="Migration by structural classification", font_size=15, title={"xanchor":"center","y":0.9,"x":0.5})
fig.write_image("../visuals/sankey_plot.jpeg")
fig.show()

In [68]:
df_edgelist["sankey_dest"] = [3+class_ for class_ in df_edgelist["destination_class"]]
df_edgelist["destination_color"] = ["rgba" + str(cmap(x, alpha=0.6)) for x in df_edgelist["destination_class"]]
df_edgelist.head(2)

Unnamed: 0,origin,destination,weight_scaled,origin_class,destination_class,origin_color,destination_color,sankey_dest
0,AFG,AUS,3.03613,0,0,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","rgba(0.4, 0.7607843137254902, 0.64705882352941...",3
1,AFG,AUT,3.537707,0,0,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","rgba(0.4, 0.7607843137254902, 0.64705882352941...",3


In [74]:
# fig = go.Figure(
#     data = [go.Sankey(
#                 node=dict(
#                     pad = 15,
#                     thickness = 20,
#                     line = dict(color = "black", width = 0.5),
#                     label = ["Class 0 origin","Class 1 origin","Class 2 origin", "Class 0 destination","Class 1 destination","Class 2 destination"],
#                     color = ["rgba" + str(cmap(x)) for x in range(3)]*2
                    
#                     ),
#                 link=dict(
#                     source=df_edgelist["origin_class"].values,
#                     target=df_edgelist["sankey_dest"].values,
#                     value=df_edgelist["weight_scaled"].values,
#                     color= df_edgelist["destination_color"].values

#                 ),
                
#             )]
# )
# fig.update_layout(title_text="Migration by structural classification", font_size=15, title={"xanchor":"center","y":0.9,"x":0.5})
# # fig.write_image("../visuals/sankey_plot.jpeg")
# # fig.show()

### For reference

In [13]:
dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["A1", "A2", "B1", "B2", "C1", "C2"],
      color = "blue"
)

{'pad': 15,
 'thickness': 20,
 'line': {'color': 'black', 'width': 0.5},
 'label': ['A1', 'A2', 'B1', 'B2', 'C1', 'C2'],
 'color': 'blue'}

In [14]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["A1", "A2", "B1", "B2", "C1", "C2"],
      color = "blue"
    ),
    link = dict(
      source = [0, 1, 0, 2, 3, 3], # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = [2, 3, 3, 4, 4, 5],
      value = [8, 4, 2, 8, 4, 2]
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

In [15]:
sankey_data.head()

Unnamed: 0,origin_class,destination_class,weight_scaled,origin_color,destination_color
0,0,0,5433.342297,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","(0.4, 0.7607843137254902, 0.6470588235294118, ..."
1,0,1,7.293687,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","(0.9882352941176471, 0.5529411764705883, 0.384..."
2,0,2,90.588868,"(0.4, 0.7607843137254902, 0.6470588235294118, ...","(0.5529411764705883, 0.6274509803921569, 0.796..."
3,1,0,16202.95351,"(0.9882352941176471, 0.5529411764705883, 0.384...","(0.4, 0.7607843137254902, 0.6470588235294118, ..."
4,1,1,1067.199721,"(0.9882352941176471, 0.5529411764705883, 0.384...","(0.9882352941176471, 0.5529411764705883, 0.384..."


In [16]:
sankey_data["destination_color"].values

array([(0.4, 0.7607843137254902, 0.6470588235294118, 1.0),
       (0.9882352941176471, 0.5529411764705883, 0.3843137254901961, 1.0),
       (0.5529411764705883, 0.6274509803921569, 0.796078431372549, 1.0),
       (0.4, 0.7607843137254902, 0.6470588235294118, 1.0),
       (0.9882352941176471, 0.5529411764705883, 0.3843137254901961, 1.0),
       (0.5529411764705883, 0.6274509803921569, 0.796078431372549, 1.0),
       (0.4, 0.7607843137254902, 0.6470588235294118, 1.0),
       (0.9882352941176471, 0.5529411764705883, 0.3843137254901961, 1.0),
       (0.5529411764705883, 0.6274509803921569, 0.796078431372549, 1.0)],
      dtype=object)

In [17]:
url = 'https://raw.githubusercontent.com/plotly/plotly.js/master/test/image/mocks/sankey_energy.json'
response = urllib.request.urlopen(url)
data = json.loads(response.read())

# override gray link colors with 'source' colors
opacity = 0.4
# change 'magenta' to its 'rgba' value to add opacity
data['data'][0]['node']['color'] = ['rgba(255,0,255, 0.8)' if color == "magenta" else color for color in data['data'][0]['node']['color']]
data['data'][0]['link']['color'] = [data['data'][0]['node']['color'][src].replace("0.8", str(opacity))
                                    for src in data['data'][0]['link']['source']]

fig = go.Figure(data=[go.Sankey(
    valueformat = ".0f",
    valuesuffix = "TWh",
    # Define nodes
    node = dict(
      pad = 15,
      thickness = 15,
      line = dict(color = "black", width = 0.5),
      label =  data['data'][0]['node']['label'],
      color =  data['data'][0]['node']['color']
    ),
    # Add links
    link = dict(
      source =  data['data'][0]['link']['source'],
      target =  data['data'][0]['link']['target'],
      value =  data['data'][0]['link']['value'],
      label =  data['data'][0]['link']['label'],
      color =  data['data'][0]['link']['color']
))])

fig.update_layout(title_text="Energy forecast for 2050<br>Source: Department of Energy & Climate Change, Tom Counsell via <a href='https://bost.ocks.org/mike/sankey/'>Mike Bostock</a>",
                  font_size=10)
fig.show()

In [26]:
data['data'][0]['link']['color']

['rgba(31, 119, 180, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(227, 119, 194, 0.4)',
 'rgba(127, 127, 127, 0.4)',
 'rgba(188, 189, 34, 0.4)',
 'rgba(31, 119, 180, 0.4)',
 'rgba(23, 190, 207, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(255, 127, 14, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(214, 39, 40, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(140, 86, 75, 0.4)',
 'rgba(127, 127, 127, 0.4)',
 'rgba(127, 127, 127, 0.4)',
 'rgba(127, 127, 127, 0.4)',
 'rgba(188, 189, 34, 0.4)',
 'rgba(23, 190, 207, 0.4)',
 'rgb