In [13]:
# -*- coding: utf-8 -*-
"""
IMF IMTS CSV → 全球出口网络社区结构 (多年份滑块)
================================================

* 读取 IMF IMTS Timeseries-per-row CSV
* Louvain 社区划分 → 各社区着不同颜色
* Folium + 自定义年份滑块切换视图
"""

from __future__ import annotations
from pathlib import Path
from typing import Dict, Tuple

import pandas as pd
import networkx as nx
from community import community_louvain          # Louvain
import folium
from folium.plugins import AntPath
from branca.element import MacroElement, Template
from matplotlib import cm, colors
from tqdm import tqdm

# ---------- 参数 ----------
FILE_CSV   = "E:\Obsidian\Project\专业课\大三下\Intro to Networks\Final_Essay\IMF_IMTS_Exports_1948_2024.csv"
COORD_CSV  = "E:\Obsidian\Project\专业课\大三下\Intro to Networks\Final_Essay\iso3_to_latlon.csv"
YEAR_START, YEAR_END = 1948, 2024
TOP_K_EDGE = 2000          # 为减轻可视化负担，仅取前 K 条出口额边
EDGE_OPACITY = 1          # 边透明度（跨社区变淡）  
MIN_TRADE   = None        # 若想滤掉过小贸易额，可设阈值
COUNTRY = "CHN"            # 关注的国家（可选）
# ---------- 坐标 ----------
def load_coord(filepath: str | Path) -> Dict[str, Tuple[float, float]]:
    df = pd.read_csv(filepath)
    out = {}
    for _, r in df.iterrows():
        lat, lon = r.Lat, r.Lon
        if pd.isna(lat) or pd.isna(lon):
            continue  # 丢弃无效坐标
        out[r.ISO3] = (lat, lon)
    return out

coord = load_coord(COORD_CSV)

# ---------- 加载年度数据 ----------
def load_year(csv: str | Path, year: int) -> pd.DataFrame:
    cols = ["COUNTRY.ID", "COUNTERPART_COUNTRY.ID", str(year)]
    df = (pd.read_csv(csv, low_memory=False)[cols]
          .rename(columns={"COUNTRY.ID": "src",
                           "COUNTERPART_COUNTRY.ID": "dst",
                           str(year): "value"})  # ← 统一命名
          .dropna(subset=["value"]))  # 先丢掉值为空的行
    df = df[(df.value > 0) & (df.src != df.dst)]  # 排除自引用行

    # ----- 过滤 src 和 dst 是否有有效坐标 -----
    df = df[df['src'].isin(coord.keys()) & df['dst'].isin(coord.keys())]
    
    if MIN_TRADE is not None:
        df = df[df.value > MIN_TRADE]
    return df.reset_index(drop=True)

# ---------- 构建 Louvain 划分 ----------
def detect_communities(df: pd.DataFrame) -> Dict[str, int]:
    G = nx.Graph()
    for _, r in df.iterrows():
        G.add_edge(r.src, r.dst, weight=r.value)
    return community_louvain.best_partition(G, weight='weight',resolution=0.5)

def gen_palette(n_comm: int) -> list[str]:
    """
    生成不包含红色系的调色板（排除容易与跨区边红色混淆的颜色）。
    使用 matplotlib 的 tab20 / tab20b / tab20c 联合拼接。
    """
    exclude_hues = {"#ffffff"}  # 红/橙/粉, "#e31a1c", "#ff7f0e", "#f15a60", "#f28e2b", "#e15759"
    seen = set()
    palette = []

    # 联合多个 colormap，以确保足够颜色
    for cmap_name in ["tab20", "tab20b", "tab20c", "Set3"]:
        cmap = cm.get_cmap(cmap_name, 20)
        for i in range(cmap.N):
            hex_color = colors.rgb2hex(cmap(i))
            if hex_color in exclude_hues or hex_color in seen:
                continue
            seen.add(hex_color)
            palette.append(hex_color)
            if len(palette) >= n_comm:
                return palette

    raise ValueError("Too many communities – not enough safe colors without red tones.")

# ---------- 滑块 ----------
class YearSlider(MacroElement):
    _template = Template("""
    {% macro script(this,kwargs) %}
        var slider = L.DomUtil.create('div', 'year-slider');
        slider.innerHTML = '<input type="range" min="{{this.min}}" max="{{this.max}}" value="{{this.val}}" id="ys"> <span id="yl">{{this.val}}</span>';
        slider.style.position = 'fixed'; slider.style.bottom = '50px'; slider.style.left = '50px';
        slider.style.background='white'; slider.style.padding='8px'; slider.style.zIndex='9999';
        slider.style.borderRadius='4px'; slider.style.boxShadow='0 0 10px rgba(0,0,0,.3)';
        L.DomEvent.disableClickPropagation(slider);
        L.DomEvent.disableScrollPropagation(slider);
        {{this._parent.get_name()}}.getContainer().appendChild(slider);

        var layers = {
        {% for y,n in this.names.items() %}
            "{{y}}": {{n}},
        {% endfor %}
        };

        function show(y){
            Object.keys(layers).forEach(k=>{
                if(k===y){ if(!{{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.addLayer(layers[k]); }
                else { if({{this._parent.get_name()}}.hasLayer(layers[k])) {{this._parent.get_name()}}.removeLayer(layers[k]); }
            });
        }
        document.getElementById('ys').addEventListener('input',e=>{
            var y=e.target.value; document.getElementById('yl').innerHTML=y; show(y);
        });
        show(String({{this.val}}));
    {% endmacro %}
    """)

    def __init__(self, names: Dict[int,str], start:int, end:int):
        super().__init__()
        self.names, self.val, self.min, self.max = names, start, start, end

# ------------------------------------------------------------
# ★★ 新增：仅绘制「focus_iso 所在社区 + 与 focus_iso 相连的社区外节点」的图层
# ------------------------------------------------------------
def build_focus_layer(df: pd.DataFrame,
                      year: int,
                      focus_iso: str = "CHN",
                      top_k: int = TOP_K_EDGE) -> folium.FeatureGroup:
    """
    只绘制 focus_iso 所在 Louvain 社区中的节点，以及社区外与 focus_iso 有边的节点，
    并绘制这些节点相关的边（社区内全连 + focus_iso 的跨社区连边）。
    """
    df = df.nlargest(top_k, "value")          # 取前 K 条边
    part = detect_communities(df)             # Louvain 划分
    if focus_iso not in part:
        raise ValueError(f"{focus_iso} not found in partition for year {year}")

    comm_ids      = sorted(set(part.values()))
    palette       = gen_palette(len(comm_ids))
    comm_color    = {cid: palette[i] for i, cid in enumerate(comm_ids)}

    focus_comm_id = part[focus_iso]

    # ---- 1) 要绘制的节点 --------------------------------------------------
    nodes_comm    = {n for n, c in part.items() if c == focus_comm_id}
    #nodes_external= {r.dst for _, r in df.iterrows() if r.src == focus_iso and part[r.dst] != focus_comm_id} | \
    #                {r.src for _, r in df.iterrows() if r.dst == focus_iso and part[r.src] != focus_comm_id}
    nodes_draw    = nodes_comm #| nodes_external

    # ---- 2) 要绘制的边 ----------------------------------------------------
    def keep_edge(row) -> bool:
        a, b = row.src, row.dst
        # 社区内边 or focus_iso 跨社区边
        if a in nodes_comm and b in nodes_comm:
            return True
        #if (a == focus_iso and b in nodes_external) or (b == focus_iso and a in nodes_external):
        #    return True
        return False

    edges_df = df[df.apply(keep_edge, axis=1)]

    # ---- 3) 通用映射函数 ---------------------------------------------------
    vmin, vmax = edges_df["value"].min(), edges_df["value"].max()
    norm       = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    edge_weight= lambda val: 0.6 + 3 * norm(val)
    out_sum    = df.groupby("src")["value"].sum()

    fg = folium.FeatureGroup(name=f"{year} Focus ({focus_iso})", show=True)

    # ---- 3a. 绘节点 -------------------------------------------------------
    world_total = out_sum.sum()  # 所有国家总出口额
    scale_factor = 100  # 可调系数：影响整体节点大小，建议在 30~100 之间微调

    for iso in nodes_draw:
        if iso not in coord:         # 理论上已过滤，但再保险
            continue
        lat, lon = coord[iso]
        if pd.isna(lat):             # 跳过无坐标
            continue

        total = out_sum.get(iso, 0)
        ratio = total / world_total
        radius = max(1, ratio * scale_factor)

        color  = comm_color[part[iso]]
        folium.CircleMarker(
            [lat, lon], radius=radius,
            color=color, weight=2,
            fill=True, fill_color=color, fill_opacity=0.7,
            tooltip=f"{iso} (comm {part[iso]})\n{out_sum.get(iso,0):,.0f} USD exports"
        ).add_to(fg)

    # ---- 3b. 绘边 ---------------------------------------------------------
    for _, r in edges_df.iterrows():
        lat1, lon1 = coord[r.src]
        lat2, lon2 = coord[r.dst]
        same_comm  = part[r.src] == part[r.dst]
        edge_col   = comm_color[part[r.src]] if same_comm else "red"
        opacity    = 1.0 if same_comm else 0.5
        folium.PolyLine(
            [[lat1, lon1], [lat2, lon2]],
            weight=edge_weight(r.value),
            color=edge_col,
            opacity=opacity
        ).add_to(fg)

    return fg

def main_focus_slider(focus_iso: str = "CHN",
                      year_start: int = 2020,
                      year_end: int = 2024,
                      top_k: int = TOP_K_EDGE):
    """
    多年份滑动视图：每年只绘制 focus_iso 所在社区 + 与其相连的跨社区节点及边
    """
    m = folium.Map(location=[20, 0], zoom_start=2, tiles='cartodbpositron')
    names = {}

    for y in tqdm(range(year_start, year_end + 1), desc=f"Building focus layers for {focus_iso}"):
        df = load_year(FILE_CSV, y)
        try:
            fg = build_focus_layer(df, year=y, focus_iso=focus_iso, top_k=top_k)
        except ValueError:
            print(f"⚠️ {focus_iso} not found in year {y}, skipping...")
            continue
        fg.add_to(m)
        names[y] = fg.get_name()

    # 添加滑块
    m.add_child(YearSlider(names, year_start, year_end))

    # 添加标题和图例
    m.get_root().html.add_child(folium.Element(
        f"<h4 style='text-align:center;font-size:20px'>Trade Focus: {focus_iso} ({year_start}–{year_end})</h4>"
    ))

    legend = f"""
        <div style="position: fixed; top: 50px; right: 80px; z-index: 9999;
                    background: white; padding: 10px 14px; border-radius: 6px;
                    font-size: 13px; box-shadow: 0 0 8px rgba(0,0,0,.3);">
            <b>Legend</b><br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:red;margin-right:4px;"></span>
            Cross-community links from {focus_iso}<br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:#666;margin-right:4px;border:1px solid #666;"></span>
            Same-community links
        </div>
    """
    m.get_root().html.add_child(folium.Element(legend))

    out = f"no_outsider_focus_slider_{focus_iso}_{year_start}_{year_end}.html"
    m.save(out)
    print("✅ Saved →", out)


  FILE_CSV   = "E:\Obsidian\Project\专业课\大三下\Intro to Networks\Final_Essay\IMF_IMTS_Exports_1948_2024.csv"
  COORD_CSV  = "E:\Obsidian\Project\专业课\大三下\Intro to Networks\Final_Essay\iso3_to_latlon.csv"


In [18]:
# ------------------------------------------------------------
# ★★ 只绘制 “focus_iso 所在社区 + 跨社区连接节点” 的年度图层（进出口总额版）
# ------------------------------------------------------------
def build_focus_layer(df: pd.DataFrame,
                      year: int,
                      focus_iso: str = "CHN",
                      top_k: int = TOP_K_EDGE) -> folium.FeatureGroup:
    """
    - 先将单向出口表 → 无向双向总额 (src↔dst) 表
    - Louvain 社区基于该无向表
    - 仅绘制 focus_iso 所在社区节点 + 与 focus_iso 相连的跨社区节点
    - 节点大小 ∝ 进出口总额，边粗细 ∝ 双向贸易额
    """
    # 1️⃣ 取前 top_k 单向出口额，再合并为无向“进出口总额”
    df = df.nlargest(top_k, "value").copy()
    df["pair"] = df.apply(lambda r: tuple(sorted((r.src, r.dst))), axis=1)
    undirected = (df.groupby("pair", as_index=False)["value"]
                    .sum()                                   # 双向进出口额
                    .rename(columns={"value": "trade"}))
    undirected[["src", "dst"]] = undirected["pair"].apply(lambda p: pd.Series(p))

    # 2️⃣ 社区划分
    part = detect_communities(
        undirected[["src", "dst", "trade"]].rename(columns={"trade": "value"})
    )
    if focus_iso not in part:
        raise ValueError(f"{focus_iso} not found in partition for year {year}")

    comm_ids   = sorted(set(part.values()))
    palette    = gen_palette(len(comm_ids))
    comm_color = {cid: palette[i] for i, cid in enumerate(comm_ids)}
    focus_comm = part[focus_iso]

    # 3️⃣ 需要绘制的节点集合
    nodes_comm = {n for n, c in part.items() if c == focus_comm}
    # nodes_ext  = {r.dst for _, r in undirected.iterrows() if r.src == focus_iso and part[r.dst] != focus_comm} | \
    #              {r.src for _, r in undirected.iterrows() if r.dst == focus_iso and part[r.src] != focus_comm}
    nodes_draw = nodes_comm #| nodes_ext

    # 4️⃣ 需要绘制的边
    def keep_edge(row) -> bool:
        a, b = row.src, row.dst
        # 社区内边
        if (a in nodes_comm) and (b in nodes_comm):
            return True
        # focus_iso 跨社区边
        # if focus_iso in (a, b):
        #     other = b if a == focus_iso else a
        #     if other in nodes_ext:
        #         return True
        return False


    edges_df = undirected[undirected.apply(keep_edge, axis=1)]

    # 5️⃣ 映射函数
    vmin, vmax  = edges_df["trade"].min(), edges_df["trade"].max()
    norm        = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    edge_weight = lambda v: 0.6 + 3 * norm(v)

    # 6️⃣ 节点进出口总额（用于半径）
    out_sum = df.groupby("src")["value"].sum()
    in_sum  = df.groupby("dst")["value"].sum()
    total_trade = out_sum.add(in_sum, fill_value=0)
    world_total = total_trade.sum()
    scale_factor = 100

    # 7️⃣ Folium 图层
    fg = folium.FeatureGroup(name=f"{year} Focus ({focus_iso})", show=True)

    # 7a. 绘制节点
    for iso in nodes_draw:
        if iso not in coord: continue
        lat, lon = coord[iso]
        if pd.isna(lat): continue

        radius = max(1, (total_trade.get(iso, 0) / world_total) * scale_factor)
        color = "red"  # 默认red
        #color  = comm_color[part[iso]] if iso in part else "#666666"

        folium.CircleMarker(
            [lat, lon], radius=radius,
            color=color, weight=2,
            fill=True, fill_color=color, fill_opacity=0.7,
            tooltip=f"{iso} (comm {part.get(iso,'?')})\n{total_trade.get(iso,0):,.0f} USD trade"
        ).add_to(fg)

    # 7b. 绘制边
    for _, r in edges_df.iterrows():
        lat1, lon1 = coord[r.src]
        lat2, lon2 = coord[r.dst]
        same_comm  = part[r.src] == part[r.dst]
        edge_col = "red"
        #edge_col   = comm_color[part[r.src]] if same_comm else "red"
        opacity    = 1.0 if same_comm else 0.5

        folium.PolyLine(
            [[lat1, lon1], [lat2, lon2]],
            weight=edge_weight(r.trade),
            color=edge_col,
            opacity=opacity
        ).add_to(fg)

    return fg


# ------------------------------------------------------------
# ★★ 多年份滑块视图（进出口总额版）
# ------------------------------------------------------------
def main_focus_slider(focus_iso: str = "CHN",
                      year_start: int = 2020,
                      year_end: int = 2024,
                      top_k: int = TOP_K_EDGE):
    """
    为 focus_iso 生成多年份滑块视图：仅展示其所在社区及关键跨社区连接
    """
    m = folium.Map(location=[20, 0], zoom_start=2, tiles='cartodbpositron')
    names = {}

    for y in tqdm(range(year_start, year_end + 1),
                  desc=f"Building focus layers for {focus_iso}"):
        df = load_year(FILE_CSV, y)           # 单向出口表
        try:
            fg = build_focus_layer(df, year=y,
                                   focus_iso=focus_iso,
                                   top_k=top_k)
        except ValueError:
            print(f"⚠️ {focus_iso} not found in year {y}, skipping...")
            continue
        fg.add_to(m)
        names[y] = fg.get_name()

    # 滑块
    m.add_child(YearSlider(names, year_start, year_end))

    # 标题
    m.get_root().html.add_child(folium.Element(
        f"<h4 style='text-align:center;font-size:20px'>"
        f"Trade Focus: {focus_iso} ({year_start}–{year_end})</h4>"
    ))

    # 图例
    legend = f"""
        <div style="position: fixed; top: 50px; right: 80px; z-index: 9999;
                    background: white; padding: 10px 14px; border-radius: 6px;
                    font-size: 13px; box-shadow: 0 0 8px rgba(0,0,0,.3);">
            <b>Legend</b><br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:red;margin-right:4px;"></span>
            Cross-community links from {focus_iso}
        </div>
    """
    #m.get_root().html.add_child(folium.Element(legend))

    out = f"no_outsider_focus_slider_{focus_iso}_{year_start}_{year_end}.html"
    m.save(out)
    print("✅ Saved →", out)


In [19]:

if __name__ == "__main__":
    countries = ["CHN","USA"] #, "USA", "DEU", "JPN", "KOR"
    year_start = 1948
    year_end = 2024
    top_k = 2000  # 可调节显示边数

    for country in countries:
        print(f"Generating map for {country}...")
        main_focus_slider(focus_iso=country, year_start=year_start, year_end=year_end, top_k=top_k)


Generating map for CHN...


  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_na

✅ Saved → no_outsider_focus_slider_CHN_1948_2024.html
Generating map for USA...


  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_name, 20)
  cmap = cm.get_cmap(cmap_na

✅ Saved → no_outsider_focus_slider_USA_1948_2024.html


# 某国前N大贸易伙伴

In [2]:
def build_topn_partners_layer(
    df_year: pd.DataFrame,          # ← 当年“完整”数据
    year: int,
    focus_iso: str = "CHN",
    top_n: int = 10,
    global_top_k: int | None = None # 若担心分区太慢，可只用最粗 k 条边做社区
) -> folium.FeatureGroup:
    """
    先对 *全局* 贸易网 Louvain 划分 → 得到每个国家的社区，
    然后仅绘制 focus_iso 的前 n 大出口伙伴及相应边。
    """
    # ---------- 1) 全局社区划分 ----------
    df_part = df_year if global_top_k is None else df_year.nlargest(global_top_k, "value")
    part    = detect_communities(df_part)      # {ISO: comm_id}

    # ---------- 2) 抽取 focus_iso → 前 n 边 ----------
    df_focus = df_year[df_year.src == focus_iso].nlargest(top_n, "value")
    if df_focus.empty:
        raise ValueError(f"{focus_iso} has no exports in year {year}")

    # ---------- 3) 调色板 ----------
    comm_ids   = sorted(set(part.values()))
    palette    = gen_palette(len(comm_ids))
    comm_color = {cid: palette[i] for i, cid in enumerate(comm_ids)}

    # ---------- 4) 节点半径依据全局进/出口总额 ----------
    out_sum = df_year.groupby("src")["value"].sum()
    in_sum  = df_year.groupby("dst")["value"].sum()

    # ---------- 5) 连边粗细 ----------
    vmin, vmax   = df_focus["value"].min(), df_focus["value"].max()
    norm         = colors.Normalize(vmin=max(vmin, 1), vmax=vmax)
    edge_weight  = lambda val: 0.6 + 3 * norm(val)

    fg = folium.FeatureGroup(name=f"{year} Top {top_n} partners of {focus_iso}", show=True)

    # ---------- 6) 绘节点 ----------
    draw_nodes = set(df_focus["dst"]).union({focus_iso})
    for iso in draw_nodes:
        if iso not in coord or iso not in part:
            continue
        lat, lon = coord[iso]
        if pd.isna(lat):                         # 无坐标
            continue
        total   = out_sum.get(iso, 0) + in_sum.get(iso, 0)
        radius  = max(3, (total ** 0.5) / 150)
        color   = comm_color[part[iso]]

        folium.CircleMarker(
            [lat, lon], radius=radius,
            color=color, weight=2,
            fill=True, fill_color=color, fill_opacity=0.75,
            tooltip=f"{iso} (comm {part[iso]})"
        ).add_to(fg)

    # ---------- 7) 绘边 ----------
    for _, r in df_focus.iterrows():
        if r.dst not in coord or r.src not in coord:
            continue
        lat1, lon1 = coord[r.src]
        lat2, lon2 = coord[r.dst]
        color      = comm_color[part[r.dst]]    # ← 伙伴国的社区色

        folium.PolyLine(
            [[lat1, lon1], [lat2, lon2]],
            weight=edge_weight(r.value),
            color=color,
            opacity=1.0
        ).add_to(fg)

    return fg


def main_topn_partners_slider(
    focus_iso: str = "CHN",
    year_start: int = 2010,
    year_end: int = 2024,
    top_n: int = 10
):
    """
    多年份滑动视图：每年显示 focus_iso 的前 n 个贸易伙伴，
    连边颜色为对方国家的社区色，节点颜色为各自社区色。
    """
    m = folium.Map(location=[20, 0], zoom_start=2, tiles='cartodbpositron')
    names = {}

    for y in tqdm(range(year_start, year_end + 1), desc=f"Building top {top_n} partners for {focus_iso}"):
        df = load_year(FILE_CSV, y)
        try:
            fg = build_topn_partners_layer(df, year=y, focus_iso=focus_iso, top_n=top_n)
        except ValueError:
            print(f"⚠️ {focus_iso} not found in year {y}, skipping...")
            continue
        fg.add_to(m)
        names[y] = fg.get_name()

    # 添加年份滑块
    m.add_child(YearSlider(names, year_start, year_end))

    # 添加标题和图例
    m.get_root().html.add_child(folium.Element(
        f"<h4 style='text-align:center;font-size:20px'>Top {top_n} Trade Partners of {focus_iso} ({year_start}–{year_end})</h4>"
    ))

    legend = f"""
        <div style="position: fixed; top: 50px; right: 80px; z-index: 9999;
                    background: white; padding: 10px 14px; border-radius: 6px;
                    font-size: 13px; box-shadow: 0 0 8px rgba(0,0,0,.3);">
            <b>Legend</b><br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:#666;margin-right:4px;border:1px solid #666;"></span>
            Nodes: community color<br>
            <span style="display:inline-block;width:12px;height:12px;
                         background:blue;margin-right:4px;"></span>
            Edge: partner’s community color
        </div>
    """
    m.get_root().html.add_child(folium.Element(legend))

    out = f"topn_slider_{focus_iso}_{year_start}_{year_end}.html"
    m.save(out)
    print(f"✅ Saved → {out}")


NameError: name 'pd' is not defined

In [1]:
if __name__ == "__main__":
    main_topn_partners_slider(focus_iso="CHN", year_start=1948, year_end=2024, top_n=25)


NameError: name 'main_topn_partners_slider' is not defined