#### 查找本机可支持中文的字体

In [1]:
# import matplotlib.font_manager

# # 获取所有可用字体
# font_list = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf')

# # 打印字体名称
# # print("可用的字体:")
# for font in font_list:
#     try:
#         font_name = matplotlib.font_manager.FontProperties(fname=font).get_name()
#         # print(font_name)
#     except:
#         pass  # 忽略无法读取的字体

# # 打印支持中文的字体
# print("\n可能支持中文的字体:")
# chinese_fonts = ['SimHei', 'Microsoft YaHei', 'SimSun', 'NSimSun', 'FangSong', 'KaiTi', 'STSong', 'STKaiti', 'STFangsong', 'STXihei', 'STZhongsong', 'STHupo', 'STLiti', 'STXinwei', 'STXingkai', 'Arial Unicode MS', 'Songti SC']

# for font in chinese_fonts:
#     if any(font.lower() in f.lower() for f in font_list):
#         print(font)

# 设置中文字体
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Songti SC']
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

#### Toy Sankey

Plotly的graph_objects模块中的Sankey图是基于桑基图（Sankey Diagram）的原理来绘制的。桑基图是一种特殊类型的流程图，用于可视化从一组值到另一组值的流量。以下是Plotly绘制Sankey图的基本原理：

1. 数据结构：
   - Plotly使用字典结构来定义Sankey图的数据。
   - 主要包含node和link两个键，分别定义节点和链接的属性。

2. 链接（Links）：
   - 节点之间的连接表示流量或数据的传输。
   - 链接的宽度通常与流量的大小成正比。
   - 在Plotly中，链接通过source、target和value属性来定义。

3. 节点（Nodes）：
   - 桑基图中的每个矩形块代表一个节点。
   - 节点可以表示不同的类别、阶段或实体。
   - 在Plotly中，节点通过label属性来命名。

4. 布局算法：
   - Plotly内部使用布局算法来自动排列节点和链接。
   - 这个算法会尝试最小化链接的交叉，并优化整体布局。

5. 颜色编码：
   - 节点和链接可以使用不同的颜色来区分不同的类别或强调特定的流量。

6. 交互性：
   - Plotly的Sankey图支持交互式操作，如悬停显示详细信息、点击等。

7. 自定义选项：
   - 提供了多种自定义选项，如节点的排序、链接的颜色、标签的位置等。

##### Sankey基本使用示例

In [1]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = ["A", "B", "C", "D", "E", "F"],
      color = "blue"
    ),
    link = dict(
      source = [0, 1, 0, 2, 3, 3], # 使用索引
      target = [2, 3, 3, 4, 4, 5],
      value = [8, 4, 2, 8, 4, 2]
  ))])

fig.show()

##### 统一 source 和 target
- source 和 target 在y轴的布局可以根据value自动计算
- link 的颜色根据target的节点颜色来定义
- 左右两边的节点y值相同

In [37]:
import plotly.graph_objects as go
import numpy as np
from matplotlib.colors import to_rgba
# 定义节点
# label = ["A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2"]
label = ["A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2",
         "A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2"]

# 定义连接
source = [2,  3,  4,  5,  6,  7,  0,  1,  2,  3, 4, 5]
target = [10, 11, 12, 13, 12, 14, 11, 12, 13, 8, 9, 10]
value =  [1,  1,  1,  1,  1,  1,  1,  1,  1,  1, 1, 1]

# 定义节点颜色
color = ["red", "red", "red", "green", "green", "green", "blue", "blue",
         "red", "red", "red", "green", "green", "green", "blue", "blue"]

# 计算每个节点的值
node_values = np.zeros(len(label))
for s, v in zip(source, value):
    node_values[s] += v
for t, v in zip(target, value):
    node_values[t] += v

# 计算累积和
cumsum_left = np.cumsum([0] + list(node_values[:8]))
cumsum_right = np.cumsum([0] + list(node_values[8:]))

# 计算 y 坐标
y = []
for i in range(16):
    if i < 8:
        # 使左右两边的节点y值相同
        tmp = max(cumsum_left[i] + cumsum_left[i+1], cumsum_right[i] + cumsum_right[i+1])
        y.append(tmp / (2 * cumsum_left[-1]))
        # y.append((cumsum_left[i] + cumsum_left[i+1]) / (2 * cumsum_left[-1]))
    else:
        tmp = max(cumsum_left[i-8] + cumsum_left[i-7], cumsum_right[i-8] + cumsum_right[i-7])
        y.append(tmp / (2 * cumsum_left[-1]))
        # y.append((cumsum_right[i-8] + cumsum_right[i-7]) / (2 * cumsum_right[-1]))
    # print(y[-1])

# 定义 x 坐标
x = [0.01] * 8 + [0.99] * 8

# 根据 target 定义节点颜色
def add_alpha_2_color(color, alpha):
    rgba = to_rgba(color)
    return f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {alpha})"
link_color = [add_alpha_2_color(color[t], 0.1) for t in target]

# 创建 Sankey 图
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = label,
      color = color,
      x = x,
      y = y
    ),
    link = dict(
      source = source,
      target = target,
      value = value,
      color = link_color
  ))])

# 更新布局
fig.update_layout(
    title_text="Buyer-Seller Sankey Diagram",
    font_size=10,
    width=600,
    height=300
)

# 添加左右标签
fig.add_annotation(x=-0.05, y=0.5, text="buyer", showarrow=False, textangle=-90, xref="paper", yref="paper")
fig.add_annotation(x=1.05, y=0.5, text="seller", showarrow=False, textangle=90, xref="paper", yref="paper")

# 显示图表
fig.show()

##### 多级Sankey

In [37]:
'''根据日期生成的数据，需要调整'''

import plotly.graph_objects as go
from matplotlib.colors import to_rgba
import pandas as pd
import numpy as np

# 假设我们有以下数据
dates = ['2024.05.01', '2024.04.28', '2024.04.20']
institutions = ['A1', 'A2', 'A3', 'B1', 'B2', 'B3', 'C1', 'C2']

# 创建随机数据
np.random.seed(0) # 设置随机种子，确保每次运行结果一致
data = []
for date in dates:
    for inst in institutions:
        if date == '2024.05.01':
            data.append({'date': date, 'institution': inst, 'buyer_volume': np.random.randint(50, 100), 'seller_volume': 0})
        elif date == '2024.04.20':
            data.append({'date': date, 'institution': inst, 'buyer_volume': 0, 'seller_volume': np.random.randint(50, 100)})
        else:
            data.append({'date': date, 'institution': inst, 'buyer_volume': np.random.randint(0, 50), 'seller_volume': np.random.randint(0, 50)})

df = pd.DataFrame(data)
print(df.head())

# 创建 Sankey 图数据
node_labels = []
node_colors = []
link_source = []
link_target = []
link_value = []
link_color = []

color_map = {'A': 'red', 'B': 'green', 'C': 'blue'}
def add_alpha_2_color(color, alpha):
    print(color)
    rgba = to_rgba(color)
    return f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {alpha})"

for i, date in enumerate(dates):
    offset = i * len(institutions)
    
    # 添加节点
    node_labels.extend([f"{inst}" for inst in institutions])
    node_colors.extend([color_map[inst[0]] for inst in institutions])
    
    date_data = df[df['date'] == date]
    
    # 添加买方链接
    if i < len(dates) - 1:  # 不是最后一个日期
        for j, inst in enumerate(institutions):
            volume = date_data[date_data['institution'] == inst]['buyer_volume'].values[0]
            if volume > 0:
                link_source.append(offset + j)
                link_target.append(offset + len(institutions) + j)
                link_value.append(volume)
                link_color.append(add_alpha_2_color(color_map[inst[0]], 0.1))
    
    # 添加卖方链接
    if i > 0:  # 不是第一个日期
        for j, inst in enumerate(institutions):
            volume = date_data[date_data['institution'] == inst]['seller_volume'].values[0]
            if volume > 0:
                link_source.append(offset + j - len(institutions))
                link_target.append(offset + j)
                link_value.append(volume)
                link_color.append(color_map[inst[0]])

# 创建 Sankey 图
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = node_labels,
      color = node_colors
    ),
    link = dict(
      source = link_source,
      target = link_target,
      value = link_value,
      color = link_color
  ))])

# 更新布局
fig.update_layout(
    title_text="Multi-stage Sankey Diagram",
    font_size=10,
    autosize=False,
    width=600,
    height=400
)

# 添加日期标签
for i, date in enumerate(dates):
    fig.add_annotation(
        x=i/2, y=1.05,
        xref="paper", yref="paper",
        text=date,
        showarrow=False,
        font=dict(size=14)
    )

# 添加买方和卖方标签
fig.add_annotation(x=-0.05, y=0.5, text="buyer", showarrow=False, textangle=-90, xref="paper", yref="paper")
fig.add_annotation(x=1.05, y=0.5, text="seller", showarrow=False, textangle=90, xref="paper", yref="paper")

fig.show()

         date institution  buyer_volume  seller_volume
0  2024.05.01          A1            94              0
1  2024.05.01          A2            97              0
2  2024.05.01          A3            50              0
3  2024.05.01          B1            53              0
4  2024.05.01          B2            53              0
red
red
red
green
green
green
blue
blue
red
red
red
green
green
green
blue
blue


##### 根据交易自动生成

In [49]:
'''多加了一个最近日期，保证link的流动能符合卖方x为交易时间，买方x为下次卖出时间/最近日期'''

import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from collections import defaultdict
# 假设我们有以下交易数据
transactions = [
    {'date': '2024.05.01', 'buyer': 'A1', 'seller': 'B2', 'volume': 100},
    {'date': '2024.05.01', 'buyer': 'C1', 'seller': 'A3', 'volume': 150},
    {'date': '2024.04.28', 'buyer': 'B1', 'seller': 'C2', 'volume': 80},
    {'date': '2024.04.28', 'buyer': 'A2', 'seller': 'B3', 'volume': 120},
    {'date': '2024.04.20', 'buyer': 'C2', 'seller': 'A1', 'volume': 90},
    {'date': '2024.04.20', 'buyer': 'B3', 'seller': 'C1', 'volume': 110},
]

df = pd.DataFrame(transactions)
dates = sorted(df['date'].unique(), reverse=True)

# 添加一个新的日期，为最大日期的后一天
max_date = datetime.strptime(max(dates), '%Y.%m.%d')
new_max_date = (max_date + timedelta(days=1)).strftime('%Y.%m.%d')
dates = [new_max_date] + dates

# 提取所有唯一的机构
all_institutions = sorted(set(df['buyer'].unique()) | set(df['seller'].unique()))

# 为每个机构分配颜色
color_map = {inst: f'rgb({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)})' 
             for inst in all_institutions}

# 创建节点和链接
nodes = []
links = []
node_indices = {} # node_index 和 node_label 的映射

# 找出每个机构的下一次卖出日期
next_sell_date = defaultdict(lambda: new_max_date)
for date in reversed(dates[1:]):  # 跳过新添加的日期
    date_df = df[df['date'] == date]
    for _, row in date_df.iterrows():
        next_sell_date[row['seller']] = date

# 为每个交易创建节点
for i, date in enumerate(dates[1:], 1):  # 跳过新添加的日期
    date_df = df[df['date'] == date]
    for _, row in date_df.iterrows():
        # 卖方节点
        if (date, row['seller']) not in node_indices:
            node_indices[(date, row['seller'])] = len(nodes)
            new_node = dict(
                label=f"{row['seller']} ({date})",
                color=color_map[row['seller']],
                x=i / len(dates),
                institution=row['seller']
            )
            nodes.append(new_node)
        
        # 买方节点
        buyer_date = next_sell_date[row['buyer']]
        if (buyer_date, row['buyer']) not in node_indices:
            node_indices[(buyer_date, row['buyer'])] = len(nodes)
            new_node = dict(
                label=f"{row['buyer']} ({buyer_date})",
                color=color_map[row['buyer']],
                x=dates.index(buyer_date) / len(dates),
                institution=row['buyer']
            )
            nodes.append(new_node)

        # 添加链接
        source_index = node_indices[(buyer_date, row['buyer'])]
        target_index = node_indices[(date, row['seller'])]
        links.append(dict(
            source=source_index,
            target=target_index,
            value=row['volume'],
            color=color_map[row['buyer']]
        ))

# 全局计算y坐标
institution_positions = {inst: i for i, inst in enumerate(all_institutions)}
for node in nodes:
    node['y'] = institution_positions[node['institution']] / (len(all_institutions) - 1)

# 创建 Sankey 图
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 10,
      thickness = 5,
      line = dict(color = "black", width = 0.5),
      label = [node['label'] for node in nodes],
      color = [node['color'] for node in nodes],
      x = [node['x'] for node in nodes],
      y = [node['y'] for node in nodes]
    ),
    link = dict(
      source = [link['source'] for link in links],
      target = [link['target'] for link in links],
      value = [link['value'] for link in links],
      color = [link['color'] for link in links]
  ))])
# print(fig)

# 更新布局
fig.update_layout(
    title_text="Multi-stage Sankey Diagram",
    font_size=10,
    autosize=False,
    width=700,
    height=400,
    margin=dict(t=50, l=50, r=50, b=50)
)

# 添加日期标签
for i, date in enumerate(dates):
    fig.add_annotation(
        x=i/(len(dates)), y=1.05,
        xref="paper", yref="paper",
        text=date,
        showarrow=False,
        font=dict(size=14)
    )

fig.show()

In [44]:
'''为每天都增加所有机构的node。但是link会出错'''

import plotly.graph_objects as go
import pandas as pd
import numpy as np
from collections import defaultdict
from datetime import datetime, timedelta

# 假设我们有以下交易数据
transactions = [
    {'date': '2024.05.01', 'buyer': 'A1', 'seller': 'B2', 'volume': 100},
    {'date': '2024.05.01', 'buyer': 'C1', 'seller': 'A3', 'volume': 150},
    {'date': '2024.04.28', 'buyer': 'B1', 'seller': 'C2', 'volume': 80},
    {'date': '2024.04.28', 'buyer': 'A2', 'seller': 'B3', 'volume': 120},
    {'date': '2024.04.20', 'buyer': 'C2', 'seller': 'A1', 'volume': 90},
    {'date': '2024.04.20', 'buyer': 'B3', 'seller': 'C1', 'volume': 110},
]

df = pd.DataFrame(transactions)
dates = sorted(df['date'].unique(), reverse=True)

# 添加一个新的日期，为最大日期的后一天
max_date = datetime.strptime(max(dates), '%Y.%m.%d')
new_max_date = (max_date + timedelta(days=1)).strftime('%Y.%m.%d')
dates = [new_max_date] + dates

# 提取所有唯一的机构
all_institutions = sorted(set(df['buyer'].unique()) | set(df['seller'].unique()))

# 为每个机构分配颜色
color_map = {inst: f'rgb({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)})' 
             for inst in all_institutions}

# 创建节点和链接
nodes = []
links = []
node_indices = {}

# 找出每个机构的下一次卖出日期
next_sell_date = defaultdict(lambda: new_max_date)
for date in reversed(dates[1:]):  # 跳过新添加的日期
    date_df = df[df['date'] == date]
    for _, row in date_df.iterrows():
        next_sell_date[row['seller']] = date

# 为每个日期的每个机构创建节点
for i, date in enumerate(dates):
    for institution in all_institutions:
        node_indices[(date, institution)] = len(nodes)
        nodes.append(dict(
            label=f"{institution} ({date})",
            color=color_map[institution],
            x=i / len(dates),
            institution=institution
        ))

# 添加链接
for _, row in df.iterrows():
    source_index = node_indices[(row['date'], row['buyer'])]
    target_index = node_indices[(row['date'], row['seller'])]
    links.append(dict(
        source=source_index,
        target=target_index,
        value=row['volume'],
        color=color_map[row['buyer']]
    ))

# 全局计算y坐标
for i, institution in enumerate(all_institutions):
    for node in nodes:
        if node['institution'] == institution:
            node['y'] = i / (len(all_institutions) - 1)

# 创建 Sankey 图
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = [node['label'] for node in nodes],
      color = [node['color'] for node in nodes],
      x = [node['x'] for node in nodes],
      y = [node['y'] for node in nodes]
    ),
    link = dict(
      source = [link['source'] for link in links],
      target = [link['target'] for link in links],
      value = [link['value'] for link in links],  # 进一步减小链接宽度
      color = [link['color'] for link in links]
  ))])

# 更新布局
fig.update_layout(
    title_text="Multi-stage Sankey Diagram",
    font_size=10,
    autosize=False,
    width=700,
    height=400,
    margin=dict(t=50, l=50, r=50, b=50)
)

# 添加日期标签
for i, date in enumerate(dates):
    fig.add_annotation(
        x=i/(len(dates)-1), y=1.05,
        xref="paper", yref="paper",
        text=date,
        showarrow=False,
        font=dict(size=14)
    )

fig.show()

In [70]:
'''为每天增加所有机构的node。并构建一个全链接link, 如果本身有link，则不构建，否则增加一个很小的value'''
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from matplotlib.colors import to_rgba
import re
# from collections import defaultdict

# 假设我们有以下交易数据
transactions = [
    {'date': '2024.05.01', 'buyer': 'A1', 'seller': 'B2', 'volume': 100},
    {'date': '2024.05.01', 'buyer': 'C',  'seller': 'A1', 'volume': 150},
    {'date': '2024.04.28', 'buyer': 'B1', 'seller': 'C',  'volume': 80},
    {'date': '2024.04.28', 'buyer': 'A2', 'seller': 'B2', 'volume': 120},
    {'date': '2024.04.20', 'buyer': 'C',  'seller': 'A1', 'volume': 90},
    {'date': '2024.04.20', 'buyer': 'B2', 'seller': 'C',  'volume': 110},
]
df = pd.DataFrame(transactions)
# 根据日期、买方和卖方对交易数据进行排序
df_sorted = df.sort_values(by=['date', 'seller', 'buyer'])
df = df_sorted.reset_index(drop=True)
print(df)

dates = sorted(df['date'].unique(), reverse=False)

# 添加一个新的日期，为最大日期的后一天
max_date = datetime.strptime(max(dates), '%Y.%m.%d')
new_max_date = (max_date + timedelta(days=1)).strftime('%Y.%m.%d')
dates = dates + [new_max_date]
print(f"dates: {dates}")

# 颜色增加透明度
def add_alpha_2_color(color, alpha):
    if color.startswith('rgb'):
        # 使用正则表达式提取 RGB 值
        r, g, b = map(int, re.findall(r'\d+', color))
        return f"rgba({r}, {g}, {b}, {alpha})"
    else:
        # 对于其他格式的颜色，使用 to_rgba
        rgba = to_rgba(color)
        return f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, {alpha})"

# 提取所有唯一的机构, 为每个机构分配颜色
all_institutions = sorted(set(df['buyer'].unique()) | set(df['seller'].unique()))
color_map = {inst: f'rgb({np.random.randint(0,256)},{np.random.randint(0,256)},{np.random.randint(0,256)})' 
             for inst in all_institutions}

# 创建节点和链接
nodes = []
links = []
node_indices = {} # ('2024.05.02', 'A1'): index

#为每个日期的创建所有机构节点
for i, date in enumerate(dates):
    for institution in all_institutions:
        node_indices[(date, institution)] = len(nodes)
        nodes.append(dict(
            # label=f"{institution} ({date})",
            label = (date, institution),
            color=color_map[institution],
            x=i / (len(dates) - 1),
            institution=institution
        ))

# 找出每个机构的下一次卖出日期
next_sell_date = defaultdict(lambda: new_max_date) # 可以为不存在的键提供默认值
for i, date in enumerate(dates[:-1]):  # 跳过新添加的日期
    date_df = df[df['date'] == date]
    # print(f"{date_df}")
    for _, row in date_df.iterrows():
        current_buyer = row['buyer']
        # 查找当前买家在之后日期的卖出交易
        current_buyer_future_trades = df[(df['date'] > date) & (df['seller'] == current_buyer)]
        if not current_buyer_future_trades.empty:
            # 如果存在未来卖出交易，获取最早的一次交易日期
            next_sell_date[(date, current_buyer)] = current_buyer_future_trades['date'].min()
        else:
            # 如果不存在未来卖出交易，设置为新的最大日期
            next_sell_date[(date, current_buyer)] = new_max_date
        # print(date, current_buyer, next_sell_date[(date, current_buyer)])

# 构建全链接图, 如果本身有link，则不构建，否则增加一个很小的value
for i, date in enumerate(dates):
    if date != new_max_date:
        # 遍历所有日期和机构, 从最早的日期开始
        for sell_inst in all_institutions:
            source_index = node_indices[(date, sell_inst)] # 卖方节点 index
            current_date_trades = df[(df['date'] == date) & (df['seller'] == sell_inst)]
            # 如果该机构当前日期存在卖出交易，则构建link
            if not current_date_trades.empty:
                for _, row in current_date_trades.iterrows():
                    buy_inst = row['buyer']
                    target_index = node_indices[(next_sell_date[(date, buy_inst)], buy_inst)]
                    links.append(dict(
                        source=source_index,
                        target=target_index,
                        value=row['volume'],
                        color=add_alpha_2_color(color_map[sell_inst], 0.4)
                    ))
            # 对于在当前日期不存在卖出交易的机构，构建一个到自身的link
            else:
                links.append(dict(
                    source=source_index,
                    target=source_index+len(all_institutions),
                    value=0.01,
                    # color='rgba(0,0,0,0)'  # 设置为完全透明的颜色
                    color=add_alpha_2_color(color_map[sell_inst], 0.4)
                ))
    else:
        break

# 计算每个机构的全局y坐标，考虑到link value的大小
node_buy_values = defaultdict(float)
node_sell_values = defaultdict(float)
for link in links:
    node_sell_values[link['source']] += link['value']
    node_buy_values[link['target']] += link['value']
node_values = defaultdict(float)
for node in node_buy_values.keys() | node_sell_values.keys():
    node_values[node] = max(node_buy_values[node], node_sell_values[node])

inst_values = defaultdict(float)
for inst in all_institutions:
    for (date, institution), idx in node_indices.items():
        if institution == inst:
            inst_values[inst] = max(node_values[idx], inst_values[inst])

# print(node_values)
# print(inst_values)

# 计算全局的总值
total_value = sum(inst_values.values())

# 计算每个节点的y坐标
cumulative_value = 0
for i in range(len(all_institutions)):
    inst_value = inst_values[all_institutions[i]]
    for d in range(len(dates)):
        node = nodes[d*len(all_institutions)+i]
        node['y'] = (cumulative_value + inst_value/2) / total_value
    cumulative_value += inst_value

print(nodes)



# 创建 Sankey 图
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 10,
      thickness = 10,
      line = dict(color = "black", width = 0.5),
      label = [node['label'] for node in nodes],
      # label = [node['institution'] for node in nodes],
      color = [node['color'] for node in nodes],
      x = [node['x'] for node in nodes]
      #   y = [node['y'] for node in nodes]
    ),
    link = dict(
      source = [link['source'] for link in links],
      target = [link['target'] for link in links],
      value = [link['value'] for link in links],
      color = [link['color'] for link in links]
  ))])
# print(fig)

# 更新布局
fig.update_layout(
    title_text="Multi-stage Sankey Diagram",
    font_size=10,
    autosize=False,
    width=700,
    height=400,
    margin=dict(t=50, l=50, r=50, b=50)
)

# 添加日期标签
for i, date in enumerate(dates):
    fig.add_annotation(
        x=1-i/(len(dates)-1), y=1.05,
        xref="paper", yref="paper",
        text=date,
        showarrow=False,
        font=dict(size=14)
    )

fig.show()

         date buyer seller  volume
0  2024.04.20     C     A1      90
1  2024.04.20    B2      C     110
2  2024.04.28    A2     B2     120
3  2024.04.28    B1      C      80
4  2024.05.01     C     A1     150
5  2024.05.01    A1     B2     100
dates: ['2024.04.20', '2024.04.28', '2024.05.01', '2024.05.02']
[{'label': ('2024.04.20', 'A1'), 'color': 'rgb(177,106,201)', 'x': 0.0, 'institution': 'A1', 'y': 0.1209618889408577}, {'label': ('2024.04.20', 'A2'), 'color': 'rgb(44,13,121)', 'x': 0.0, 'institution': 'A2', 'y': 0.33870135316033095}, {'label': ('2024.04.20', 'B1'), 'color': 'rgb(70,38,167)', 'x': 0.0, 'institution': 'B1', 'y': 0.5}, {'label': ('2024.04.20', 'B2'), 'color': 'rgb(136,13,248)', 'x': 0.0, 'institution': 'B2', 'y': 0.6612905827137396}, {'label': ('2024.04.20', 'C'), 'color': 'rgb(135,208,248)', 'x': 0.0, 'institution': 'C', 'y': 0.879030046933213}, {'label': ('2024.04.28', 'A1'), 'color': 'rgb(177,106,201)', 'x': 0.3333333333333333, 'institution': 'A1', 'y': 0.12096188

#### 数据
生成一些交易假数据，用于绘制demo图

In [53]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

# 生成机构和交易员名称
institutions = ['机构A', '机构B', '机构C', '机构D', '机构E']
traders = ['交易员1', '交易员2', '交易员3', '交易员4', '交易员5']

# 生成100条交易数据
n_transactions = 100
data = []

for _ in range(n_transactions):
    # 生成交易时间
    transaction_time = datetime.now() - timedelta(days=np.random.randint(0, 30))
    
    # 生成交易量
    volume = np.random.randint(1000, 100000)
    
    # 生成买、卖方信息
    buyer_institution = np.random.choice(institutions)
    buyer_trader = np.random.choice(traders)
    seller_institution = np.random.choice(institutions)
    seller_trader = np.random.choice(traders)
    
    # 确保买方和卖方不是同一机构
    while seller_institution == buyer_institution:
        seller_institution = np.random.choice(institutions)
    
    # 添加到数据列表
    data.append({
        '交易时间': transaction_time,
        '交易量': volume,
        '买方机构': buyer_institution,
        '买方交易员': buyer_trader,
        '卖方机构': seller_institution,
        '卖方交易员': seller_trader
    })

# 创建DataFrame
df = pd.DataFrame(data)
# print(df.head())

##### 过滤当天对倒交易，仅保留交易净值

In [54]:
# def calculate_net_volume(group):
#     print("group\n", group)
#     pairs = group['机构交易员对'].iloc[0]
#     inst_a, trader_a = pairs[0]
#     inst_b, trader_b = pairs[1]
#     volume_a_to_b = group[(group['买方机构'] == inst_b) & (group['买方交易员'] == trader_b) & 
#                           (group['卖方机构'] == inst_a) & (group['卖方交易员'] == trader_a)]['交易量'].sum()
#     volume_b_to_a = group[(group['买方机构'] == inst_a) & (group['买方交易员'] == trader_a) & 
#                           (group['卖方机构'] == inst_b) & (group['卖方交易员'] == trader_b)]['交易量'].sum()
#     return volume_a_to_b - volume_b_to_a

def process_daily_net_transactions(df):    
    df['交易日期'] = pd.to_datetime(df['交易时间']).dt.date
    df['机构交易员对'] = df.apply(lambda row: tuple(sorted([(row['买方机构'], row['买方交易员']), 
                                                      (row['卖方机构'], row['卖方交易员'])])), axis=1)
    
    direct_df = df.copy()
    direct_df['direction'] = 1
    direct_df.loc[ (direct_df['买方机构']==direct_df['机构交易员对'].map(lambda x: x[1][0])) & 
        (direct_df['买方交易员']==direct_df['机构交易员对'].map(lambda x: x[1][1])), 'direction'] = -1
    
    direct_df['净交易量'] = direct_df['交易量'] * direct_df['direction']
    net_df = direct_df.groupby(['交易日期', '机构交易员对'])['净交易量'].sum().reset_index()
    print(net_df.head())


    # # 新版本的 pandas 不支持 groupby 和 apply 一起用了
    # net_transactions = df.groupby(['交易日期', '机构交易员对']).apply(calculate_net_volume).reset_index()
    # print(net_transactions.head())
    
    result_df = []
    for _, row in net_df.iterrows():
        if row['净交易量'] != 0:
            (inst_a, trader_a), (inst_b, trader_b) = row['机构交易员对']
            if row['净交易量'] > 0:
                seller, seller_trader, buyer, buyer_trader = inst_a, trader_a, inst_b, trader_b
            else:
                seller, seller_trader, buyer, buyer_trader = inst_b, trader_b, inst_a, trader_a
            result_df.append({
                '买方机构': buyer,
                '买方交易员': buyer_trader,
                '卖方机构': seller,
                '卖方交易员': seller_trader,
                '交易日期': row['交易日期'],
                '净交易量': abs(row['净交易量'])
            })
    
    return pd.DataFrame(result_df)

net_df = process_daily_net_transactions(df)
print(net_df.head())

         交易日期                      机构交易员对   净交易量
0  2024-08-13  ((机构A, 交易员1), (机构D, 交易员3)) -46283
1  2024-08-13  ((机构A, 交易员1), (机构E, 交易员3))   3761
2  2024-08-13  ((机构A, 交易员2), (机构E, 交易员4))  48327
3  2024-08-13  ((机构A, 交易员4), (机构D, 交易员2)) -28858
4  2024-08-13  ((机构B, 交易员5), (机构E, 交易员1))  49105
  买方机构 买方交易员 卖方机构 卖方交易员        交易日期   净交易量
0  机构A  交易员1  机构D  交易员3  2024-08-13  46283
1  机构E  交易员3  机构A  交易员1  2024-08-13   3761
2  机构E  交易员4  机构A  交易员2  2024-08-13  48327
3  机构A  交易员4  机构D  交易员2  2024-08-13  28858
4  机构E  交易员1  机构B  交易员5  2024-08-13  49105


##### 展示数据

In [38]:
# 按机构和交易员分组，并显示买入交易
grouped_df = net_df.groupby(['买方机构', '买方交易员'])

print("\n按机构和交易员分组的买入交易：")
for (institution, trader), group in grouped_df:
    print(f"\n{institution} - {trader} 的买入交易：")
    sorted_group = group.sort_values('交易时间')
    print(sorted_group[['交易时间', '交易量', '卖方机构', '卖方交易员']])

In [39]:
# 按机构和交易员分组，并显示卖出交易
grouped_df = net_df.groupby(['卖方机构', '卖方交易员'])

print("\n按机构和交易员分组的卖出交易：")
for (institution, trader), group in grouped_df:
    print(f"\n{institution} - {trader} 的卖出交易：")
    sorted_group = group.sort_values('交易时间')
    print(sorted_group[['交易时间', '交易量', '买方机构', '买方交易员']])

#### 绘制 Demo 设计
##### 设置常量

In [59]:
from datetime import datetime, timedelta
import plotly.express as px

# 获取唯一的机构
institutions = list(set(net_df['买方机构'].unique().tolist() + net_df['卖方机构'].unique().tolist()))

# 获取每个机构的唯一交易员
institution_traders = {}
for instn in institutions:
    buy_traders = net_df[net_df['买方机构'] == instn]['买方交易员'].unique().tolist()
    sell_traders = net_df[net_df['卖方机构'] == instn]['卖方交易员'].unique().tolist()
    institution_traders[instn] = list(set(buy_traders + sell_traders))

# 设置颜色映射
color_map = px.colors.qualitative.Plotly
institution_colors = {inst: color_map[i % len(color_map)] for i, inst in enumerate(institutions)}

# 设置时间范围
t_earliest = net_df['交易日期'].min()
t_latest = net_df['交易日期'].max()
t_duration = (t_latest - t_earliest).days
# print(t_earliest, t_latest, t_duration)

# 定义时间到长度的映射函数
def time_to_length(t):
    return CONSTRANTS["x_start"] + (t_latest - t).days / t_duration * CONSTRANTS["x_range"]

# 绘制图表的常量
CONSTRANTS = {
    # 计算y轴位置
    "y_loc" : {},
    "y0" : 0,
    "instn_gap" : 2,
    "trader_gap" : 0.5,
    "trader_h" : 1,
    "trader_w" : 1,
    # 添加x轴的时间和长度的映射
    "x_start" : 0,
    "x_end" : 200,
    "x_range" : 0,
    "axis_slot" : 3
}
CONSTRANTS["x_range"] = CONSTRANTS["x_end"] - CONSTRANTS["x_start"]

##### 计算纵坐标位置

In [60]:
for instn, trader_list in institution_traders.items():
    # y_loc[instn] = y0
    for trader in trader_list:
        CONSTRANTS["y_loc"][f"{instn}-{trader}"] = CONSTRANTS["y0"]
        CONSTRANTS["y0"] += CONSTRANTS["trader_h"] + CONSTRANTS["trader_gap"]
    CONSTRANTS["y0"] += CONSTRANTS["instn_gap"]
# print(y_loc)

"""需要获取每个交易员持有的最大数额"""
# 获取每个机构的每个交易员在当前所有交易中的最大流量
# trader_max_volume = {}

# for instn in institutions:
#     for trader in institution_traders[instn]:
#         # 获取买方交易
#         buy_volume = net_df[(net_df['买方机构'] == instn) & (df['买方交易员'] == trader)]['交易量'].max()
#         # 获取卖方交易
#         sell_volume = df[(df['卖方机构'] == instn) & (df['卖方交易员'] == trader)]['交易量'].max()
#         # 取买卖方交易中的最大值
#         max_volume = max(buy_volume, sell_volume) if not pd.isna(buy_volume) and not pd.isna(sell_volume) else (buy_volume if not pd.isna(buy_volume) else sell_volume)
        
#         trader_max_volume[f"{instn}-{trader}"] = max_volume if not pd.isna(max_volume) else 0






'需要获取每个交易员持有的最大数额'

##### 设置坐标轴

In [62]:
import plotly.graph_objects as go

# 创建图形
fig = go.Figure()

# 设置布局
fig.update_layout(
    title="机构和交易员时间分布图",
    xaxis_title="交易时间",
    yaxis_title="机构和交易员",
    # xaxis=dict(autorange="reversed"),
    xaxis=dict(
        # autorange="reversed",
        range=[CONSTRANTS["x_start"], CONSTRANTS["x_end"]],
        tickmode='array', # 设置刻度模式为"数组"模式，手动指定刻度位置和标签
        tickvals=[
            CONSTRANTS["x_start"] + i * CONSTRANTS["x_range"] / ((t_duration // CONSTRANTS["axis_slot"] + 1) - 1)
            for i in range(t_duration // CONSTRANTS["axis_slot"] + 1)
        ], # 根据t_duration计算刻度位置
        ticktext=[
            (t_latest - timedelta(days=i * t_duration / ((t_duration // CONSTRANTS["axis_slot"] + 1) - 1))).strftime('%Y-%m-%d')
            for i in range(t_duration // CONSTRANTS["axis_slot"] + 1)
        ] # 根据计算的刻度位置生成对应的日期标签
    ),
    yaxis=dict(
        # autorange="reversed",
        range=[max(CONSTRANTS["y_loc"].values()) + CONSTRANTS["trader_h"], min(CONSTRANTS["y_loc"].values()) - CONSTRANTS["trader_h"]]
    ),
    showlegend=False,
    height=600,
    width=800,
)

# 为每个机构和交易员添加长方形
for instn, trader_list in institution_traders.items():
    clr = institution_colors[instn]
    print(instn, clr)
    # 添加交易员长方形
    for trader in trader_list:
        y_loc = CONSTRANTS["y_loc"][f"{instn}-{trader}"]
        fig.add_shape(
            type="rect",
            x0=CONSTRANTS["x_start"], x1=CONSTRANTS["trader_w"],
            y0 = y_loc, y1 = y_loc + CONSTRANTS["trader_h"],
            fillcolor=clr,
            opacity=0.7,
            line=dict(width=0),
        )

# 显示图表
fig.show()

机构D #636EFA
机构A #EF553B
机构E #00CC96
机构B #AB63FA
机构C #FFA15A


##### 简单展示边

In [58]:
# 创建交易边
for _, trade in net_df.iterrows():
    sell_institution = trade['卖方机构']
    sell_trader = trade['卖方交易员']
    buy_institution = trade['买方机构']
    buy_trader = trade['买方交易员']
    trade_time = trade['交易日期']
    
    # 获取卖方和买方的y坐标
    sell_y = CONSTRANTS["y_loc"][f"{sell_institution}-{sell_trader}"] + CONSTRANTS["trader_h"] / 2
    buy_y = CONSTRANTS["y_loc"][f"{buy_institution}-{buy_trader}"] + CONSTRANTS["trader_h"] / 2
    
    # 获取交易时间对应的x坐标
    sell_x = time_to_length(trade_time)
    prevSell4buy = net_df[(net_df['卖方机构'] == buy_institution) & (net_df['卖方交易员'] == buy_trader) & (net_df['交易日期'] < trade_time)]
    # buy_x  = 






    trade_x = time_to_length(trade_time)
    
    # 获取买方上一次交易的时间 #TODO:这里买卖有问题
    prev_trades = net_df[(net_df['买方机构'] == buy_institution) & (net_df['买方交易员'] == buy_trader) & (net_df['交易日期'] < trade_time)]
    print("\n", prev_trades)
    if not prev_trades.empty:
        try:
            prev_trade_time = prev_trades['交易日期'].max()
            prev_x = time_to_length(prev_trade_time)
        except:
            print(f"//**//**// error {prev_trades['交易日期']}")
            prev_x = CONSTRANTS["x_end"]
    else:
        prev_x = CONSTRANTS["x_end"]
    
    # 添加交易边
    fig.add_trace(go.Scatter(
        x=[trade_x, prev_x],
        y=[sell_y, buy_y],
        mode='lines',
        line=dict(color='rgba(0,0,0,0.1)', width=1),
        hoverinfo='none'
    ))

# 显示图表
fig.show()


KeyError: '交易时间'

##### 使用 Sankey 图表示边