# Q1 - 描述性分析（贸易数据）
本笔记本完成：读取 `raw` 下的贸易与参考文件，生成近三年（2016-2018）的描述性统计，匹配产品与国家代码，计算中美的主要贸易伙伴，并基于国家形状文件计算中国重心与其它国家重心之间的距离并画出对数散点图。

注：如果运行环境缺少包，下面的单元会尝试安装 `geopandas`、`pyproj` 和绘图所需库。

In [1]:
# 导入常用库并读取三年贸易数据
import sys
import importlib
import os
import glob
from pathlib import Path
import pandas as pd
import numpy as np
import geopandas as gpd
from pyproj import Geod
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

sns.set(style='whitegrid', context='notebook')

# 自动解析仓库根目录，兼容从不同工作目录运行
NOTEBOOK_DIR = Path.cwd().resolve()
PROJECT_ROOT_CANDIDATES = [
    NOTEBOOK_DIR,
    NOTEBOOK_DIR.parent,
    NOTEBOOK_DIR.parent.parent,
]
PROJECT_ROOT = None
for candidate in PROJECT_ROOT_CANDIDATES:
    if (candidate / 'raw' / 'trade_data').exists():
        PROJECT_ROOT = candidate
        break
if PROJECT_ROOT is None:
    raise FileNotFoundError('无法定位 raw/trade_data 目录，请确认当前工作目录位于仓库内。')

DATA_DIR = PROJECT_ROOT / 'raw' / 'trade_data'
SHAPE_DIR = PROJECT_ROOT / 'raw' / 'countries_shapefile'
print('Using project root:', PROJECT_ROOT)


def normalize_numeric_code(value: object) -> str | None:
    """将 ISO 数字代码统一为 3 位字符串（保留字母代码以便后续匹配）。"""
    if value is None or (isinstance(value, float) and np.isnan(value)):
        return None
    text = str(value).strip()
    if text == '':
        return None
    cleaned = text.replace(',', '')
    try:
        num = int(float(cleaned))
        return f"{num:03d}"
    except ValueError:
        return cleaned.upper()


# 定位 2016-2018 的 BACI CSV
candidates = [
    DATA_DIR / 'baci_hs12_y2016_v202001.csv',
    DATA_DIR / 'baci_HS12_y2017_v202001.csv',
    DATA_DIR / 'baci_hs12_y2018_v202001.csv',
]
files = [f for f in candidates if f.exists()]
if not files:
    files = [
        Path(p) for p in glob.glob(str(DATA_DIR / '*.csv'))
        if 'baci' in Path(p).name.lower()
        and 'country_codes' not in Path(p).name.lower()
        and 'product_codes' not in Path(p).name.lower()
    ]
if not files:
    raise FileNotFoundError(f'未在 {DATA_DIR} 下找到 BACI 贸易数据 CSV，请确认文件已解压。')
print('Using trade files:', files)

# 读取并清洗
frames: list[pd.DataFrame] = []
for fpath in tqdm(files, desc='Reading trade CSVs'):
    print('Reading', fpath)
    d = pd.read_csv(fpath, dtype=str)
    d.columns = [c.strip().lower() for c in d.columns]
    colmap = {}
    for c in d.columns:
        if c in ['t', 'year', 'y']:
            colmap[c] = 't'
        if c in ['i', 'exp', 'exporter', 'iso_o', 'iso_o3']:
            colmap[c] = 'i'
        if c in ['j', 'imp', 'importer', 'iso_d', 'iso_d3']:
            colmap[c] = 'j'
        if c in ['k', 'hs6', 'prod', 'product']:
            colmap[c] = 'k'
        if c in ['v', 'value', 'trade_value', 'trade_value_usd']:
            colmap[c] = 'v'
        if c in ['q', 'qty', 'quantity']:
            colmap[c] = 'q'
    d = d.rename(columns=colmap)
    for required in ['t', 'i', 'j', 'k', 'v']:
        if required not in d.columns:
            d[required] = np.nan
    d['k'] = d['k'].astype(str).str.zfill(6)
    for numeric_col in ['v', 'q']:
        if numeric_col in d.columns:
            d[numeric_col] = (
                d[numeric_col]
                .astype(str)
                .str.replace(',', '')
                .replace({'': '0', 'nan': '0'})
            )
            d[numeric_col] = pd.to_numeric(d[numeric_col], errors='coerce').fillna(0.0)
        else:
            d[numeric_col] = 0.0
    frames.append(d[['t', 'i', 'j', 'k', 'v', 'q']])

trade = pd.concat(frames, ignore_index=True, sort=False)
trade['t'] = pd.to_numeric(trade['t'], errors='coerce')
trade = trade[trade['t'].isin([2016, 2017, 2018])].copy()
trade['exporter_code'] = trade['i'].apply(normalize_numeric_code)
trade['importer_code'] = trade['j'].apply(normalize_numeric_code)
trade['hs6'] = trade['k'].astype(str).str.zfill(6)
trade['hs2'] = trade['hs6'].str[:2]
trade = trade.dropna(subset=['t', 'exporter_code', 'importer_code', 'hs6'])
trade = trade.drop_duplicates()
trade['q'] = trade['q'].fillna(0.0)
trade = trade.reset_index(drop=True)
trade.shape

Using project root: /Users/kaibiaozhu/Documents/GitHub/course5020-finalproject
Using trade files: [PosixPath('/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_hs12_y2016_v202001.csv'), PosixPath('/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_HS12_y2017_v202001.csv'), PosixPath('/Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_hs12_y2018_v202001.csv')]


Reading trade CSVs:   0%|          | 0/3 [00:00<?, ?it/s]

Reading /Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_hs12_y2016_v202001.csv
Reading /Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_HS12_y2017_v202001.csv
Reading /Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_HS12_y2017_v202001.csv
Reading /Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_hs12_y2018_v202001.csv
Reading /Users/kaibiaozhu/Documents/GitHub/course5020-finalproject/raw/trade_data/baci_hs12_y2018_v202001.csv


(24025656, 10)

In [2]:
# 基本数据质量检查与清洗记录
print(f'Total observations after filtering: {len(trade):,}')
print('Rows per year:')
print(trade['t'].value_counts().sort_index())

key_cols = ['t', 'exporter_code', 'importer_code', 'hs6', 'v', 'q']
missing_summary = trade[key_cols].isna().sum()
print('\nMissing values in key columns:')
print(missing_summary)

neg_value_rows = (trade['v'] < 0).sum()
if neg_value_rows:
    print(f'Warning: {neg_value_rows} rows have negative trade values; please investigate.')
else:
    print('No negative trade values detected.')

dup_keys = ['t', 'exporter_code', 'importer_code', 'hs6']
dup_count = trade.duplicated(subset=dup_keys).sum()
print(f'Duplicated rows on (year, exporter, importer, hs6): {dup_count}')

Total observations after filtering: 24,025,656
Rows per year:
t
2016    7892508
2017    8132873
2018    8000275
Name: count, dtype: int64

Missing values in key columns:
t                0
exporter_code    0
importer_code    0
hs6              0
v                0
q                0
dtype: int64
No negative trade values detected.

Missing values in key columns:
t                0
exporter_code    0
importer_code    0
hs6              0
v                0
q                0
dtype: int64
No negative trade values detected.
Duplicated rows on (year, exporter, importer, hs6): 0
Duplicated rows on (year, exporter, importer, hs6): 0


In [3]:
# 1) 近三年每个国家的贸易伙伴数量（出口与进口合并）
pairs = trade[['exporter_code', 'importer_code']].dropna()
from collections import defaultdict
partners = defaultdict(set)
for _, row in pairs.iterrows():
    a = row['exporter_code']
    b = row['importer_code']
    partners[a].add(b)
    partners[b].add(a)

partner_counts = pd.DataFrame(
    [{'country_code': code, 'n_partners': len(counter)} for code, counter in partners.items()]
).sort_values('n_partners', ascending=False).reset_index(drop=True)

lookup = globals().get('country_lookup')
if lookup is not None:
    partner_counts = partner_counts.merge(
        lookup[['code', 'name']], left_on='country_code', right_on='code', how='left'
    )
    partner_counts['display'] = partner_counts.apply(
        lambda r: f"{r['country_code']} ({r['name']})" if pd.notna(r['name']) else r['country_code'], axis=1
    )
else:
    partner_counts['display'] = partner_counts['country_code']

print('Top 10 - 伙伴数量最多的国家（2016-2018 合计）:')
print(partner_counts.head(10)[['display', 'n_partners']].to_string(index=False))
print('\nBottom 10 - 伙伴数量最少的国家（2016-2018 合计）:')
bottom = partner_counts.tail(10).sort_values('n_partners')
print(bottom[['display', 'n_partners']].to_string(index=False))

Top 10 - 伙伴数量最多的国家（2016-2018 合计）:
display  n_partners
    764         220
    251         220
    616         220
    528         220
    276         220
    381         220
    724         220
    826         219
    203         219
    058         219

Bottom 10 - 伙伴数量最少的国家（2016-2018 合计）:
display  n_partners
    530           1
    535          36
    666          43
    652          46
    534          50
    162          50
    574          51
    260          52
    876          54
    583          56


In [4]:
# 小规模测试：先检测国家/产品映射文件的编码并预览少量数据
ctry_map_path = DATA_DIR / 'country_codes_v202001.csv'
prod_map_path = DATA_DIR / 'product_codes_hs12_v202001.csv'
COUNTRY_ENCODING = None
PRODUCT_ENCODING = None


def detect_encoding(path: Path, label: str, test_rows: int = 100):
    if not path.exists():
        print(f'{label} 文件缺失: {path}')
        return None, None
    encodings = ['utf-8', 'utf-8-sig', 'latin1', 'ISO-8859-1']
    last_error = None
    for enc in encodings:
        try:
            sample = pd.read_csv(path, dtype=str, nrows=test_rows, encoding=enc)
            print(f'{label} 试探读取成功，encoding={enc}, preview rows={len(sample)}')
            return enc, sample
        except UnicodeDecodeError as exc:
            last_error = exc
            print(f'{label} 读取失败（encoding={enc}）：{exc}')
    raise RuntimeError(f'{label} 无法用候选编码解析，最后错误: {last_error}')


COUNTRY_ENCODING, country_sample = detect_encoding(ctry_map_path, 'Country codes')
PRODUCT_ENCODING, product_sample = detect_encoding(prod_map_path, 'Product codes')

print('\nCountry sample preview:')
if country_sample is not None:
    print(country_sample.head())

print('\nProduct sample preview:')
if product_sample is not None:
    print(product_sample.head())

print('\n编码测试完成，如上预览无报错即可继续执行下一单元。')

Country codes 读取失败（encoding=utf-8）：'utf-8' codec can't decode byte 0xf4 in position 1: invalid continuation byte
Country codes 读取失败（encoding=utf-8-sig）：'utf-8' codec can't decode byte 0xf4 in position 4141: invalid continuation byte
Country codes 试探读取成功，encoding=latin1, preview rows=100
Product codes 试探读取成功，encoding=utf-8, preview rows=100

Country sample preview:
  country_code country_name_abbreviation country_name_full iso_2digit_alpha  \
0            4               Afghanistan       Afghanistan               AF   
1            8                   Albania           Albania               AL   
2           12                   Algeria           Algeria               DZ   
3           16            American Samoa    American Samoa               AS   
4           20                   Andorra           Andorra               AD   

  iso_3digit_alpha  
0              AFG  
1              ALB  
2              DZA  
3              ASM  
4              AND  

Product sample preview:
    cod

In [None]:
# 2) 匹配国家 / 产品代码并生成中美的伙伴统计
if 'COUNTRY_ENCODING' not in globals():
    COUNTRY_ENCODING = 'utf-8'
else:
    COUNTRY_ENCODING = COUNTRY_ENCODING or 'utf-8'
if 'PRODUCT_ENCODING' not in globals():
    PRODUCT_ENCODING = 'utf-8'
else:
    PRODUCT_ENCODING = PRODUCT_ENCODING or 'utf-8'

ctry_df = pd.read_csv(ctry_map_path, dtype=str, encoding=COUNTRY_ENCODING) if ctry_map_path.exists() else None
prod_df = pd.read_csv(prod_map_path, dtype=str, encoding=PRODUCT_ENCODING) if prod_map_path.exists() else None
print('Country codes loaded:', ctry_df.shape if ctry_df is not None else 'Missing file')
print('Product codes loaded:', prod_df.shape if prod_df is not None else 'Missing file')


def build_country_lookup(df: pd.DataFrame | None) -> pd.DataFrame | None:
    if df is None:
        return None
    df = df.copy()
    df.columns = [c.strip().lower() for c in df.columns]
    name_cols = [c for c in df.columns if any(k in c for k in ['name', 'label'])]
    id_cols = [c for c in df.columns if any(k in c for k in ['iso', 'code', 'num'])]
    if not name_cols or not id_cols:
        return None
    name_col = name_cols[0]
    id_col = id_cols[0]
    df['code'] = df[id_col].apply(normalize_numeric_code)
    df['name'] = df[name_col].str.strip()
    df = df.dropna(subset=['code']).drop_duplicates('code')
    df['name_key'] = df['name'].str.upper().str.strip()
    return df[['code', 'name', 'name_key']]


def build_product_lookup(df: pd.DataFrame | None) -> tuple[dict, pd.DataFrame] | tuple[None, None]:
    if df is None:
        return None, None
    df = df.copy()
    df.columns = [c.strip().lower() for c in df.columns]
    code_cols = [c for c in df.columns if 'code' in c or 'hs' in c]
    desc_cols = [c for c in df.columns if any(k in c for k in ['desc', 'product', 'label'])]
    if not code_cols or not desc_cols:
        return None, None
    code_col = code_cols[0]
    desc_col = desc_cols[0]
    df['hs6'] = df[code_col].astype(str).str.zfill(6)
    df['hs2'] = df['hs6'].str[:2]
    df = df.rename(columns={desc_col: 'product_desc'})
    hs2_desc_map = df[['hs2', 'product_desc']].dropna().drop_duplicates('hs2').set_index('hs2')['product_desc'].to_dict()
    return hs2_desc_map, df[['hs6', 'hs2', 'product_desc']]


country_lookup = build_country_lookup(ctry_df)
hs2_desc_map, product_lookup_df = build_product_lookup(prod_df)

if country_lookup is not None:
    exporter_lookup = country_lookup.rename(columns={'code': 'exporter_code', 'name': 'exporter_name', 'name_key': 'exporter_name_key'})
    importer_lookup = country_lookup.rename(columns={'code': 'importer_code', 'name': 'importer_name', 'name_key': 'importer_name_key'})
    trade = trade.merge(exporter_lookup[['exporter_code', 'exporter_name']], on='exporter_code', how='left')
    trade = trade.merge(importer_lookup[['importer_code', 'importer_name']], on='importer_code', how='left')
else:
    trade['exporter_name'] = trade['exporter_code']
    trade['importer_name'] = trade['importer_code']

if product_lookup_df is not None:
    trade = trade.merge(
        product_lookup_df[['hs6', 'product_desc']], on='hs6', how='left'
    )
    trade['hs2_desc'] = trade['hs2'].map(hs2_desc_map)
else:
    trade['product_desc'] = trade['hs6']
    trade['hs2_desc'] = trade['hs2']

# 辅助函数：根据名称查找国家代码
def find_code_by_name(keyword: str, default: str) -> str:
    if country_lookup is None:
        return default
    mask = country_lookup['name'].str.contains(keyword, case=False, na=False)
    if mask.any():
        return country_lookup.loc[mask, 'code'].iloc[0]
    return default

china_code = find_code_by_name('China', '156')
usa_code = find_code_by_name('United States', '840')
japan_code = find_code_by_name('Japan', '392')
print('Detected codes -> China:', china_code, 'USA:', usa_code, 'Japan:', japan_code)

# 整体贸易额（价值与数量）
total_trade_value = trade['v'].sum()
total_trade_qty = trade['q'].sum()
print(f"Total trade value 2016-2018: {total_trade_value:,.2f}")
print(f"Total trade quantity 2016-2018: {total_trade_qty:,.2f}")


def summarize_top_partners(exporter_code: str, label: str, top_n: int = 10) -> pd.DataFrame:
    subset = trade[trade['exporter_code'] == exporter_code]
    grouped = (
        subset.groupby(['importer_code', 'importer_name'], dropna=False)
        .agg(export_value=('v', 'sum'), export_quantity=('q', 'sum'))
        .sort_values('export_value', ascending=False)
        .head(top_n)
        .reset_index()
    )
    print(f"\n{label} - Top {top_n} export partners by trade value:")
    print(grouped.to_string(index=False))
    return grouped


china_top_partners = summarize_top_partners(china_code, 'China')
usa_top_partners = summarize_top_partners(usa_code, 'USA')

china_exports = trade[trade['exporter_code'] == china_code]
china_flow_cols = ['exporter_name', 'importer_name', 'importer_code', 'k', 'product_desc', 't']
china_flow_summary = (
    china_exports.groupby(china_flow_cols, dropna=False)
    .agg(export_value=('v', 'sum'), export_quantity=('q', 'sum'))
    .sort_values('export_value', ascending=False)
    .head(5)
    .reset_index()
)
print('\nChina - Top 5 product-year bilateral flows (exports only):')
print(china_flow_summary.to_string(index=False))

In [None]:
# 3) 各国出口额 Top10 产品（HS2）以及整体按价值/数量 Top10 商品
exports_hs2 = (
    trade.groupby(['exporter_code', 'exporter_name', 'hs2', 'hs2_desc'], dropna=False)
    .agg(total_value=('v', 'sum'), total_quantity=('q', 'sum'))
    .reset_index()
)


def top_products_for_country(code: str, label: str, n: int = 10) -> pd.DataFrame:
    subset = exports_hs2[exports_hs2['exporter_code'] == code]
    ranked = subset.sort_values('total_value', ascending=False).head(n)
    print(f"\n{label} - Top {n} HS2 export categories (by value):")
    display_cols = ['hs2', 'hs2_desc', 'total_value', 'total_quantity']
    print(ranked[display_cols].to_string(index=False))
    return ranked


top_products_for_country(china_code, 'China')
top_products_for_country(japan_code, 'Japan')
top_products_for_country(usa_code, 'USA')

overall_by_value = (
    trade.groupby(['hs2', 'hs2_desc'], dropna=False)
    .agg(total_value=('v', 'sum'))
    .sort_values('total_value', ascending=False)
    .head(10)
)
overall_by_quantity = (
    trade.groupby(['hs2', 'hs2_desc'], dropna=False)
    .agg(total_quantity=('q', 'sum'))
    .sort_values('total_quantity', ascending=False)
    .head(10)
)

print('\nOverall top 10 HS2 categories by trade value:')
print(overall_by_value.reset_index().to_string(index=False))
print('\nOverall top 10 HS2 categories by trade quantity:')
print(overall_by_quantity.reset_index().to_string(index=False))

In [None]:
# 4) 计算中国与所有国家重心的距离，并绘制 log-log 散点图
shp_files = list((SHAPE_DIR).glob('*.shp'))
if not shp_files:
    print('找不到 shapefile，请确认路径:', SHAPE_DIR)
else:
    shp_path = shp_files[0]
    gdf = gpd.read_file(shp_path)
    print('Loaded shapefile columns:', gdf.columns.tolist())
    name_cols = [c for c in gdf.columns if 'name' in c.lower() or 'country' in c.lower()]
    name_col = name_cols[0] if name_cols else gdf.columns[0]
    iso_cols = [c for c in gdf.columns if any(k in c.lower() for k in ['iso', 'code', 'num'])]
    gdf_cent = gdf.to_crs(epsg=4326).copy()
    gdf_cent['lon'] = gdf_cent.geometry.centroid.x
    gdf_cent['lat'] = gdf_cent.geometry.centroid.y
    gdf_cent['name_key'] = gdf_cent[name_col].astype(str).str.upper().str.strip()
    if iso_cols:
        gdf_cent['shp_code'] = gdf_cent[iso_cols[0]].apply(normalize_numeric_code)
    else:
        gdf_cent['shp_code'] = None

    # 锁定中国重心
    china_mask = gdf_cent['name_key'].str.contains('CHINA', case=False, na=False)
    if not china_mask.any():
        print('警告：在 shapefile 中未找到 China 关键字，使用第一条记录作为近似。')
    china_row = gdf_cent[china_mask].iloc[0] if china_mask.any() else gdf_cent.iloc[0]
    china_lon, china_lat = float(china_row['lon']), float(china_row['lat'])
    geod = Geod(ellps='WGS84')

    def gc_distance(row):
        _, _, meters = geod.inv(china_lon, china_lat, float(row['lon']), float(row['lat']))
        return meters / 1000.0

    gdf_cent['dist_km_to_china'] = gdf_cent.apply(gc_distance, axis=1)
    shp_lookup = gdf_cent[[name_col, 'name_key', 'lon', 'lat', 'dist_km_to_china', 'shp_code']].rename(columns={name_col: 'shp_name'})

    if country_lookup is not None:
        shp_lookup = shp_lookup.merge(
            country_lookup[['code', 'name', 'name_key']], on='name_key', how='left', suffixes=('_shp', '')
        )
        shp_lookup['iso_code'] = shp_lookup['code'].fillna(shp_lookup['shp_code'])
        shp_lookup['country_name'] = shp_lookup['name'].fillna(shp_lookup['shp_name'])
    else:
        shp_lookup['iso_code'] = shp_lookup['shp_code']
        shp_lookup['country_name'] = shp_lookup['shp_name']

    print('\nSample distances from China to other centroids (km):')
    print(shp_lookup[['country_name', 'dist_km_to_china']].dropna().head(10).to_string(index=False))

    china_exports_full = (
        trade[trade['exporter_code'] == china_code]
        .groupby(['importer_code', 'importer_name'], dropna=False)
        .agg(export_value=('v', 'sum'), export_quantity=('q', 'sum'))
        .reset_index()
    )

    china_with_dist = china_exports_full.merge(
        shp_lookup[['iso_code', 'dist_km_to_china']], left_on='importer_code', right_on='iso_code', how='left'
    )
    matched = china_with_dist['dist_km_to_china'].notna().sum()
    print(f"Matched distances for {matched} China export partners out of {len(china_with_dist)}.")

    plot_df = china_with_dist[(china_with_dist['dist_km_to_china'] > 0) & (china_with_dist['export_value'] > 0)].copy()
    plot_df['log_dist'] = np.log(plot_df['dist_km_to_china'])
    plot_df['log_value'] = np.log(plot_df['export_value'])

    plt.figure(figsize=(8, 5))
    sns.scatterplot(data=plot_df, x='log_dist', y='log_value', hue='importer_name', legend=False)
    plt.xlabel('log(distance km)')
    plt.ylabel('log(export value)')
    plt.title('China exports: log(distance) vs log(value)')
    plt.show()

    qty_df = plot_df[plot_df['export_quantity'] > 0].copy()
    qty_df['log_quantity'] = np.log(qty_df['export_quantity'])
    if not qty_df.empty:
        plt.figure(figsize=(8, 5))
        sns.scatterplot(data=qty_df, x='log_dist', y='log_quantity', hue='importer_name', legend=False)
        plt.xlabel('log(distance km)')
        plt.ylabel('log(export quantity)')
        plt.title('China exports: log(distance) vs log(quantity)')
        plt.show()
    else:
        print('No positive export quantities available for log quantity scatter plot.')