In [None]:
import pandas as pd
import numpy as np
import talib as ta
import warnings

pd.options.mode.chained_assignment = None
pd.set_option('display.max_columns', None)
warnings.filterwarnings('ignore')

# =========================
# 1. 原始数据加载 + 因子计算（只做一次）
# =========================
def load_and_prepare_data(filepath):
    df = pd.read_excel(filepath)
    df['high'] = df['high'].astype(float)
    df['low'] = df['low'].astype(float)
    df['close'] = df['close'].astype(float)

    # 预计算所有原始因子（全市场）
    df['natr_14'] = ta.NATR(df['high'], df['low'], df['close'], timeperiod=14)
    df['ma_20'] = ta.SMA(df['close'], timeperiod=20)
    df['momentum_20'] = df['close'] / df['close'].shift(20)
    df['volatility_20'] = df['close'].rolling(20).std()

    # 可选：保存为缓存文件（parquet 或 pickle）
    df.to_parquet('full_data.parquet')
    return df

# =========================
# 2. 过滤逻辑（灵活组合）
# =========================
def apply_filter(df, config):
    filtered = df.copy()
    if config.get('exclude_high_natr'):
        filtered = filtered[filtered['natr_14'] < config['exclude_high_natr']]
    if config.get('above_ma_20'):
        filtered = filtered[filtered['close'] > filtered['ma_20']]
    return filtered

# =========================
# 3. 局部打分逻辑（池内打分）
# =========================
def calculate_scores(filtered_df, config):
    if 'score_factors' not in config:
        return filtered_df

    for factor in config['score_factors']:
        ascending = config['score_factors'][factor] == 'asc'
        score_name = f'{factor}_score'
        filtered_df[score_name] = filtered_df[factor].rank(ascending=ascending)

    score_cols = [f"{f}_score" for f in config['score_factors'].keys()]
    filtered_df['total_score'] = filtered_df[score_cols].mean(axis=1)
    return filtered_df

# =========================
# 4. 主流程：加载 → 过滤 → 打分
# =========================
if __name__ == '__main__':
    # 第一次运行时构建全量数据缓存
    # df_all = load_and_prepare_data('your_data.xlsx')

    # 后续直接加载缓存
    df_all = pd.read_parquet('full_data.parquet')

    config = {
        'exclude_high_natr': 20,
        'above_ma_20': True,
        'score_factors': {
            'momentum_20': 'desc',
            'volatility_20': 'asc'
        }
    }

    df_filtered = apply_filter(df_all, config)
    df_scored = calculate_scores(df_filtered, config)

    print(df_scored[['close', 'natr_14', 'momentum_20', 'volatility_20', 'total_score']].tail())