In [18]:
import pandas as pd
import plotly.graph_objects as go
import matplotlib.colors as mcolors

In [19]:

# Load the data
file_path = 'data/city_corr_class.csv'
data = pd.read_csv(file_path)

# Prepare the data for the Sankey diagram
# Create a list of unique classes for all years
all_classes = sorted(set(data['class_2019']).union(set(data['class_2020']), set(data['class_2023'])), reverse=True)

# Create a mapping from class names to node indices
class_to_index = {f'{cls}_{year}': i for i, (cls, year) in enumerate([(cls, year) for year in [2019, 2020, 2023] for cls in all_classes])}

# Initialize lists for the Sankey diagram
source = []
target = []
value = []

# Add the flow connections from 2019 to 2020
for (class_2019, class_2020), group in data.groupby(['class_2019', 'class_2020']):
    source.append(class_to_index[f'{class_2019}_2019'])
    target.append(class_to_index[f'{class_2020}_2020'])
    value.append(len(group))

# Add the flow connections from 2020 to 2023
for (class_2020, class_2023), group in data.groupby(['class_2020', 'class_2023']):
    source.append(class_to_index[f'{class_2020}_2020'])
    target.append(class_to_index[f'{class_2023}_2023'])
    value.append(len(group))

def get_gradient_color(start_color, end_color, n):
    # Start color and end color in RGB
    colors = list(mcolors.LinearSegmentedColormap.from_list('grad', [start_color, end_color])(i/n) for i in range(n))
    return ['rgba({},{},{},{})'.format(int(r*255), int(g*255), int(b*255), a) for r, g, b, a in colors]

# Create labels for the nodes
labels = [f'{cls} ({year})' for year in [2019, 2020, 2023] for cls in all_classes]

start_color = 'red'
end_color = 'blue'
link_colors = get_gradient_color(start_color, end_color, len(source))

# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=labels
    ),
    link=dict(
        source=source,
        target=target,
        value=value,
        color=link_colors
    )
)])

# Update layout and show the figure
fig.update_layout(title_text='Sankey Diagram of Class Transitions (2019-2020-2023)', font_size=10)
fig.show()

In [62]:
import pandas as pd
import plotly.graph_objects as go
import matplotlib.pyplot as plt

# Load the data
file_path = '/Users/wishingtree/Desktop/learn_python/MNA/interlayer_edge/plot/plot_non_corr/scatter-bar/data/city_corr_class.csv'
data = pd.read_csv(file_path)

# Prepare the data for the Sankey diagram
# Sort the classes in descending order for each year (5, 4, 3, 2, 1)
all_classes = sorted(set(data['class_2019']).union(set(data['class_2020']), set(data['class_2023'])), reverse=True)

# Create a mapping from class names to node indices
class_to_index = {f'{cls}_{year}': i for i, (cls, year) in enumerate([(cls, year) for year in [2019, 2020, 2023] for cls in all_classes])}

# Initialize lists for the Sankey diagram
source = []
target = []
value = []

# Add the flow connections from 2019 to 2020
for (class_2019, class_2020), group in data.groupby(['class_2019', 'class_2020']):
    source.append(class_to_index[f'{class_2019}_2019'])
    target.append(class_to_index[f'{class_2020}_2020'])
    value.append(len(group))

# Add the flow connections from 2020 to 2023
for (class_2020, class_2023), group in data.groupby(['class_2020', 'class_2023']):
    source.append(class_to_index[f'{class_2020}_2020'])
    target.append(class_to_index[f'{class_2023}_2023'])
    value.append(len(group))

# Create labels for the nodes, keeping the order from top to bottom as 5, 4, 3, 2, 1 for each year
labels = [f'{cls} ({year})' for year in [2019, 2020, 2023] for cls in all_classes]

# 选择一个 colormap，例如 'RdBu'
cmap = plt.get_cmap('RdBu', 5)  # 指定 5 个颜色

# 提取 5 种颜色
tmp_colors = [cmap(i) for i in range(cmap.N)]  # cmap.N 是 colormap 中的颜色数量
tmp_colors = [f'rgb({int(c[0]*255)}, {int(c[1]*255)}, {int(c[2]*255)})' for c in tmp_colors]
tmp_colors = tmp_colors*3

# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=labels,
        color=tmp_colors,
    ),
    link=dict(
        source=source,
        target=target,
        value=value
    )
)])

# Update layout and show the figure
fig.update_layout(title_text='Sankey Diagram of Class Transitions (2019-2020-2023)', font_size=10)
fig.show()