In [48]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [49]:
csv_file = "QITS-2021-1-EN-20220624T010716.csv"
df = pd.read_csv(csv_file)
df

Unnamed: 0,LOCATION,Reporter country,FLOW,Flow,PARTNER,Partner country,FREQUENCY,Frequency,TIME,Time,Unit Code,Unit,Value
0,AUS,Australia,EXP,Exports,AUT,Austria,A,Annual,2011,2011,USD,US Dollar,7.314492e+07
1,AUS,Australia,EXP,Exports,AUT,Austria,A,Annual,2012,2012,USD,US Dollar,7.148306e+07
2,AUS,Australia,EXP,Exports,AUT,Austria,A,Annual,2013,2013,USD,US Dollar,6.789742e+07
3,AUS,Australia,EXP,Exports,AUT,Austria,A,Annual,2014,2014,USD,US Dollar,6.111475e+07
4,AUS,Australia,EXP,Exports,AUT,Austria,A,Annual,2015,2015,USD,US Dollar,5.473024e+07
...,...,...,...,...,...,...,...,...,...,...,...,...,...
42952,LVA,Latvia,IMP,Imports,SAU,Saudi Arabia,A,Annual,2019,2019,USD,US Dollar,1.693470e+05
42953,LVA,Latvia,IMP,Imports,SAU,Saudi Arabia,A,Annual,2020,2020,USD,US Dollar,7.626500e+04
42954,LVA,Latvia,IMP,Imports,G-20,G20,A,Annual,2019,2019,USD,US Dollar,1.810599e+10
42955,LVA,Latvia,IMP,Imports,G-20,G20,A,Annual,2020,2020,USD,US Dollar,1.600864e+10


In [50]:
# Separate import and export data
import_df = df[df['FLOW'] == 'IMP']  # Import data
export_df = df[df['FLOW'] == 'EXP']  # Export data

# Filter out "World" entries
import_df = import_df[(import_df['Partner country'] != 'World') \
                      & (import_df['Partner country'] != 'OECD - Total') \
                     & (import_df['Partner country'] != 'G20') \
                     & (import_df['Partner country'] != 'G7') \
                     & (import_df['Reporter country'] != 'World')\
                      & (import_df['Reporter country'] != 'OECD - Total') \
                     & (import_df['Reporter country'] != 'G20') \
                     & (import_df['Reporter country'] != 'G7')]

export_df = export_df[(export_df['Partner country'] != 'World') \
                      & (export_df['Partner country'] != 'OECD - Total') \
                     & (export_df['Partner country'] != 'G20') \
                     & (export_df['Partner country'] != 'G7')\
                     & (export_df['Reporter country'] != 'World')\
                      & (export_df['Reporter country'] != 'OECD - Total') \
                     & (export_df['Reporter country'] != 'G20') \
                     & (export_df['Reporter country'] != 'G7')]

In [51]:
df = df[(df['Partner country'] != 'World') \
                      & (df['Partner country'] != 'OECD - Total') \
                     & (df['Partner country'] != 'G20') \
                     & (df['Partner country'] != 'G7') \
                     & (df['Reporter country'] != 'World')\
                      & (df['Reporter country'] != 'OECD - Total') \
                     & (df['Reporter country'] != 'G20') \
                     & (df['Reporter country'] != 'G7')]

In [52]:
# df is the preprocessed DataFrame with the relevant columns.

# Aggregate trade values by reporter and partner country
trade_values = df.groupby(['Reporter country', 'Partner country'])['Value'].sum().reset_index()
trade_values[(trade_values['Reporter country'].str.contains('China'))]



Unnamed: 0,Reporter country,Partner country,Value
274,China (People's Republic of),Argentina,53842400000.0
275,China (People's Republic of),Australia,448786000000.0
276,China (People's Republic of),Austria,28746100000.0
277,China (People's Republic of),Belgium,96144720000.0
278,China (People's Republic of),Brazil,302655000000.0
279,China (People's Republic of),Canada,181948500000.0
280,China (People's Republic of),Chile,124197500000.0
281,China (People's Republic of),China (People's Republic of),489746000000.0
282,China (People's Republic of),Colombia,34762710000.0
283,China (People's Republic of),Costa Rica,13001740000.0


In [53]:
# Sort by trade value to get the top trading corridors
top_trade_values = trade_values.sort_values(by='Value', ascending=False).head(20)  # Replace N with the number of top corridors you want to display

# Create a list of all unique countries to define the nodes
all_countries = pd.concat([top_trade_values['Reporter country'], top_trade_values['Partner country']]).unique()
country_idx = {country: idx for idx, country in enumerate(all_countries)}

# Map the countries to indices for the source and target
top_trade_values['source'] = top_trade_values['Reporter country'].map(country_idx)
top_trade_values['target'] = top_trade_values['Partner country'].map(country_idx)

# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
      pad=15,
      thickness=20,
      line=dict(color="black", width=0.5),
      label=all_countries
    ),
    link=dict(
      source=top_trade_values['source'],
      target=top_trade_values['target'],
      value=top_trade_values['Value']
    )
)])

fig.update_layout(title_text="Top Trading Corridors", font_size=10)
fig.show()

In [54]:
# Define the subdued colors for the nodes
node_colors = ['#B0C4DE', '#AFEEEE', '#E0FFFF', '#F0F8FF', '#F5F5F5']  # Replace with your chosen colors

# Define the colors for the links
link_colors = ['rgba(173, 216, 230, 0.5)', 'rgba(191, 239, 255, 0.5)']  # Light blue shades as in your example

# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
      pad=15,
      thickness=20,
      line=dict(color="black", width=0.5),
      label=all_countries,
      color=node_colors[:len(all_countries)]  # Apply the colors to the nodes
    ),
    link=dict(
      source=top_trade_values['source'],
      target=top_trade_values['target'],
      value=top_trade_values['Value'],
      color=[link_colors[i%len(link_colors)] for i in range(len(top_trade_values))]  # Apply the colors to the links
    )
)])

# Add titles for exporters and importers
fig.update_layout(
    title_text="Top Trading Corridors",
    annotations=[
        dict(text="Exporters", x=0.1, y=1.1, xref="paper", yref="paper", showarrow=False, font=dict(size=20)),
        dict(text="Importers", x=0.9, y=1.1, xref="paper", yref="paper", showarrow=False, font=dict(size=20))
    ],
    font_size=10
)

fig.show()


In [55]:
# 创建一个所有独特国家的列表，包括exporter和importer
unique_exporters = top_trade_values['Reporter country'].unique()
unique_importers = top_trade_values['Partner country'].unique()
all_countries = list(set(unique_exporters).union(set(unique_importers)))

# 给每个国家一个唯一的索引
country_idx = {country: i for i, country in enumerate(all_countries)}

# 为exporter和importer分配索引
top_trade_values['source'] = top_trade_values['Reporter country'].map(country_idx)
top_trade_values['target'] = top_trade_values['Partner country'].map(country_idx) + len(all_countries)

# 创建一个颜色列表，为每个国家分配一个颜色
country_colors = [
    '#95a2ff', '#fa8080', '#ffc076', '#fae768', '#87e885', 
    '#3cb9fc', '#73abf5', '#cb9bff', '#434348', '#90ed7d', 
    '#f7a35c', '#8085e9', '#E0FFFF', '#F0F8FF', '#F5F5F5',
    '#FFEFD5', '#FFF0F5', '#FAEBD7', '#E6E6FA', '#FFF5EE',
    '#FFEFD5', '#FFF0F5', '#FAEBD7','#FFEFD5'
    # 继续添加颜色
    # 需要足够多的颜色来匹配所有的国家
]

# 确保颜色列表与国家列表一样长
# assert len(country_colors) >= len(all_countries) * 2

# 创建桑基图
fig = go.Figure(data=[go.Sankey(
    node=dict(
      pad=15,
      thickness=20,
      line=dict(color='black', width=0),  # 将线宽设为0来移除边框
      label=all_countries + all_countries,  # 两次使用all_countries为了对应exporter和importer
      color=country_colors[:len(all_countries) * 2]  # 应用颜色到节点
    ),
    link=dict(
      source=top_trade_values['source'],
      target=top_trade_values['target'],
      value=top_trade_values['Value'],
      color='rgba(173, 216, 230, 0.5)'  # 淡蓝色的中间条形
    )
)])

fig.update_layout(title_text="Top Trading Corridors", font_size=10)
fig.show()