# A/H 溢价季度策略回测分析

本 Notebook 参考 `ah_premium_quarterly_bt.py`，分步实现 A/H 溢价季度策略回测，便于交互分析和结果展示。

In [None]:
# 导入必要库
import duckdb
from qs.backtester.data import Bar, DataFeed
from qs.backtester.broker import Broker
from qs.backtester.engine import BacktestEngine
from qs.backtester.stats import (
    compute_annual_returns,
    compute_max_drawdown,
    compute_risk_metrics,
)
from qs.strategy.ah_premium_quarterly import AHPremiumQuarterlyStrategy
import matplotlib.pyplot as plt

In [None]:
# 全局字体设置（中文显示 & 负号正常 & 回退逻辑）
import matplotlib
from matplotlib import font_manager, rcParams
import os, warnings

# 可选：抑制 glyph 缺失警告（若仍未找到字体时可打开）
# warnings.filterwarnings('ignore', 'Glyph .* missing from font')

# 1. 首先尝试直接已注册字体
preferred_names = [
    "Microsoft YaHei",
    "SimHei",
    "Arial Unicode MS",
    "Noto Sans CJK SC",
    "Source Han Sans CN",
    "Source Han Sans SC",
]

available = {f.name for f in font_manager.fontManager.ttflist}
selected = None
for name in preferred_names:
    if name in available:
        selected = name
        break

# 2. 若未命中，尝试手动添加常见 Windows 字体路径
if selected is None:
    candidate_files = [
        r"C:\\Windows\\Fonts\\msyh.ttc",  # 微软雅黑
        r"C:\\Windows\\Fonts\\msyhl.ttc",  # 微软雅黑Light
        r"C:\\Windows\\Fonts\\simhei.ttf",  # 黑体
        r"C:\\Windows\\Fonts\\msyh.ttf",  # 旧格式
    ]
    for fp in candidate_files:
        if os.path.exists(fp):
            try:
                font_manager.fontManager.addfont(fp)
            except Exception:
                pass
    # 重新获取
    available = {f.name for f in font_manager.fontManager.ttflist}
    for name in preferred_names:
        if name in available:
            selected = name
            break

# 3. 再次 fallback：若依旧没有，尝试下载/放置 Noto Sans （此处提示，不自动下载）
if selected is None:
    print(
        "[字体提示] 未找到常见中文字体，可下载 Noto Sans CJK 放入 fonts/ 目录, 如: fonts/NotoSansCJKsc-Regular.otf"
    )
    local_noto = "fonts/NotoSansCJKsc-Regular.otf"
    if os.path.exists(local_noto):
        try:
            font_manager.fontManager.addfont(local_noto)
            available = {f.name for f in font_manager.fontManager.ttflist}
            for name in preferred_names:
                if name in available:
                    selected = name
                    break
        except Exception as e:
            print("[字体加载失败]", e)

# 4. 设置 rcParams
if selected:
    rcParams["font.sans-serif"] = [selected]
else:
    # 最后保留 DejaVu Sans 但提示
    print("[警告] 仍使用默认字体，中文可能乱码。")
rcParams["axes.unicode_minus"] = False
print("Using font family:", rcParams.get("font.sans-serif", []))

In [None]:
# 加载交易日历数据
start_date = "20180101"
db_path = "data/data.duckdb"
con = duckdb.connect(db_path, read_only=True)
q = f"""
SELECT trade_date,
       MIN(open) AS open,
       MIN(high) AS high,
       MIN(low) AS low,
       MIN(close) AS close,
       NULL AS pct_chg
FROM (
  SELECT trade_date, open, high, low, close FROM daily_a
  UNION ALL
  SELECT trade_date, open, high, low, close FROM daily_h
)
WHERE trade_date >= '{start_date}'
GROUP BY 1
ORDER BY 1
"""
rows = con.execute(q).fetchall()
con.close()
bars = [Bar(*r) for r in rows]
feed = DataFeed(bars)
print(f"Loaded {len(bars)} trading days from {start_date}")

In [None]:
# 设置回测参数
TOP_K = 5
BOTTOM_K = 5
START_DATE = "20180101"
INIT_CASH = 1_000_000
CAPITAL_SPLIT = 0.5  # H股分配比例

In [None]:
# 初始化回测组件
broker = Broker(cash=INIT_CASH, enable_trade_log=False)
strat = AHPremiumQuarterlyStrategy(
    top_k=TOP_K,
    bottom_k=BOTTOM_K,
    start_date=START_DATE,
    capital_split=CAPITAL_SPLIT,
)
engine = BacktestEngine(feed, broker, strat)

In [None]:
# 运行回测引擎
curve = engine.run()
print(f"回测完成，曲线长度: {len(curve)}")

In [None]:
# 计算年度收益率
ann = compute_annual_returns(curve)
print("年度收益率:")
for y, r in ann.items():
    print(f"  {y}: {r:.2%}")

In [None]:
# 计算最大回撤
max_dd, dd_peak, dd_trough = compute_max_drawdown(curve)
print(f"最大回撤: {max_dd:.2%} from {dd_peak} to {dd_trough}")

In [None]:
# 计算风险指标
risk = compute_risk_metrics(curve, INIT_CASH)
print("风险指标:")
for k, v in risk.items():
    if k.endswith("Rate") or k in ("CAGR", "AnnReturn", "AnnVol", "Sharpe"):
        print(f"  {k}: {v:.4f}")
    else:
        print(f"  {k}: {v}")

## 策略 vs 三大指数归一化收益曲线对比

下方代码将查询 `index_daily` (000300.SH) 与 `index_global` (HSI, IXIC) 的收盘价，与策略净值一起按起始日期归一化后绘制对比曲线。

In [None]:
# 三大指数与策略净值归一化对比绘图
import pandas as pd

# 1. 读取指数数据
con = duckdb.connect("data/data.duckdb", read_only=True)
idx_sh300 = con.execute(
    "SELECT trade_date, close FROM index_daily WHERE ts_code='000300.SH' AND trade_date >= ? ORDER BY trade_date",
    [START_DATE],
).fetchdf()
idx_hsi = con.execute(
    "SELECT trade_date, close FROM index_global WHERE ts_code='HSI' AND trade_date >= ? ORDER BY trade_date",
    [START_DATE],
).fetchdf()
idx_ixic = con.execute(
    "SELECT trade_date, close FROM index_global WHERE ts_code='IXIC' AND trade_date >= ? ORDER BY trade_date",
    [START_DATE],
).fetchdf()
con.close()

# 2. 转换日期 & 归一化
for df in (idx_sh300, idx_hsi, idx_ixic):
    df["date"] = pd.to_datetime(df["trade_date"])
    base = df["close"].iloc[0] if not df.empty else 1.0
    df["norm"] = df["close"] / base

# 策略曲线 DataFrame
curve_df = pd.DataFrame(
    {
        "date": [pd.to_datetime(b.trade_date) for b in curve],
        "equity": [b.equity for b in curve],
    }
)
curve_df.sort_values("date", inplace=True)
curve_df["norm"] = curve_df["equity"] / curve_df["equity"].iloc[0]

# 3. 对齐日期（内连接可选；此处保持各自跨度，绘图自动处理）

# 4. 绘图
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(
    curve_df["date"], curve_df["norm"], label="策略净值", linewidth=1.6, color="black"
)
ax.plot(idx_sh300["date"], idx_sh300["norm"], label="沪深300 指数", linewidth=1.0)
ax.plot(idx_hsi["date"], idx_hsi["norm"], label="恒生指数", linewidth=1.0)
ax.plot(idx_ixic["date"], idx_ixic["norm"], label="纳斯达克指数", linewidth=1.0)

ax.set_title(f"A/H 溢价季度策略 vs 三大指数（自 {START_DATE} 起）")
ax.set_xlabel("日期")
ax.set_ylabel("归一化收益 (起点=1)")
ax.set_ylim(bottom=0)
ax.legend(loc="upper left", fontsize=9)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# 交易历史总费用汇总
from collections import defaultdict

print(f"总费用(佣金+卖出税): {broker.total_fees:.2f}")
print(f"成交笔数: {len(broker.trades)}")
if broker.trades:
    # 按年份汇总费用
    yearly = defaultdict(float)
    for tr in broker.trades:
        yearly[tr.trade_date[:4]] += tr.fees
    print("年度费用:")
    for y in sorted(yearly):
        print(f"  {y}: {yearly[y]:.2f}")
    # 前几条示例
    print("前5条交易记录(日期, 动作, 标的, 手数, 费用):")
    for tr in broker.trades[:5]:
        print(
            f"  {tr.trade_date} {tr.action} {tr.symbol} size={int(tr.size)} fees={tr.fees:.2f}"
        )

In [None]:
# 历史全部交易明细（含名称 & 上一交易日 AH 溢价）
import pandas as pd
import duckdb

if broker.trades:
    # ---------------- 基础交易表 ----------------
    trades_df = pd.DataFrame(
        [
            {
                "date": t.trade_date,
                "action": t.action,
                "symbol": t.symbol,
                "price_signal": t.price,
                "price_exec": t.exec_price,
                "size": t.size,
                "gross_amount": t.gross_amount,
                "fees": t.fees,
                "cash_after": t.cash_after,
                "position_after": t.position_after,
                "equity_after": t.equity_after,
            }
            for t in broker.trades
        ]
    )
    trades_df["date"] = pd.to_datetime(trades_df["date"])

    # ---------------- 读取配对映射 ----------------
    pair_df = pd.read_csv("data/ah_codes.csv")
    # 自适应列名：如果没有 hk_code 但有 c，则重命名
    if "hk_code" not in pair_df.columns:
        if "c" in pair_df.columns:
            pair_df = pair_df.rename(columns={"c": "hk_code"})
    # 期望列: cn_code, hk_code
    missing_cols = [c for c in ("cn_code", "hk_code") if c not in pair_df.columns]
    if missing_cols:
        raise ValueError(f"ah_codes.csv 缺少必要列: {missing_cols}")
    pair_df = pair_df.dropna(subset=["cn_code", "hk_code"]).copy()
    a_codes = pair_df["cn_code"].unique().tolist()
    h_codes = pair_df["hk_code"].unique().tolist()

    # ---------------- 名称映射 ----------------
    symbols = sorted(trades_df["symbol"].unique())
    if symbols:
        sym_list_sql = ",".join([f"'{s}'" for s in symbols])
        con_tmp = duckdb.connect("data/data.duckdb", read_only=True)
        name_rows = con_tmp.execute(
            f"""
            SELECT ts_code, name FROM stock_basic_a WHERE ts_code IN ({sym_list_sql})
            UNION ALL
            SELECT ts_code, name FROM stock_basic_h WHERE ts_code IN ({sym_list_sql})
            """
        ).fetchall()
        con_tmp.close()
        name_map = {r[0]: r[1] for r in name_rows}
        trades_df["name"] = trades_df["symbol"].map(name_map)
    else:
        trades_df["name"] = ""

    # ---------------- 计算上一交易日溢价 ----------------
    try:
        max_adj_a = getattr(strat, "_max_adj_a", {})
        max_adj_h = getattr(strat, "_max_adj_h", {})
    except Exception:
        max_adj_a, max_adj_h = {}, {}

    con = duckdb.connect("data/data.duckdb", read_only=True)

    trade_dates_str = trades_df["date"].dt.strftime("%Y%m%d")
    unique_trade_dates = sorted(trade_dates_str.unique())

    # 当前交易日 -> 上一有效交易日
    prev_map = {}
    for d in unique_trade_dates:
        row = con.execute(
            f"""
            SELECT trade_date FROM (
              SELECT trade_date FROM daily_a
              UNION
              SELECT trade_date FROM daily_h
            ) WHERE trade_date < '{d}'
            ORDER BY trade_date DESC LIMIT 1
            """
        ).fetchone()
        prev_map[d] = row[0] if row else None

    needed_prev_dates = sorted({p for p in prev_map.values() if p})

    premium_snapshots = {}
    for prev_d in needed_prev_dates:
        a_rows = con.execute(
            f"SELECT d.ts_code, d.close, a.adj_factor FROM daily_a d JOIN adj_factor_a a USING(ts_code,trade_date) \n"
            f"WHERE trade_date='{prev_d}' AND ts_code IN ({','.join(repr(x) for x in a_codes)})"
        ).fetchall()
        h_rows = con.execute(
            f"SELECT d.ts_code, d.close, a.adj_factor FROM daily_h d JOIN adj_factor_h a USING(ts_code,trade_date) \n"
            f"WHERE trade_date='{prev_d}' AND ts_code IN ({','.join(repr(x) for x in h_codes)})"
        ).fetchall()
        fx_usd_cnh = con.execute(
            f"SELECT (bid_close+ask_close)/2 FROM fx_daily WHERE ts_code='USDCNH.FXCM' AND trade_date='{prev_d}'"
        ).fetchone()
        fx_usd_hkd = con.execute(
            f"SELECT (bid_close+ask_close)/2 FROM fx_daily WHERE ts_code='USDHKD.FXCM' AND trade_date='{prev_d}'"
        ).fetchone()
        if not fx_usd_cnh or not fx_usd_hkd or fx_usd_hkd[0] in (None, 0):
            premium_snapshots[prev_d] = {}
            continue
        hk_to_cny = float(fx_usd_cnh[0]) / float(fx_usd_hkd[0])
        a_map = {r[0]: (float(r[1]), float(r[2])) for r in a_rows}
        h_map = {r[0]: (float(r[1]), float(r[2])) for r in h_rows}
        snap = {}
        for _, r in pair_df.iterrows():
            a_code = r["cn_code"]
            h_code = r["hk_code"]
            if a_code not in a_map or h_code not in h_map:
                continue
            close_a, adj_a = a_map[a_code]
            close_h, adj_h = h_map[h_code]
            max_a = max_adj_a.get(a_code, 1.0)
            max_h = max_adj_h.get(h_code, 1.0)
            fq_a = close_a * adj_a / max_a
            fq_h_cny = close_h * adj_h / max_h * hk_to_cny
            if fq_h_cny == 0:
                continue
            premium_pct = (fq_a / fq_h_cny - 1) * 100
            snap[a_code] = premium_pct
            snap[h_code] = premium_pct
        premium_snapshots[prev_d] = snap

    premium_dates = []
    premium_values = []
    for _, row in trades_df.iterrows():
        dstr = row["date"].strftime("%Y%m%d")
        prev_d = prev_map.get(dstr)
        premium_dates.append(prev_d)
        if prev_d is None:
            premium_values.append(None)
        else:
            premium_values.append(
                premium_snapshots.get(prev_d, {}).get(row["symbol"], None)
            )
    trades_df["premium_date"] = premium_dates
    trades_df["premium_pct"] = premium_values

    con.close()

    trades_df.sort_values(["date", "symbol"], inplace=True)
    cols_order = [
        "date",
        "premium_date",
        "premium_pct",
        "symbol",
        "name",
        "action",
        "size",
        "price_signal",
        "price_exec",
        "gross_amount",
        "fees",
        "cash_after",
        "position_after",
        "equity_after",
    ]
    trades_df = trades_df[cols_order]

    pd.set_option("display.max_rows", None)
    display(trades_df)
    print(
        f"共 {len(trades_df)} 笔交易。premium_pct 为对应上一交易日收盘计算的溢价(%)。"
    )
else:
    print("无交易记录。")