# 坐标提取

In [6]:
import pandas as pd
import requests
from time import sleep
from tqdm import tqdm

# === 步骤 1：读取 IMF 数据，提取唯一国家 ===
csv_path = "IMF_IMTS_Exports_2016_2024.csv"  # ← 修改为你的文件路径
df = pd.read_csv(csv_path, low_memory=False)

# 提取报告国和伙伴国的 ISO3 列
iso_reporters = df["COUNTRY.ID"].dropna().unique()
iso_partners = df["COUNTERPART_COUNTRY.ID"].dropna().unique()

# 合并去重
iso3_list = sorted(set(iso_reporters).union(set(iso_partners)))

print(f"🌍 共识别出 {len(iso3_list)} 个唯一国家")

# === 步骤 2：查询经纬度 ===
records = []
for iso3 in tqdm(iso3_list, desc="查询国家坐标"):
    try:
        r = requests.get(f"https://restcountries.com/v3.1/alpha/{iso3}", timeout=10)
        r.raise_for_status()
        js = r.json()
        lat, lon = js[0]["latlng"]
        records.append((iso3, lat, lon))
    except Exception as e:
        print(f"[WARN] {iso3}: 查询失败 ({e})")
        records.append((iso3, None, None))
    sleep(0.2)  # 避免触发 API 限速

# === 步骤 3：保存为 CSV 映射表 ===
df_coords = pd.DataFrame(records, columns=["ISO3", "Lat", "Lon"])
df_coords.to_csv("iso3_to_latlon.csv", index=False)
print("✅ 已保存 → iso3_to_latlon.csv")


🌍 共识别出 244 个唯一国家


查询国家坐标:  19%|█▉        | 46/244 [01:05<04:35,  1.39s/it]

[WARN] CSK: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/CSK)


查询国家坐标:  21%|██        | 51/244 [01:12<04:44,  1.47s/it]

[WARN] DDR: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/DDR)


查询国家坐标:  29%|██▊       | 70/244 [01:39<04:06,  1.42s/it]

[WARN] G001: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G001)


查询国家坐标:  29%|██▉       | 71/244 [01:41<04:04,  1.41s/it]

[WARN] G080: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G080)


查询国家坐标:  30%|██▉       | 72/244 [01:42<03:59,  1.39s/it]

[WARN] G092: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G092)


查询国家坐标:  30%|██▉       | 73/244 [01:44<04:01,  1.41s/it]

[WARN] G110: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G110)


查询国家坐标:  30%|███       | 74/244 [01:45<04:02,  1.42s/it]

[WARN] G163: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G163)


查询国家坐标:  31%|███       | 75/244 [01:46<03:58,  1.41s/it]

[WARN] G200: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G200)


查询国家坐标:  31%|███       | 76/244 [01:48<04:05,  1.46s/it]

[WARN] G205: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G205)


查询国家坐标:  32%|███▏      | 77/244 [01:49<04:02,  1.45s/it]

[WARN] G400: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G400)


查询国家坐标:  32%|███▏      | 78/244 [01:51<03:57,  1.43s/it]

[WARN] G505: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G505)


查询国家坐标:  32%|███▏      | 79/244 [01:52<03:51,  1.40s/it]

[WARN] G603: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G603)


查询国家坐标:  33%|███▎      | 80/244 [01:54<03:53,  1.42s/it]

[WARN] G903: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G903)


查询国家坐标:  33%|███▎      | 81/244 [01:55<03:48,  1.40s/it]

[WARN] G998: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/G998)


查询国家坐标:  40%|███▉      | 97/244 [02:18<03:27,  1.41s/it]

[WARN] GX170: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/GX170)


查询国家坐标:  40%|████      | 98/244 [02:19<03:24,  1.40s/it]

[WARN] GX405: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/GX405)


查询国家坐标:  41%|████      | 99/244 [02:21<03:24,  1.41s/it]

[WARN] GX440: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/GX440)


查询国家坐标:  41%|████      | 100/244 [02:22<03:19,  1.38s/it]

[WARN] GX605: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/GX605)


查询国家坐标:  41%|████▏     | 101/244 [02:23<03:18,  1.39s/it]

[WARN] GX901: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/GX901)


查询国家坐标:  75%|███████▍  | 182/244 [04:19<01:26,  1.39s/it]

[WARN] SCG: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/SCG)


查询国家坐标:  80%|███████▉  | 194/244 [04:35<01:08,  1.37s/it]

[WARN] SUN: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/SUN)


查询国家坐标:  89%|████████▊ | 216/244 [05:06<00:39,  1.39s/it]

[WARN] TX126: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX126)


查询国家坐标:  89%|████████▉ | 217/244 [05:08<00:38,  1.42s/it]

[WARN] TX399: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX399)


查询国家坐标:  89%|████████▉ | 218/244 [05:09<00:36,  1.40s/it]

[WARN] TX489: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX489)
[WARN] TX598: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX598)


查询国家坐标:  90%|█████████ | 220/244 [05:12<00:34,  1.42s/it]

[WARN] TX799: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX799)
[WARN] TX884: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX884)


查询国家坐标:  91%|█████████ | 222/244 [05:15<00:31,  1.41s/it]

[WARN] TX898: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX898)


查询国家坐标:  91%|█████████▏| 223/244 [05:16<00:29,  1.38s/it]

[WARN] TX899: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX899)
[WARN] TX910: 查询失败 (400 Client Error: Bad Request for url: https://restcountries.com/v3.1/alpha/TX910)


查询国家坐标:  97%|█████████▋| 236/244 [05:34<00:11,  1.42s/it]

[WARN] WBG: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/WBG)


查询国家坐标:  98%|█████████▊| 238/244 [05:37<00:08,  1.39s/it]

[WARN] YAR: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/YAR)


查询国家坐标:  98%|█████████▊| 240/244 [05:40<00:05,  1.40s/it]

[WARN] YMD: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/YMD)


查询国家坐标:  99%|█████████▉| 241/244 [05:41<00:04,  1.38s/it]

[WARN] YUG: 查询失败 (404 Client Error: Not Found for url: https://restcountries.com/v3.1/alpha/YUG)


查询国家坐标: 100%|██████████| 244/244 [05:45<00:00,  1.42s/it]

✅ 已保存 → iso3_to_latlon.csv





# Global Export Network


In [None]:
"""
IMF IMTS CSV → 全球出口网络 (多年份可滑动查看)
================================================

依赖:
    pip install pandas networkx folium tqdm branca

用法:
    1. 将 IMF 下载的 Timeseries‑per‑row CSV 放到同目录
    2. 准备 iso3_to_latlon.csv (ISO3,Lat,Lon)
    3. 修改 FILE_CSV、YEAR_START、YEAR_END、TOP_K_FLOW 等参数
    4. python imf_trade_map_range.py

产物:
    · imf_trade_network_<START>_<END>.html —— 多年份可交互地图

功能概述:
    * 读取并清洗 IMF IMTS 出口数据
    * 为给定年份区间 (闭区间) 构建各年份的全球出口网络 (节点 + 边)
    * 每个年份生成一个 Folium LayerGroup
    * 注入自定义 Leaflet Slider —— 拖动即可在年份图层间切换
"""
from __future__ import annotations
import math
from pathlib import Path
from typing import Dict, Tuple, List

import pandas as pd
import networkx as nx
import folium
from folium.plugins import AntPath
from tqdm import tqdm
from matplotlib import cm, colors
from branca.element import MacroElement, Template

# ================ 参数区 ================
FILE_CSV    = "IMF_IMTS_Exports_1948_2024.csv"  # IMF CSV 文件
YEAR_START  = 1948                               # 起始年份 (含)
YEAR_END    = 2024                               # 结束年份 (含)
TOP_K_FLOW  = 100                                # 每年绘制前 K 条边
EDGE_STYLE  = "poly"                            # "poly" | "ant"
MIN_TRADE   = None                               # 最小贸易额阈值 (None 关闭)
COORD_CSV   = "iso3_to_latlon.csv"              # 国家坐标表
# =======================================


# ---------- ① 加载本地坐标 ----------

def load_coord_dict(filepath: str | Path) -> Dict[str, Tuple[float | None, float | None]]:
    """读取 ISO3 → (lat, lon) 映射"""
    df = pd.read_csv(filepath)
    mapping = {}
    for _, row in df.iterrows():
        lat, lon = row.Lat, row.Lon
        if pd.isna(lat) or pd.isna(lon):
            lat, lon = None, None
        mapping[row.ISO3] = (lat, lon)
    return mapping


coord_dict: Dict[str, Tuple[float | None, float | None]] = load_coord_dict(COORD_CSV)


# ---------- ② 读取并清洗贸易数据 ----------

def load_exports_from_csv(csv_path: str | Path, year: int) -> pd.DataFrame:
    """读取指定年份的出口数据并清洗 (去 NaN / 自环 / 负值)"""
    cols = ["COUNTRY.ID", "COUNTERPART_COUNTRY.ID", str(year)]
    df = (
        pd.read_csv(csv_path, low_memory=False)[cols]
        .rename(
            columns={
                "COUNTRY.ID": "reporter",
                "COUNTERPART_COUNTRY.ID": "partner",
                str(year): "value",
            }
        )
        .dropna(subset=["value"])
    )
    df = df[df["value"] > 0]
    df = df[df["reporter"] != df["partner"]]  # 去自环
    if MIN_TRADE is not None:
        df = df[df["value"] > MIN_TRADE]
    return df.reset_index(drop=True)


# ---------- ③ 坐标查询 ----------

def get_latlon(iso3: str) -> Tuple[float | None, float | None]:
    return coord_dict.get(iso3, (None, None))


# ---------- ④ 生成单年份图层 ----------

def build_year_layer(
    df: pd.DataFrame,
    year: int,
    top_k: int = 300,
    edge_style: str = "poly",
) -> folium.FeatureGroup:
    """将单一年份的节点 & 边绘制进 FeatureGroup"""

    # 只保留坐标完整的国家
    iso_valid = {
        iso for iso, (lat, lon) in coord_dict.items() if None not in (lat, lon)
    }
    df = df[df["reporter"].isin(iso_valid) & df["partner"].isin(iso_valid)]

    # 颜色与宽度映射 
    vmin, vmax = df["value"].min(), df["value"].max()
    norm = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    cmap = cm.get_cmap("YlOrRd")

    def edge_color(val):
        return colors.rgb2hex(cmap(norm(val)))

    def edge_weight(val):
        return 0.6 + 3 * norm(val)

    # FeatureGroup —— 之后注入到主图层
    fg = folium.FeatureGroup(name=str(year), show=False)

    # --- 节点 (以 reporter 总出口额计大小) ---------------------------
    out_sum = df.groupby("reporter")["value"].sum()
    for iso3, total in out_sum.items():
        lat, lon = get_latlon(iso3)
        radius = max(1.5, (total ** 0.5) / 150)
        folium.CircleMarker(
            [lat, lon],
            radius=radius,
            tooltip=f"{iso3}: {total:,.0f} USD exports",
            color="#0066cc",
            fill=True,
            fill_opacity=0.7,
        ).add_to(fg)

    # --- 边：取出口额最高 top_k --------------------------------------
    edges = df.nlargest(top_k, "value")
    for _, r in edges.iterrows():
        lat1, lon1 = get_latlon(r.reporter)
        lat2, lon2 = get_latlon(r.partner)
        w, c = edge_weight(r.value), edge_color(r.value)

        if edge_style == "ant":
            AntPath(
                [[lat1, lon1], [lat2, lon2]],
                weight=w,
                color=c,
                dash_array=[10, 15],
                delay=400,
                opacity=0.75,
            ).add_to(fg)
        else:
            folium.PolyLine(
                [[lat1, lon1], [lat2, lon2]],
                weight=w,
                color=c,
                opacity=0.75,
            ).add_to(fg)

    return fg


# ---------- ⑤ 自定义年份滑块控件 ----------
class YearSlider(MacroElement):
    """Leaflet slider 控件: 拖动切换年份图层"""

    _template = Template(
        """
        {% macro script(this, kwargs) %}
            // 创建容器
            var sliderContainer = L.DomUtil.create('div', 'year-slider');
            sliderContainer.innerHTML = '<input type="range" min="{{ this.min_year }}" max="{{ this.max_year }}" value="{{ this.start_year }}" step="1" id="yearSlider"> <span id="yearLabel">{{ this.start_year }}</span>';
            
            L.DomEvent.disableClickPropagation(sliderContainer);
            L.DomEvent.disableScrollPropagation(sliderContainer);
            
            sliderContainer.style.position = 'fixed';
            sliderContainer.style.bottom = '50px';
            sliderContainer.style.left = '50px';
            sliderContainer.style.zIndex = '9999';
            sliderContainer.style.background = 'white';
            sliderContainer.style.padding = '10px';
            sliderContainer.style.borderRadius = '4px';
            sliderContainer.style.boxShadow = '0 0 15px rgba(0,0,0,0.2)';
            {{ this._parent.get_name() }}.getContainer().appendChild(sliderContainer);

            // 年份 → LayerGroup 映射
            var layerGroups = {
            {% for y, name in this.year_layer_names.items() %}
                "{{ y }}": {{ name }},
            {% endfor %}
            };

            // 显示指定年份图层
            function showYearLayer(y) {
                Object.keys(layerGroups).forEach(function(k) {
                    if (k === y) {
                        if (!{{ this._parent.get_name() }}.hasLayer(layerGroups[k])) {
                            {{ this._parent.get_name() }}.addLayer(layerGroups[k]);
                        }
                    } else {
                        if ({{ this._parent.get_name() }}.hasLayer(layerGroups[k])) {
                            {{ this._parent.get_name() }}.removeLayer(layerGroups[k]);
                        }
                    }
                });
            }

            // 事件绑定
            document.getElementById('yearSlider').addEventListener('input', function(e) {
                var y = e.target.value;
                document.getElementById('yearLabel').innerHTML = y;
                showYearLayer(y);
            });

            // 初始化
            showYearLayer(String({{ this.start_year }}));
        {% endmacro %}
        """
    )

    def __init__(self, start_year: int, min_year: int, max_year: int, year_layer_names: Dict[int, str]):
        super().__init__()
        self._name = "YearSlider"
        self.start_year = start_year
        self.min_year = min_year
        self.max_year = max_year
        self.year_layer_names = year_layer_names


# ---------- ⑥ 主流程 ----------

def main():
    # ---------------- 创建底图 ----------------
    m = folium.Map(location=[20, 0], zoom_start=2, tiles="cartodbpositron")

    year_layer_varnames: Dict[int, str] = {}

    # ---------------- 按年生成图层 ----------------
    for year in range(YEAR_START, YEAR_END + 1):
        df_year = load_exports_from_csv(FILE_CSV, year)
        fg_year = build_year_layer(df_year, year, TOP_K_FLOW, EDGE_STYLE)
        fg_year.add_to(m)
        year_layer_varnames[year] = fg_year.get_name()
        print(f"Year {year}: edges≈{len(df_year):,}")


    # ---------------- 年份滑块 ----------------
    slider = YearSlider(
        start_year=YEAR_START,
        min_year=YEAR_START,
        max_year=YEAR_END,
        year_layer_names=year_layer_varnames,
    )
    m.add_child(slider)

    # ---------------- 标题 ----------------
    title_html = (
        f"<h4 style='text-align:center;font-size:20px'>IMF IMTS Global Export Network — {YEAR_START}-{YEAR_END}</h4>"
    )
    m.get_root().html.add_child(folium.Element(title_html))

    # ---------------- 保存 ----------------
    outfile = f"imf_trade_network_{YEAR_START}_{YEAR_END}.html"
    m.save(outfile)
    print(f"Map saved → {outfile}")


if __name__ == "__main__":
    main()


  cmap = cm.get_cmap("YlOrRd")


Year 1948: edges≈5,552


  cmap = cm.get_cmap("YlOrRd")


Year 1949: edges≈5,052


  cmap = cm.get_cmap("YlOrRd")


# Top N Exporters and their Partners

In [None]:
# -*- coding: utf-8 -*-
import folium
from folium.plugins import AntPath
from branca.element import MacroElement, Template
from matplotlib import cm, colors
import seaborn as sns
import pandas as pd
from tqdm import tqdm

# ---------------- 单一年份图层 ----------------
# ---------- 单年图层（前 k 出口国 × 各自前 n 边，蓝色节点版） ----------
def build_topn_layer(
    df: pd.DataFrame,
    year: int,
    k_countries: int = 8,
    n_edges: int = 6,
    edge_style: str = "poly",
    cmap_name: str = "tab10",
    partner_radius: int = 2          # ★ 伙伴国节点固定半径
) -> folium.FeatureGroup:

    # 0️⃣ 过滤无坐标节点 ------------------------------------------------
    iso_valid = {
        iso for iso in set(df["reporter"]).union(df["partner"])
        if get_latlon(iso)[0] is not None
    }
    df = df[df["reporter"].isin(iso_valid) & df["partner"].isin(iso_valid)]

    # ① 核心国家 -------------------------------------------------------
    country_sum_raw = (df.groupby("reporter")["value"]
                         .sum()
                         .sort_values(ascending=False))

    core_list = []
    for iso in country_sum_raw.index:
        core_list.append(iso)
        if len(core_list) == k_countries:
            break
    core_sum = country_sum_raw.loc[core_list]

    # —— 颜色映射：核心各异、伙伴灰 ------------------------------------
    core_palette    = sns.color_palette(cmap_name, n_colors=len(core_list)).as_hex()
    color_map_core  = {iso: core_palette[i] for i, iso in enumerate(core_list)}
    color_partner   = "#888888"

    # ② 每核心国前 n 边 ------------------------------------------------
    edges = pd.concat(
        [df[df["reporter"] == iso]
           .nlargest(n_edges, "value")
           .assign(core=iso)
         for iso in core_list],
        ignore_index=True
    )

    # ③ 伙伴国出口额近似 ----------------------------------------------
    partner_set = set(edges["partner"]) - set(core_list)

    partner_sum = (df[df["reporter"].isin(partner_set)]
                     .groupby("reporter")["value"].sum())
    partner_sum = partner_sum.combine_first(
        edges.groupby("partner")["value"].sum()
    )

    node_sum = pd.concat([core_sum, partner_sum])
    for iso in partner_set - set(node_sum.index):
        node_sum.loc[iso] = 1  # 占位

    # ④ 边宽 ----------------------------------------------------------
    vmin, vmax   = edges["value"].min(), edges["value"].max()
    norm         = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    edge_weight  = lambda v: 0.6 + 3 * norm(v)

    # ⑤ Folium 图层 ---------------------------------------------------
    fg = folium.FeatureGroup(name=str(year), show=False)

    # —— 绘节点 --------------------------------------------------------
    for iso, total in node_sum.items():
        lat, lon = get_latlon(iso)
        if iso in core_list:
            radius = max(1.5, (total ** 0.5) / 150)      # 核心国：随出口额
            col    = color_map_core[iso]
        else:
            radius = partner_radius                    # 伙伴国：固定
            col    = color_partner

        folium.CircleMarker(
            [lat, lon],
            radius=radius,
            color=col,
            weight=2,
            fill=True,
            fill_color=col,
            fill_opacity=0.8,
            tooltip=f"{iso}: {total:,.0f} USD exports"
        ).add_to(fg)

    # —— 绘边 ----------------------------------------------------------
    for _, r in edges.iterrows():
        lat1, lon1 = get_latlon(r.reporter)
        lat2, lon2 = get_latlon(r.partner)
        w = edge_weight(r.value)
        c = color_map_core[r.core]     # 与核心国同色

        if edge_style == "ant":
            AntPath([[lat1, lon1], [lat2, lon2]],
                    color=c, weight=w, opacity=0.8,
                    dash_array=[10, 15], delay=400).add_to(fg)
        else:
            folium.PolyLine([[lat1, lon1], [lat2, lon2]],
                            color=c, weight=w, opacity=0.8).add_to(fg)

    return fg


# ---------------- 滑块控件 ----------------
class YearSlider(MacroElement):
    _template = Template("""
    {% macro script(this,kwargs) %}
        var s = L.DomUtil.create('div', 'year-slider');
        s.innerHTML='<input type="range" min="{{this.min}}" max="{{this.max}}" value="{{this.val}}" id="ys"> <span id="yl">{{this.val}}</span>';
        s.style.position='fixed'; s.style.bottom='50px'; s.style.left='50px';
        s.style.zIndex='9999'; s.style.background='white'; s.style.padding='8px';
        s.style.borderRadius='4px'; s.style.boxShadow='0 0 10px rgba(0,0,0,.3)';
        L.DomEvent.disableClickPropagation(s);
        L.DomEvent.disableScrollPropagation(s);
        {{this._parent.get_name()}}.getContainer().appendChild(s);

        var layers = {
        {% for y,n in this.names.items() %}
            "{{y}}": {{n}},
        {% endfor %}
        };
        function show(y){Object.keys(layers).forEach(k=>{
            if(k===y){ if(!{{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.addLayer(layers[k]); }
            else      { if({{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.removeLayer(layers[k]); }
        });}
        document.getElementById('ys').addEventListener('input',e=>{
            var y=e.target.value; document.getElementById('yl').innerHTML=y; show(y);
        });
        show(String({{this.val}}));
    {% endmacro %}
    """)
    def __init__(self, names: Dict[int,str], start: int, end: int):
        super().__init__()
        self.names, self.val, self.min, self.max = names, start, start, end

# ---------------- 主函数：生成多年份滑块地图 ----------------
def make_topn_slider_map(
    csv_path: str,
    years: range,
    k_countries: int = 8,
    n_edges: int = 6,
    edge_style: str = "poly",
    cmap_name: str = "tab10"
) -> folium.Map:
    m = folium.Map(location=[20,0], zoom_start=2, tiles="cartodbpositron")
    layer_names = {}

    for y in tqdm(years, desc="Build layers"):
        df_year = load_exports_from_csv(csv_path, y)
        layer   = build_topn_layer(df_year, y,
                                   k_countries=k_countries,
                                   n_edges=n_edges,
                                   edge_style=edge_style,
                                   cmap_name=cmap_name)
        layer.add_to(m)
        layer_names[y] = layer.get_name()

    m.add_child(YearSlider(layer_names, min(years), max(years)))
    m.get_root().html.add_child(folium.Element(
        f"<h4 style='text-align:center;font-size:20px'>Top Exporters Network {min(years)}–{max(years)}</h4>"
    ))
    return m
# ---------------- 使用示例 ----------------
YEAR_START, YEAR_END = 1948, 2024   # 自行设定区间
years_range = range(YEAR_START, YEAR_END + 1)

topn_map = make_topn_slider_map(
    csv_path=FILE_CSV,
    years=years_range,
    k_countries=8,
    n_edges=6,
    edge_style="poly",      # 或 "ant"
    cmap_name="tab10"
)
topn_map.save(f"top_exporters_{YEAR_START}_{YEAR_END}.html")


Build layers:   0%|          | 0/77 [00:00<?, ?it/s]


NameError: name 'load_exports_from_csv' is not defined

# Louvain Community.

In [1]:
# -*- coding: utf-8 -*-
"""
IMF IMTS CSV → 全球出口网络社区结构 (多年份滑块)
================================================

* 读取 IMF IMTS Timeseries-per-row CSV
* Louvain 社区划分 → 各社区着不同颜色
* Folium + 自定义年份滑块切换视图
"""

from __future__ import annotations
from pathlib import Path
from typing import Dict, Tuple

import pandas as pd
import networkx as nx
from community import community_louvain          # Louvain
import folium
from folium.plugins import AntPath
from branca.element import MacroElement, Template
from matplotlib import cm, colors
from tqdm import tqdm

# ---------- 参数 ----------
FILE_CSV   = "IMF_IMTS_Exports_1948_2024.csv"
COORD_CSV  = "iso3_to_latlon.csv"
YEAR_START, YEAR_END = 1948, 2024
TOP_K_EDGE = 500          # 为减轻可视化负担，仅取前 K 条出口额边
EDGE_OPACITY = 1          # 边透明度（跨社区变淡）  
MIN_TRADE   = None        # 若想滤掉过小贸易额，可设阈值

# ---------- 坐标 ----------
def load_coord(filepath: str | Path) -> Dict[str, Tuple[float, float]]:
    df = pd.read_csv(filepath)
    out = {}
    for _, r in df.iterrows():
        lat, lon = r.Lat, r.Lon
        if pd.isna(lat) or pd.isna(lon):
            continue  # 丢弃无效坐标
        out[r.ISO3] = (lat, lon)
    return out

coord = load_coord(COORD_CSV)

# ---------- 加载年度数据 ----------
def load_year(csv: str | Path, year: int) -> pd.DataFrame:
    cols = ["COUNTRY.ID", "COUNTERPART_COUNTRY.ID", str(year)]
    df = (pd.read_csv(csv, low_memory=False)[cols]
          .rename(columns={"COUNTRY.ID": "src",
                           "COUNTERPART_COUNTRY.ID": "dst",
                           str(year): "value"})  # ← 统一命名
          .dropna(subset=["value"]))  # 先丢掉值为空的行
    df = df[(df.value > 0) & (df.src != df.dst)]  # 排除自引用行

    # ----- 过滤 src 和 dst 是否有有效坐标 -----
    df = df[df['src'].isin(coord.keys()) & df['dst'].isin(coord.keys())]
    
    if MIN_TRADE is not None:
        df = df[df.value > MIN_TRADE]
    return df.reset_index(drop=True)

# ---------- 构建 Louvain 划分 ----------
def detect_communities(df: pd.DataFrame) -> Dict[str, int]:
    G = nx.Graph()
    for _, r in df.iterrows():
        G.add_edge(r.src, r.dst, weight=r.value)
    return community_louvain.best_partition(G, weight='weight',resolution=0.5)

def gen_palette(n_comm: int) -> list[str]:
    """
    生成不包含红色系的调色板（排除容易与跨区边红色混淆的颜色）。
    使用 matplotlib 的 tab20 / tab20b / tab20c 联合拼接。
    """
    exclude_hues = {"#d62728", "#e31a1c", "#ff7f0e", "#f15a60", "#f28e2b", "#e15759"}  # 红/橙/粉
    seen = set()
    palette = []

    # 联合多个 colormap，以确保足够颜色
    for cmap_name in ["tab20", "tab20b", "tab20c", "Set3"]:
        cmap = cm.get_cmap(cmap_name, 20)
        for i in range(cmap.N):
            hex_color = colors.rgb2hex(cmap(i))
            if hex_color in exclude_hues or hex_color in seen:
                continue
            seen.add(hex_color)
            palette.append(hex_color)
            if len(palette) >= n_comm:
                return palette

    raise ValueError("Too many communities – not enough safe colors without red tones.")


# ---------- 单年图层（新版统一配色） ----------

def build_layer(df: pd.DataFrame, year: int) -> folium.FeatureGroup:
    df = df.nlargest(TOP_K_EDGE, "value")  # 统一字段名为 'value'

    # ------ Louvain 社区划分 ------
    part = detect_communities(df)
    comm_ids = sorted(set(part.values()))
    palette = gen_palette(len(comm_ids))
    comm_color = {cid: palette[i] for i, cid in enumerate(comm_ids)}

    # ------ 颜色与边宽度映射（YlOrRd，对数归一） ------
    vmin, vmax = df["value"].min(), df["value"].max()
    norm = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    edge_weight = lambda val: 0.6 + 3 * norm(val)

    # ------ 节点尺寸（以 src 出口总额） ------
    out_sum = df.groupby("src")["value"].sum()
    nodes_in_year = {iso for iso in part if iso in coord}

    fg = folium.FeatureGroup(name=str(year), show=False)

    for iso in nodes_in_year:
        lat_lon = coord.get(iso)
        if lat_lon is None or pd.isna(lat_lon[0]) or pd.isna(lat_lon[1]):
            continue
        lat, lon = lat_lon
        total = out_sum.get(iso, 0)
        radius = max(3, (total ** 0.5) / 150)
        color = comm_color.get(part[iso], "#666666")

        folium.CircleMarker(
            [lat, lon],
            radius=radius,
            color=color,
            weight=2,
            fill=True,
            fill_color=color,
            fill_opacity=0.7,
            tooltip=f"{iso} (comm {part[iso]})\n{total:,.0f} USD exports"
        ).add_to(fg)

    # ------ 边：按是否跨社区设置透明度 ------
    for _, r in df.iterrows():
        if r.src not in coord or r.dst not in coord:
            continue
        lat1, lon1 = coord[r.src]
        lat2, lon2 = coord[r.dst]
        if pd.isna(lat1) or pd.isna(lat2):
            continue

        same_comm = part.get(r.src) == part.get(r.dst)
        edge_color_value = comm_color.get(part.get(r.src), "#666666") if same_comm else "red"
        opacity = 1.0 if same_comm else 0.2  # ← 跨社区变淡

        folium.PolyLine(
            [[lat1, lon1], [lat2, lon2]],
            weight=edge_weight(r.value),
            color=edge_color_value,
            opacity=opacity
        ).add_to(fg)

    return fg

# ---------- 滑块 ----------
class YearSlider(MacroElement):
    _template = Template("""
    {% macro script(this,kwargs) %}
        var slider = L.DomUtil.create('div', 'year-slider');
        slider.innerHTML = '<input type="range" min="{{this.min}}" max="{{this.max}}" value="{{this.val}}" id="ys"> <span id="yl">{{this.val}}</span>';
        slider.style.position = 'fixed'; slider.style.bottom = '50px'; slider.style.left = '50px';
        slider.style.background='white'; slider.style.padding='8px'; slider.style.zIndex='9999';
        slider.style.borderRadius='4px'; slider.style.boxShadow='0 0 10px rgba(0,0,0,.3)';
        L.DomEvent.disableClickPropagation(slider);
        L.DomEvent.disableScrollPropagation(slider);
        {{this._parent.get_name()}}.getContainer().appendChild(slider);

        var layers = {
        {% for y,n in this.names.items() %}
            "{{y}}": {{n}},
        {% endfor %}
        };

        function show(y){
            Object.keys(layers).forEach(k=>{
                if(k===y){ if(!{{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.addLayer(layers[k]); }
                else { if({{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.removeLayer(layers[k]); }
            });
        }
        document.getElementById('ys').addEventListener('input',e=>{
            var y=e.target.value; document.getElementById('yl').innerHTML=y; show(y);
        });
        show(String({{this.val}}));
    {% endmacro %}
    """)

    def __init__(self, names: Dict[int,str], start:int, end:int):
        super().__init__()
        self.names, self.val, self.min, self.max = names, start, start, end

# ---------- 主 ----------
def main():
    m = folium.Map(location=[20,0], zoom_start=2, tiles='cartodbpositron')
    names = {}
    for y in tqdm(range(YEAR_START, YEAR_END+1), desc="Build layers"):
        df = load_year(FILE_CSV, y)
        layer = build_layer(df, y)
        layer.add_to(m)
        names[y] = layer.get_name()
        print(f"{y}: communities = {len(set(detect_communities(df).values()))}")

    m.add_child(YearSlider(names, YEAR_START, YEAR_END))
    m.get_root().html.add_child(folium.Element(
        f"<h4 style='text-align:center;font-size:20px'>Trade‐Community Map {YEAR_START}–{YEAR_END}</h4>"))

    # Adding Legend
    legend = """
        <div style="position: fixed; top: 50px; right: 100px; z-index: 9999; font-size: 14px; background: white; padding: 10px; border-radius: 5px;">
            <strong>Legend:</strong><br>
            <span style="color:red;">Cross-region connections</span>
        </div>
    """
    m.get_root().html.add_child(folium.Element(legend))

    out = f"imf_trade_community_{YEAR_START}_{YEAR_END}_E{TOP_K_EDGE}_R0.5.html"
    m.save(out)
    print("Saved →", out)

if __name__ == "__main__":
    main()


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   1%|▏         | 1/77 [00:00<00:44,  1.70it/s]

1948: communities = 5


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   3%|▎         | 2/77 [00:01<00:41,  1.82it/s]

1949: communities = 10


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   4%|▍         | 3/77 [00:01<00:39,  1.86it/s]

1950: communities = 8


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   5%|▌         | 4/77 [00:02<00:38,  1.89it/s]

1951: communities = 10


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   6%|▋         | 5/77 [00:02<00:38,  1.88it/s]

1952: communities = 7


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   8%|▊         | 6/77 [00:03<00:38,  1.86it/s]

1953: communities = 5


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:   9%|▉         | 7/77 [00:03<00:37,  1.87it/s]

1954: communities = 4


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  10%|█         | 8/77 [00:04<00:37,  1.86it/s]

1955: communities = 17


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  12%|█▏        | 9/77 [00:04<00:36,  1.87it/s]

1956: communities = 3


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  13%|█▎        | 10/77 [00:05<00:36,  1.83it/s]

1957: communities = 3


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  14%|█▍        | 11/77 [00:05<00:36,  1.82it/s]

1958: communities = 13


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  16%|█▌        | 12/77 [00:06<00:35,  1.81it/s]

1959: communities = 16


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  17%|█▋        | 13/77 [00:07<00:35,  1.80it/s]

1960: communities = 3


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  18%|█▊        | 14/77 [00:07<00:35,  1.79it/s]

1961: communities = 5


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  19%|█▉        | 15/77 [00:08<00:34,  1.80it/s]

1962: communities = 15


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  21%|██        | 16/77 [00:08<00:34,  1.79it/s]

1963: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  22%|██▏       | 17/77 [00:09<00:33,  1.77it/s]

1964: communities = 3


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  23%|██▎       | 18/77 [00:09<00:33,  1.76it/s]

1965: communities = 19


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  25%|██▍       | 19/77 [00:10<00:33,  1.72it/s]

1966: communities = 22


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  26%|██▌       | 20/77 [00:11<00:34,  1.65it/s]

1967: communities = 5


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  27%|██▋       | 21/77 [00:11<00:34,  1.64it/s]

1968: communities = 24


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  29%|██▊       | 22/77 [00:12<00:35,  1.54it/s]

1969: communities = 21


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  30%|██▉       | 23/77 [00:13<00:34,  1.56it/s]

1970: communities = 21


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  31%|███       | 24/77 [00:13<00:33,  1.57it/s]

1971: communities = 20


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  32%|███▏      | 25/77 [00:14<00:32,  1.61it/s]

1972: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  34%|███▍      | 26/77 [00:15<00:31,  1.60it/s]

1973: communities = 23


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  35%|███▌      | 27/77 [00:15<00:32,  1.54it/s]

1974: communities = 22


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  36%|███▋      | 28/77 [00:16<00:32,  1.53it/s]

1975: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  38%|███▊      | 29/77 [00:17<00:32,  1.47it/s]

1976: communities = 18


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  39%|███▉      | 30/77 [00:18<00:41,  1.12it/s]

1977: communities = 16


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  40%|████      | 31/77 [00:19<00:48,  1.05s/it]

1978: communities = 21


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  42%|████▏     | 32/77 [00:21<00:51,  1.15s/it]

1979: communities = 27


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  43%|████▎     | 33/77 [00:22<00:51,  1.17s/it]

1980: communities = 25


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  44%|████▍     | 34/77 [00:23<00:53,  1.25s/it]

1981: communities = 7


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  45%|████▌     | 35/77 [00:25<00:54,  1.29s/it]

1982: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  47%|████▋     | 36/77 [00:26<00:54,  1.32s/it]

1983: communities = 7


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  48%|████▊     | 37/77 [00:28<00:53,  1.34s/it]

1984: communities = 22


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  49%|████▉     | 38/77 [00:29<00:53,  1.37s/it]

1985: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  51%|█████     | 39/77 [00:30<00:52,  1.39s/it]

1986: communities = 22


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  52%|█████▏    | 40/77 [00:32<00:51,  1.40s/it]

1987: communities = 23


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  53%|█████▎    | 41/77 [00:34<00:52,  1.46s/it]

1988: communities = 16


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  55%|█████▍    | 42/77 [00:35<00:51,  1.46s/it]

1989: communities = 24


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  56%|█████▌    | 43/77 [00:37<00:50,  1.48s/it]

1990: communities = 25


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  57%|█████▋    | 44/77 [00:38<00:49,  1.50s/it]

1991: communities = 24


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  58%|█████▊    | 45/77 [00:40<00:48,  1.53s/it]

1992: communities = 24


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  60%|█████▉    | 46/77 [00:41<00:49,  1.61s/it]

1993: communities = 4


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  61%|██████    | 47/77 [00:43<00:50,  1.68s/it]

1994: communities = 29


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  62%|██████▏   | 48/77 [00:45<00:50,  1.73s/it]

1995: communities = 24


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  64%|██████▎   | 49/77 [00:47<00:49,  1.76s/it]

1996: communities = 10


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  65%|██████▍   | 50/77 [00:49<00:48,  1.81s/it]

1997: communities = 29


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  66%|██████▌   | 51/77 [00:51<00:48,  1.86s/it]

1998: communities = 7


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  68%|██████▊   | 52/77 [00:53<00:47,  1.92s/it]

1999: communities = 3


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  69%|██████▉   | 53/77 [00:55<00:48,  2.00s/it]

2000: communities = 8


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  70%|███████   | 54/77 [00:57<00:47,  2.05s/it]

2001: communities = 9


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  71%|███████▏  | 55/77 [01:00<00:46,  2.11s/it]

2002: communities = 11


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  73%|███████▎  | 56/77 [01:02<00:45,  2.15s/it]

2003: communities = 14


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  74%|███████▍  | 57/77 [01:04<00:43,  2.19s/it]

2004: communities = 7


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  75%|███████▌  | 58/77 [01:06<00:42,  2.23s/it]

2005: communities = 4


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  77%|███████▋  | 59/77 [01:09<00:40,  2.27s/it]

2006: communities = 10


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  78%|███████▊  | 60/77 [01:11<00:39,  2.31s/it]

2007: communities = 10


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  79%|███████▉  | 61/77 [01:14<00:37,  2.34s/it]

2008: communities = 34


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  81%|████████  | 62/77 [01:16<00:35,  2.37s/it]

2009: communities = 8


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  82%|████████▏ | 63/77 [01:18<00:33,  2.37s/it]

2010: communities = 36


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  83%|████████▎ | 64/77 [01:21<00:31,  2.45s/it]

2011: communities = 9


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  84%|████████▍ | 65/77 [01:23<00:29,  2.46s/it]

2012: communities = 8


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  86%|████████▌ | 66/77 [01:26<00:27,  2.48s/it]

2013: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  87%|████████▋ | 67/77 [01:29<00:24,  2.49s/it]

2014: communities = 7


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  88%|████████▊ | 68/77 [01:31<00:22,  2.49s/it]

2015: communities = 32


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  90%|████████▉ | 69/77 [01:34<00:20,  2.52s/it]

2016: communities = 28


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  91%|█████████ | 70/77 [01:36<00:17,  2.57s/it]

2017: communities = 29


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  92%|█████████▏| 71/77 [01:39<00:15,  2.55s/it]

2018: communities = 24


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  94%|█████████▎| 72/77 [01:41<00:12,  2.57s/it]

2019: communities = 6


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  95%|█████████▍| 73/77 [01:44<00:10,  2.56s/it]

2020: communities = 29


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  96%|█████████▌| 74/77 [01:47<00:07,  2.60s/it]

2021: communities = 34


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  97%|█████████▋| 75/77 [01:49<00:05,  2.63s/it]

2022: communities = 5


  cmap = cm.get_cmap(cmap_name, 20)
Build layers:  99%|█████████▊| 76/77 [01:52<00:02,  2.64s/it]

2023: communities = 5


  cmap = cm.get_cmap(cmap_name, 20)
Build layers: 100%|██████████| 77/77 [01:55<00:00,  1.50s/it]

2024: communities = 32





Saved → imf_trade_community_1948_2024_E500_R0.5.html


## 显示以某个国家所在的社区

In [40]:
# -*- coding: utf-8 -*-
"""
IMF IMTS CSV → 全球出口网络社区结构 (多年份滑块)
================================================

* 读取 IMF IMTS Timeseries-per-row CSV
* Louvain 社区划分 → 各社区着不同颜色
* Folium + 自定义年份滑块切换视图
"""

from __future__ import annotations
from pathlib import Path
from typing import Dict, Tuple

import pandas as pd
import networkx as nx
from community import community_louvain          # Louvain
import folium
from folium.plugins import AntPath
from branca.element import MacroElement, Template
from matplotlib import cm, colors
from tqdm import tqdm

# ---------- 参数 ----------
FILE_CSV   = "IMF_IMTS_Exports_1948_2024.csv"
COORD_CSV  = "iso3_to_latlon.csv"
YEAR_START, YEAR_END = 1948, 2024
TOP_K_EDGE = 2000          # 为减轻可视化负担，仅取前 K 条出口额边
EDGE_OPACITY = 1          # 边透明度（跨社区变淡）  
MIN_TRADE   = None        # 若想滤掉过小贸易额，可设阈值
COUNTRY = "JPN"            # 关注的国家（可选）
# ---------- 坐标 ----------
def load_coord(filepath: str | Path) -> Dict[str, Tuple[float, float]]:
    df = pd.read_csv(filepath)
    out = {}
    for _, r in df.iterrows():
        lat, lon = r.Lat, r.Lon
        if pd.isna(lat) or pd.isna(lon):
            continue  # 丢弃无效坐标
        out[r.ISO3] = (lat, lon)
    return out

coord = load_coord(COORD_CSV)

# ---------- 加载年度数据 ----------
def load_year(csv: str | Path, year: int) -> pd.DataFrame:
    cols = ["COUNTRY.ID", "COUNTERPART_COUNTRY.ID", str(year)]
    df = (pd.read_csv(csv, low_memory=False)[cols]
          .rename(columns={"COUNTRY.ID": "src",
                           "COUNTERPART_COUNTRY.ID": "dst",
                           str(year): "value"})  # ← 统一命名
          .dropna(subset=["value"]))  # 先丢掉值为空的行
    df = df[(df.value > 0) & (df.src != df.dst)]  # 排除自引用行

    # ----- 过滤 src 和 dst 是否有有效坐标 -----
    df = df[df['src'].isin(coord.keys()) & df['dst'].isin(coord.keys())]
    
    if MIN_TRADE is not None:
        df = df[df.value > MIN_TRADE]
    return df.reset_index(drop=True)

# ---------- 构建 Louvain 划分 ----------
def detect_communities(df: pd.DataFrame) -> Dict[str, int]:
    G = nx.Graph()
    for _, r in df.iterrows():
        G.add_edge(r.src, r.dst, weight=r.value)
    return community_louvain.best_partition(G, weight='weight')

def gen_palette(n_comm: int) -> list[str]:
    """
    生成不包含红色系的调色板（排除容易与跨区边红色混淆的颜色）。
    使用 matplotlib 的 tab20 / tab20b / tab20c 联合拼接。
    """
    exclude_hues = {"#d62728", "#e31a1c", "#ff7f0e", "#f15a60", "#f28e2b", "#e15759"}  # 红/橙/粉
    seen = set()
    palette = []

    # 联合多个 colormap，以确保足够颜色
    for cmap_name in ["tab20", "tab20b", "tab20c", "Set3"]:
        cmap = cm.get_cmap(cmap_name, 20)
        for i in range(cmap.N):
            hex_color = colors.rgb2hex(cmap(i))
            if hex_color in exclude_hues or hex_color in seen:
                continue
            seen.add(hex_color)
            palette.append(hex_color)
            if len(palette) >= n_comm:
                return palette

    raise ValueError("Too many communities – not enough safe colors without red tones.")

# ---------- 滑块 ----------
class YearSlider(MacroElement):
    _template = Template("""
    {% macro script(this,kwargs) %}
        var slider = L.DomUtil.create('div', 'year-slider');
        slider.innerHTML = '<input type="range" min="{{this.min}}" max="{{this.max}}" value="{{this.val}}" id="ys"> <span id="yl">{{this.val}}</span>';
        slider.style.position = 'fixed'; slider.style.bottom = '50px'; slider.style.left = '50px';
        slider.style.background='white'; slider.style.padding='8px'; slider.style.zIndex='9999';
        slider.style.borderRadius='4px'; slider.style.boxShadow='0 0 10px rgba(0,0,0,.3)';
        L.DomEvent.disableClickPropagation(slider);
        L.DomEvent.disableScrollPropagation(slider);
        {{this._parent.get_name()}}.getContainer().appendChild(slider);

        var layers = {
        {% for y,n in this.names.items() %}
            "{{y}}": {{n}},
        {% endfor %}
        };

        function show(y){
            Object.keys(layers).forEach(k=>{
                if(k===y){ if(!{{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.addLayer(layers[k]); }
                else { if({{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.removeLayer(layers[k]); }
            });
        }
        document.getElementById('ys').addEventListener('input',e=>{
            var y=e.target.value; document.getElementById('yl').innerHTML=y; show(y);
        });
        show(String({{this.val}}));
    {% endmacro %}
    """)

    def __init__(self, names: Dict[int,str], start:int, end:int):
        super().__init__()
        self.names, self.val, self.min, self.max = names, start, start, end

# ------------------------------------------------------------
# ★★ 新增：仅绘制「focus_iso 所在社区 + 与 focus_iso 相连的社区外节点」的图层
# ------------------------------------------------------------
def build_focus_layer(df: pd.DataFrame,
                      year: int,
                      focus_iso: str = "CHN",
                      top_k: int = TOP_K_EDGE) -> folium.FeatureGroup:
    """
    只绘制 focus_iso 所在 Louvain 社区中的节点，以及社区外与 focus_iso 有边的节点，
    并绘制这些节点相关的边（社区内全连 + focus_iso 的跨社区连边）。
    """
    df = df.nlargest(top_k, "value")          # 取前 K 条边
    part = detect_communities(df)             # Louvain 划分
    if focus_iso not in part:
        raise ValueError(f"{focus_iso} not found in partition for year {year}")

    comm_ids      = sorted(set(part.values()))
    palette       = gen_palette(len(comm_ids))
    comm_color    = {cid: palette[i] for i, cid in enumerate(comm_ids)}

    focus_comm_id = part[focus_iso]

    # ---- 1) 要绘制的节点 --------------------------------------------------
    nodes_comm    = {n for n, c in part.items() if c == focus_comm_id}
    nodes_external= {r.dst for _, r in df.iterrows() if r.src == focus_iso and part[r.dst] != focus_comm_id} | \
                    {r.src for _, r in df.iterrows() if r.dst == focus_iso and part[r.src] != focus_comm_id}
    nodes_draw    = nodes_comm | nodes_external

    # ---- 2) 要绘制的边 ----------------------------------------------------
    def keep_edge(row) -> bool:
        a, b = row.src, row.dst
        # 社区内边 or focus_iso 跨社区边
        if a in nodes_comm and b in nodes_comm:
            return True
        if (a == focus_iso and b in nodes_external) or (b == focus_iso and a in nodes_external):
            return True
        return False

    edges_df = df[df.apply(keep_edge, axis=1)]

    # ---- 3) 通用映射函数 ---------------------------------------------------
    vmin, vmax = edges_df["value"].min(), edges_df["value"].max()
    norm       = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    edge_weight= lambda val: 0.6 + 3 * norm(val)
    out_sum    = df.groupby("src")["value"].sum()

    fg = folium.FeatureGroup(name=f"{year} Focus ({focus_iso})", show=True)

    # ---- 3a. 绘节点 -------------------------------------------------------
    for iso in nodes_draw:
        if iso not in coord:         # 理论上已过滤，但再保险
            continue
        lat, lon = coord[iso]
        if pd.isna(lat):             # 跳过无坐标
            continue
        radius = max(3, (out_sum.get(iso, 0) ** 0.5) / 150)
        color  = comm_color[part[iso]]
        folium.CircleMarker(
            [lat, lon], radius=radius,
            color=color, weight=2,
            fill=True, fill_color=color, fill_opacity=0.7,
            tooltip=f"{iso} (comm {part[iso]})\n{out_sum.get(iso,0):,.0f} USD exports"
        ).add_to(fg)

    # ---- 3b. 绘边 ---------------------------------------------------------
    for _, r in edges_df.iterrows():
        lat1, lon1 = coord[r.src]
        lat2, lon2 = coord[r.dst]
        same_comm  = part[r.src] == part[r.dst]
        edge_col   = comm_color[part[r.src]] if same_comm else "red"
        opacity    = 1.0 if same_comm else 0.5
        folium.PolyLine(
            [[lat1, lon1], [lat2, lon2]],
            weight=edge_weight(r.value),
            color=edge_col,
            opacity=opacity
        ).add_to(fg)

    return fg

def main_focus_slider(focus_iso: str = "CHN",
                      year_start: int = 2020,
                      year_end: int = 2024,
                      top_k: int = TOP_K_EDGE):
    """
    多年份滑动视图：每年只绘制 focus_iso 所在社区 + 与其相连的跨社区节点及边
    """
    m = folium.Map(location=[20, 0], zoom_start=2, tiles='cartodbpositron')
    names = {}

    for y in tqdm(range(year_start, year_end + 1), desc=f"Building focus layers for {focus_iso}"):
        df = load_year(FILE_CSV, y)
        try:
            fg = build_focus_layer(df, year=y, focus_iso=focus_iso, top_k=top_k)
        except ValueError:
            print(f"⚠️ {focus_iso} not found in year {y}, skipping...")
            continue
        fg.add_to(m)
        names[y] = fg.get_name()

    # 添加滑块
    m.add_child(YearSlider(names, year_start, year_end))

    # 添加标题和图例
    m.get_root().html.add_child(folium.Element(
        f"<h4 style='text-align:center;font-size:20px'>Trade Focus: {focus_iso} ({year_start}–{year_end})</h4>"
    ))

    legend = f"""
        <div style="position: fixed; top: 50px; right: 80px; z-index: 9999;
                    background: white; padding: 10px 14px; border-radius: 6px;
                    font-size: 13px; box-shadow: 0 0 8px rgba(0,0,0,.3);">
            <b>Legend</b><br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:red;margin-right:4px;"></span>
            Cross-community links from {focus_iso}<br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:#666;margin-right:4px;border:1px solid #666;"></span>
            Same-community links
        </div>
    """
    m.get_root().html.add_child(folium.Element(legend))

    out = f"focus_slider_{focus_iso}_{year_start}_{year_end}.html"
    m.save(out)
    print("✅ Saved →", out)

if __name__ == "__main__":
    main_focus_slider(focus_iso=COUNTRY, year_start=YEAR_START, year_end=YEAR_END, top_k=TOP_K_EDGE)



  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_na

✅ Saved → focus_slider_USA_1948_2024.html
