In [4]:
import pandas as pd
import numpy as np
# import quantstats as qs # Not directly used for IC/IR calculation
import warnings
from tqdm.auto import tqdm
from IPython.display import display
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr # Needed for IC calculation

warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=pd.errors.PerformanceWarning)
warnings.filterwarnings('ignore', category=UserWarning) # Quantstats might issue user warnings

# --- Function Definitions ---

# Helper for safe division (if needed, otherwise simple division is fine)
def safe_division(numerator, denominator, default=np.nan):
    """Performs division, returning default if denominator is zero or NaN."""
    mask = (denominator == 0) | denominator.isna() | numerator.isna()
    # Ensure denominator is float for np.where compatibility if it might be int 0
    denominator_float = denominator.astype(float)
    result = np.where(mask, default, numerator / denominator_float)
    return result

# Step 1: Load Data (Mostly Unchanged - Ensure required columns exist)
def load_data(cb_path, index_path=None): # Made index_path optional
    """Loads CB data, ensures DatetimeIndex."""
    print("--- Step 1: Loading Data ---")
    try:
        df = pd.read_parquet(cb_path)

        # Ensure df has correct MultiIndex with DatetimeIndex for trade_date
        required_levels = ['code', 'trade_date']
        if not (isinstance(df.index, pd.MultiIndex) and
                all(level in df.index.names for level in required_levels)):
            print("Attempting to set MultiIndex ['code', 'trade_date']...")
            if all(col in df.columns for col in required_levels):
                df['trade_date'] = pd.to_datetime(df['trade_date'])
                df = df.set_index(required_levels)
                print("MultiIndex set successfully.")
            else:
                raise ValueError("CB data missing 'code' or 'trade_date' columns for index.")

        if not isinstance(df.index.levels[df.index.names.index('trade_date')], pd.DatetimeIndex):
             print("Converting 'trade_date' level to DatetimeIndex...")
             df.index = df.index.set_levels(pd.to_datetime(df.index.levels[df.index.names.index('trade_date')]), level='trade_date')
             print("'trade_date' level converted.")

        # --- Add check for required return calculation columns ---
        # Need T close for pulse logic, T+1 open/high/close for fwd returns
        required_cols = ['open', 'high', 'close']
        # pct_chg is useful if defined as T-1 close to T close for easy fwd ret calc
        if 'pct_chg' not in df.columns:
             print("Warning: 'pct_chg' column missing. Forward close-to-close return will be calculated manually if possible.")
        else:
             required_cols.append('pct_chg')

        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
             # Allow proceeding if pct_chg is missing but others are present
             if not ('open' in df.columns and 'high' in df.columns and 'close' in df.columns):
                  raise ValueError(f"Required columns for return calculation missing: {missing_cols}")

        print(f"Loaded CB data shape: {df.shape}")

        # Load index data if path provided (not strictly needed for core analysis)
        index_df = None
        if index_path:
            try:
                index_df = pd.read_parquet(index_path)
                if not isinstance(index_df.index, pd.DatetimeIndex):
                    index_df.index = pd.to_datetime(index_df.index)
                print(f"Loaded Index data shape: {index_df.shape}")
            except Exception as e_idx:
                print(f"Warning: Could not load or process index data from {index_path}: {e_idx}")

        return df, index_df
    except Exception as e:
        print(f"Error loading data: {e}")
        return None, None

# Step 2: Filter Data (Ensure factors_to_analyze exist before filtering)
def filter_data(df, start_date, end_date, filter_rules, factors_to_analyze):
    """Applies date range and custom filters."""
    print("--- Step 2: Filtering Data ---")
    if df is None: return None

    # Check if factor columns exist BEFORE filtering
    all_cols_needed = factors_to_analyze + ['open', 'high', 'close'] # Need OHLC for returns
    if 'pct_chg' in df.columns: all_cols_needed.append('pct_chg')
    # Add columns used in filter_rules if they aren't factors already
    for rule in filter_rules:
        # Basic parsing to find potential column names in rules
        import re
        potential_cols = re.findall(r'`([^`]*)`|(\b[a-zA-Z_][a-zA-Z0-9_]*\b)', rule)
        # Flatten list of tuples and remove empty strings/keywords if any
        rule_cols = {item for sublist in potential_cols for item in sublist if item and item not in ['and', 'or', 'not', 'isin']}
        all_cols_needed.extend(list(rule_cols))
    all_cols_needed = list(set(all_cols_needed)) # Unique columns

    missing_factors = [f for f in all_cols_needed if f not in df.columns]
    if missing_factors:
        print(f"Error: Required columns missing from data: {missing_factors}")
        return None

    # Date filtering
    try:
        if 'trade_date' not in df.index.names:
             raise KeyError("'trade_date' not found in DataFrame index levels.")
        trade_date_level = df.index.get_level_values('trade_date')
        date_mask = (trade_date_level >= pd.to_datetime(start_date)) & (trade_date_level <= pd.to_datetime(end_date))
        df_filtered = df[date_mask].copy()
        if df_filtered.empty: raise ValueError(f"No data remaining after date filtering ({start_date} to {end_date}).")
        print(f"Filtered by date: {start_date} to {end_date}. Shape: {df_filtered.shape}")
    except Exception as e:
        print(f"Error during date filtering: {e}")
        return None

    # Apply standard filters (Optional - Add if needed, e.g., Redemption, Listing days)
    df_filtered['filter_out'] = False
    # Example:
    # redeem_statuses = ['已公告强赎', '公告到期赎回', '公告实施强赎', '公告提示强赎', '已满足强赎条件']
    # if 'is_call' in df_filtered.columns: df_filtered.loc[df_filtered['is_call'].isin(redeem_statuses), 'filter_out'] = True
    # if 'list_days' in df_filtered.columns: df_filtered.loc[df_filtered['list_days'] <= 3, 'filter_out'] = True

    # Apply custom filters
    print("Applying custom filters...")
    initial_eligible = len(df_filtered[~df_filtered['filter_out']])
    for rule in filter_rules:
        try:
            # Apply filter to non-filtered-out rows only to avoid errors on already filtered data
            current_eligible_indices = df_filtered[~df_filtered['filter_out']].index
            if not current_eligible_indices.empty:
                 # Query on the subset of eligible rows
                 filtered_indices = df_filtered.loc[current_eligible_indices].query(rule).index
                 # Mark these as filtered out in the original df_filtered
                 df_filtered.loc[filtered_indices, 'filter_out'] = True
                 print(f" - Applied: {rule}. Marked {len(filtered_indices)} additional rows.")
            else:
                 print(f" - Skipping rule (no eligible rows left): {rule}")

        except Exception as e:
            print(f"  - Warning: Could not apply filter rule '{rule}'. Error: {e}")
            # Consider whether a failing filter should halt the process or just be skipped

    # --- IMPORTANT: Filter out rows where any needed factor is NaN ---
    # We need valid factor values *today* (day T) to correlate with *tomorrow's* return (T+1)
    print("Filtering rows with NaN in factor values...")
    nan_mask = df_filtered[factors_to_analyze].isna().any(axis=1)
    df_filtered.loc[nan_mask, 'filter_out'] = True

    final_eligible_count = len(df_filtered[~df_filtered['filter_out']])
    filtered_by_nan = initial_eligible - final_eligible_count # Rough count, might double count if filters also hit NaNs
    print(f"NaN factor check complete. Marked approx {filtered_by_nan} additional rows due to NaN factors.")
    print(f"Filtering complete. Eligible bond-days for return calculation: {final_eligible_count}")

    if final_eligible_count == 0:
        print("Warning: No bonds eligible after applying all filters and NaN checks.")
    return df_filtered


# --- NEW Step 3: Calculate Multiple Forward Returns ---
def calculate_multiple_fwd_returns(df, pulse_percentages):
    """
    Calculates various next-day return metrics for each bond.
    - 'fwd_ret_close': Raw percentage change from current close to next close.
    - 'fwd_ret_pulse_X': Return based on pulse stop-profit logic at X%.
    """
    print("--- Step 3: Calculating Multiple Forward Returns ---")
    if df is None: return None, []
    required_cols = ['open', 'high', 'close'] # Minimum required for pulse logic
    if not all(col in df.columns for col in required_cols):
        print(f"Error: Missing required columns for return calc: {required_cols}")
        return df, [] # Return original df and empty list
    if not isinstance(df.index, pd.MultiIndex) or 'code' not in df.index.names:
        print("Error: DataFrame needs MultiIndex with 'code' level for forward returns.")
        return df, []

    df_with_fwd = df.copy()
    grouped = df_with_fwd.groupby(level='code', group_keys=False) # group_keys=False avoids adding group keys to index

    # Get next day's data by shifting within each group
    print("  - Shifting data to get next day's OHLC...")
    df_with_fwd['next_open'] = grouped['open'].shift(-1)
    df_with_fwd['next_high'] = grouped['high'].shift(-1)
    df_with_fwd['next_close'] = grouped['close'].shift(-1)

    # Calculate Forward Close-to-Close Return
    # Option 1: Use shifted pct_chg if available and reliable
    if 'pct_chg' in df_with_fwd.columns:
        print("  - Calculating fwd_ret_close using shifted 'pct_chg'...")
        # Assuming pct_chg is (close_T / close_{T-1}) - 1
        # Then shifted pct_chg is (close_{T+1} / close_T) - 1 which is what we want
        df_with_fwd['fwd_ret_close'] = grouped['pct_chg'].shift(-1)
    # Option 2: Calculate manually if pct_chg is missing or unreliable
    else:
        print("  - Calculating fwd_ret_close manually from close and next_close...")
        df_with_fwd['fwd_ret_close'] = safe_division(df_with_fwd['next_close'], df_with_fwd['close']) - 1

    # Calculate pulse returns
    print(f"  - Calculating forward pulse returns for thresholds: {pulse_percentages}%...")
    current_close = df_with_fwd['close']
    next_open = df_with_fwd['next_open']
    next_high = df_with_fwd['next_high']
    # next_close = df_with_fwd['next_close'] # Needed only if manual fwd_ret_close calc is used below
    raw_next_day_ret = df_with_fwd['fwd_ret_close']

    # Handle cases where next day data is missing (last day for a bond) or current close is invalid
    valid_next_day = next_open.notna() & next_high.notna() & current_close.notna() & (current_close > 0) & raw_next_day_ret.notna()

    return_cols = ['fwd_ret_close'] # Start with the base return column name

    for pct in pulse_percentages:
        ret_col_name = f'fwd_ret_pulse_{pct:.1f}' # e.g., fwd_ret_pulse_2.5
        return_cols.append(ret_col_name)
        stop_profit_pct = pct / 100.0
        # Calculate threshold price based on current close
        threshold_price = current_close * (1 + stop_profit_pct)

        # Initialize return column with NaN
        df_with_fwd[ret_col_name] = np.nan

        # --- Vectorized Calculation for Pulse Return ---
        # Condition 1: Triggered at open (next_open >= threshold) -> Return is based on next_open
        # Use safe_division in case current_close is somehow zero despite filter
        open_return = safe_division(next_open, current_close) - 1

        # Condition 2: Triggered intraday (next_high >= threshold but next_open < threshold) -> Return is stop_profit_pct
        intra_return = stop_profit_pct

        # Condition 3: Not triggered -> Return is raw_next_day_ret (close-to-close)

        # Apply conditions using np.select or chained .loc
        conditions = [
            valid_next_day & (next_open >= threshold_price),                     # Triggered at open
            valid_next_day & (next_open < threshold_price) & (next_high >= threshold_price), # Triggered intraday
            valid_next_day & (next_high < threshold_price)                       # Not triggered
        ]
        choices = [
            open_return,        # Use calculated open_return
            intra_return,       # Use the fixed stop-profit percentage
            raw_next_day_ret    # Use the close-to-close return
        ]

        # Apply the logic using np.select for efficiency
        df_with_fwd[ret_col_name] = np.select(conditions, choices, default=np.nan)

    # Clean up intermediate columns if desired
    df_with_fwd = df_with_fwd.drop(columns=['next_open', 'next_high', 'next_close'], errors='ignore')

    # Report NaN counts for the new return columns
    nan_counts = df_with_fwd[return_cols].isna().sum()
    total_rows = len(df_with_fwd)
    print(f"Calculated forward returns ({len(return_cols)} types). Example NaN counts (out of {total_rows}):\n{nan_counts}")

    # Check if ALL return columns are completely NaN for the eligible rows
    eligible_returns = df_with_fwd.loc[~df_with_fwd['filter_out'], return_cols]
    if eligible_returns.isna().all().all():
         print("Warning: All calculated forward returns are NaN for the eligible data points. IC calculation will yield no results.")

    return df_with_fwd, return_cols


# --- MODIFIED Step 4: Analyze Factor vs. Each Return Type Relationship (IC/IR) ---
def analyze_factor_return_relationships(df, factors, return_cols):
    """
    Calculates Information Coefficient (IC) and Information Ratio (IR)
    for each factor against each specified forward return column.
    Uses Spearman rank correlation for IC.
    """
    print(f"--- Step 4: Analyzing Factor Relationships with {len(return_cols)} Return Types ---")
    if df is None:
        print("Error: DataFrame is missing.")
        return None
    if 'filter_out' not in df.columns:
        print("Error: 'filter_out' column missing. Cannot select eligible rows.")
        return None
    if not return_cols:
        print("Error: No forward return columns provided for analysis.")
        return None
    if not isinstance(df.index, pd.MultiIndex) or 'trade_date' not in df.index.names:
        print("Error: DataFrame needs MultiIndex with 'trade_date' level for IC calc.")
        return None

    all_ic_results = {}
    # daily_ic_data = {} # Optional: Store all daily ICs for plotting if needed

    # --- Use eligible rows only based on initial filtering ---
    # Forward returns were calculated for all rows, but we only analyze rows that were eligible *before* return calc.
    df_eligible_base = df[~df['filter_out']].copy() # Select rows eligible on day T

    if df_eligible_base.empty:
        print("Warning: No eligible bond-days found based on initial filters. Cannot calculate IC.")
        return pd.DataFrame()

    # Iterate through each type of forward return calculated in Step 3
    for return_col in tqdm(return_cols, desc="Analyzing Return Types"):
        if return_col not in df_eligible_base.columns:
            print(f"Warning: Return column '{return_col}' not found in eligible data. Skipping.")
            continue

        print(f"\n-- Analyzing Factors vs. Return: '{return_col}' --")

        # Prepare data for this specific return type:
        # Need factor values from day T and return value for T+1 (which is stored in return_col on day T's row)
        # Drop rows where *this specific* return is NaN or the factor is NaN (factor NaNs should be handled by filter_out, but double check)
        analysis_subset = df_eligible_base[factors + [return_col]].dropna(subset=[return_col] + factors)

        if analysis_subset.empty:
            print(f"Warning: No valid (non-NaN) Factor/Return pairs found for '{return_col}' in eligible data. Skipping IC calculation for this return type.")
            continue

        # Group by date to calculate daily IC
        grouped = analysis_subset.groupby(level='trade_date')
        num_days = len(grouped)
        if num_days == 0 :
             print(f"Warning: No trade dates found after filtering NaNs for {return_col}. Skipping.")
             continue
        print(f"Analyzing {num_days} days for '{return_col}' with {len(analysis_subset)} total valid pairs...")

        # Iterate through each factor for the current return type
        for factor in factors:
            if factor not in analysis_subset.columns: # Should exist, but check
                continue

            # Function to safely calculate Spearman correlation per day
            # It receives a DataFrame group for a specific day
            def safe_spearman(group):
                # Check if factor and return columns exist in the group
                if factor not in group.columns or return_col not in group.columns:
                    return np.nan
                # Double check for NaNs within the group for this pair (should be removed already, but safety first)
                group_cleaned = group[[factor, return_col]].dropna()
                if len(group_cleaned) < 5: # Need sufficient pairs for meaningful correlation (adjust threshold if needed)
                    return np.nan
                # Check for zero variance in either column within the group
                if group_cleaned[factor].nunique() <= 1 or group_cleaned[return_col].nunique() <= 1:
                     # print(f"Warning: Zero variance detected for {factor} or {return_col} on date {group.name}. Returning NaN for IC.")
                     return np.nan
                try:
                    # Use rank correlation (Spearman) for IC
                    corr, p_val = spearmanr(group_cleaned[factor], group_cleaned[return_col])
                    # Handle potential NaN result from spearmanr itself
                    return corr if not np.isnan(corr) else np.nan
                except Exception as e: # Catch any other unexpected errors
                    # print(f"Warning: spearmanr failed for {factor}/{return_col} on {group.name}. Error: {e}. Returning NaN.")
                    return np.nan

            # Apply the function to each day's group using .apply()
            try:
                # This applies safe_spearman to each group (each day's subset of data)
                # The index of daily_ic will be the 'trade_date'
                daily_ic = grouped.apply(safe_spearman)
                daily_ic_clean = daily_ic.dropna() # Remove days where IC calculation failed or wasn't possible

                # Store daily IC series if needed later
                # daily_ic_data[(factor, return_col)] = daily_ic_clean

                # Calculate summary statistics for this factor/return pair
                if daily_ic_clean.empty:
                    # print(f" - Factor '{factor}': No valid daily ICs calculated.")
                    mean_ic, std_ic, ir, ic_positive_ratio, num_obs = np.nan, np.nan, np.nan, np.nan, 0
                elif len(daily_ic_clean) < 2:
                    # print(f" - Factor '{factor}': Only 1 valid daily IC. Cannot calculate Std Dev/IR.")
                    mean_ic = daily_ic_clean.mean()
                    std_ic = np.nan # Cannot calculate std dev with 1 point
                    ir = np.nan
                    ic_positive_ratio = (daily_ic_clean > 0).mean()
                    num_obs = len(daily_ic_clean)
                else:
                    mean_ic = daily_ic_clean.mean()
                    std_ic = daily_ic_clean.std()
                    # Calculate IR, handle std_dev being zero or NaN
                    if pd.notna(std_ic) and std_ic != 0:
                        ir = mean_ic / std_ic
                    else:
                        ir = np.nan
                    ic_positive_ratio = (daily_ic_clean > 0).mean()
                    num_obs = len(daily_ic_clean) # Number of days with valid IC

                # Store the results
                all_ic_results[(factor, return_col)] = {
                    'Mean IC': mean_ic,
                    'IC Std Dev': std_ic,
                    'IR (IC Mean/Std)': ir,
                    'IC > 0 Ratio': ic_positive_ratio,
                    'Num Observations (Days)': num_obs
                }
                # Optional: Print summary per factor immediately
                # print(f"   - Factor '{factor}': Mean IC={mean_ic:.4f}, IR={ir:.4f}, Obs={num_obs}")

            except Exception as e:
                print(f"Error during IC calculation loop for factor '{factor}' vs '{return_col}': {e}")
                # Ensure a placeholder entry exists if the loop fails catastrophically
                if (factor, return_col) not in all_ic_results:
                     all_ic_results[(factor, return_col)] = {
                         'Mean IC': np.nan, 'IC Std Dev': np.nan, 'IR (IC Mean/Std)': np.nan,
                         'IC > 0 Ratio': np.nan, 'Num Observations (Days)': 0
                     }

    print("\nIC/IR calculation complete for all factor/return pairs.")
    if not all_ic_results:
        print("Warning: No IC results were generated.")
        return pd.DataFrame()

    # Format results into a DataFrame
    try:
        results_df = pd.DataFrame.from_dict(all_ic_results, orient='index')
        results_df.index = pd.MultiIndex.from_tuples(results_df.index, names=['Factor', 'Return Type'])
        results_df = results_df.sort_index()
    except Exception as e_format:
        print(f"Error formatting IC results into DataFrame: {e_format}")
        return None

    return results_df


# Step 5: Analyze Factor Correlation (Optional, Unchanged logic, ensure input data is correct)
def analyze_factor_correlation(df, factors):
    """Calculates and displays the correlation matrix for the selected factors."""
    print("--- Step 5: Analyzing Factor Correlation (Optional) ---")
    if df is None: print("Error: DataFrame missing."); return None
    if 'filter_out' not in df.columns: print("Error: 'filter_out' column missing."); return None
    if not factors: print("Error: No factors provided for correlation analysis."); return None

    # Use data *before* forward returns were added, but after filtering
    # Make sure to use only eligible rows
    df_eligible = df.loc[~df['filter_out']].copy()

    if df_eligible.empty: print("Warning: No eligible bonds for correlation."); return None

    missing_factors = [f for f in factors if f not in df_eligible.columns]
    if missing_factors: print(f"Warning: Factors missing for correlation: {missing_factors}");
    present_factors = [f for f in factors if f in df_eligible.columns]
    if len(present_factors) < 2: print("Warning: Need at least 2 present factors for correlation."); return None

    factor_data = df_eligible[present_factors]

    # Handle potential infinite values before calculating correlation
    factor_data = factor_data.replace([np.inf, -np.inf], np.nan)
    # Drop rows where *any* of the present factors are NaN for the correlation calculation
    factor_data = factor_data.dropna()

    if len(factor_data) < 5: # Need a few points for correlation
        print(f"Warning: Less than 5 valid data points ({len(factor_data)}) after NaN drop for factor correlation."); return None

    # Optional Sampling for very large datasets
    if len(factor_data) > 100000: # Adjust sample size as needed
        print(f"Sampling {100000} rows from {len(factor_data)} for factor correlation calculation...")
        factor_data = factor_data.sample(100000, random_state=42)

    print(f"Calculating Spearman rank correlation matrix on {len(factor_data)} observations...")
    try:
        correlation_matrix = factor_data.corr(method='spearman')
        print("Factor Correlation Matrix:")
        plt.figure(figsize=(max(6, len(present_factors)*0.8), max(5, len(present_factors)*0.6))) # Adjust size
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", linewidths=.5, annot_kws={"size": 8})
        plt.title('Factor Spearman Rank Correlation Heatmap')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
        return correlation_matrix
    except Exception as corr_e:
         print(f"Error calculating or plotting correlation: {corr_e}")
         # Try calculating without plotting if plotting fails
         try:
             correlation_matrix = factor_data.corr(method='spearman')
             print("Correlation matrix calculated but plotting failed.")
             print(correlation_matrix)
             return correlation_matrix
         except Exception as corr_e2:
             print(f"Error calculating correlation even without plotting: {corr_e2}")
             return None


# --- Main Execution Function ---
def run_simplified_factor_analysis(config):
    """Orchestrates the simplified factor vs. return analysis."""

    # Step 1: Load Data
    df_cb_raw, _ = load_data(config['cb_data_path'], config.get('index_data_path')) # Index not strictly needed now
    if df_cb_raw is None:
        print("Stopping analysis: Data loading failed.")
        return None

    # Extract config parameters
    start_date = config['start_date']
    end_date = config['end_date']
    filters = config['filters']
    factors_to_analyze = config['factors_to_analyze'] # List of factor names
    pulse_percentages = config['pulse_percentages']

    # Validate factors exist
    missing_factors_init = [f for f in factors_to_analyze if f not in df_cb_raw.columns]
    if missing_factors_init:
        print(f"Error: Factors specified in config are missing from the loaded data: {missing_factors_init}")
        print("Stopping analysis.")
        return None
    print(f"Factors to analyze: {factors_to_analyze}")
    print(f"Pulse percentages for forward returns: {pulse_percentages}%")


    # Step 2: Filter Data
    # Pass factors_to_analyze to ensure NaNs in these are checked
    df_filtered = filter_data(df_cb_raw, start_date, end_date, filters, factors_to_analyze)
    if df_filtered is None:
        print("Stopping analysis: Data filtering failed.")
        return None
    if df_filtered[~df_filtered['filter_out']].empty:
        print("Stopping analysis: No eligible data remaining after filtering.")
        return None

    # Step 3: Calculate Multiple Forward Returns
    df_with_returns, return_cols = calculate_multiple_fwd_returns(df_filtered, pulse_percentages)
    if df_with_returns is None or not return_cols:
         print("Stopping analysis: Failed to calculate forward returns or no return columns generated.")
         # Return the filtered data for inspection if returns failed
         return {"final_data_before_returns": df_filtered}

    # Step 4: Analyze Factor vs. Each Return Type Relationship
    ic_results_df = analyze_factor_return_relationships(df_with_returns, factors_to_analyze, return_cols)

    # Step 5: Analyze Factor Correlation (Optional)
    factor_correlation_matrix = None
    if config.get('analyze_factor_correlation', False): # Add a flag in config
        # Pass df_filtered (before returns were added) is safer, as return calc might have failed partially
        # or df_with_returns if you want to correlate factors on the dataset that was used for IC
        # Using df_filtered ensures we use the data exactly as it was before return calc attempts
        print("\nRunning Optional Factor Correlation Analysis...")
        factor_correlation_matrix = analyze_factor_correlation(df_filtered, factors_to_analyze)


    # --- Step 6: Report Results ---
    print("\n" + "="*30 + " Factor Analysis Report " + "="*30)

    # --- IC / IR Results ---
    print("\n--- Factor vs. Forward Return Relationship Analysis (IC/IR) ---")
    if ic_results_df is not None and not ic_results_df.empty:
        # Display the full IC/IR table
        print("Full IC/IR Results Table:")
        with pd.option_context('display.max_rows', None, 'display.max_columns', None): # Show all rows/cols
            display(ic_results_df.style.format({
                'Mean IC': '{:.4f}',
                'IC Std Dev': '{:.4f}',
                'IR (IC Mean/Std)': '{:.3f}',
                'IC > 0 Ratio': '{:.1%}',
                'Num Observations (Days)': '{:,.0f}'
            }).background_gradient(cmap='coolwarm', subset=['Mean IC', 'IR (IC Mean/Std)'], vmin=-0.1, vmax=0.1)) # Added gradient


        # --- Highlight Strongest Relationships based on Mean IC ---
        print("\n--- Strongest Relationships (Highest Absolute Mean IC per Return Type) ---")
        try:
            # Find index of max absolute IC for each return type
            idx = ic_results_df.loc[ic_results_df.groupby(level='Return Type')['Mean IC'].apply(lambda x: x.abs().idxmax())]
            display(idx.style.format({
                'Mean IC': '{:.4f}', 'IC Std Dev': '{:.4f}', 'IR (IC Mean/Std)': '{:.3f}',
                'IC > 0 Ratio': '{:.1%}', 'Num Observations (Days)': '{:,.0f}'
            }).set_caption("Factors with Highest Absolute Mean IC for each Return Type"))
        except Exception as e_report_ic:
             print(f"Could not determine strongest relationships based on Mean IC: {e_report_ic}")

        # --- Highlight Strongest Relationships based on IR ---
        print("\n--- Strongest Relationships (Highest Absolute IR per Return Type) ---")
        try:
            # Create Abs IR column for sorting, handle NaNs
            ic_results_df_ir = ic_results_df.copy()
            ic_results_df_ir['Abs IR'] = ic_results_df_ir['IR (IC Mean/Std)'].abs()
            # Find index of max absolute IR for each return type, ignoring NaNs in Abs IR
            idx_ir = ic_results_df_ir.loc[ic_results_df_ir.dropna(subset=['Abs IR']).groupby(level='Return Type')['Abs IR'].idxmax()]

            if not idx_ir.empty:
                 display(idx_ir.drop(columns=['Abs IR']).style.format({ # Display original columns
                     'Mean IC': '{:.4f}', 'IC Std Dev': '{:.4f}', 'IR (IC Mean/Std)': '{:.3f}',
                     'IC > 0 Ratio': '{:.1%}', 'Num Observations (Days)': '{:,.0f}'
                 }).set_caption("Factors with Highest Absolute IR for each Return Type"))
            else:
                 print("Could not determine strongest relationships based on IR (perhaps all IR values were NaN?).")
        except Exception as e_report_ir:
             print(f"Could not determine strongest relationships based on IR: {e_report_ir}")


    else:
        print("Factor vs. Return relationship results (IC/IR) are not available or empty.")

    # --- Factor Correlation Results ---
    if config.get('analyze_factor_correlation', False):
        print("\n--- Factor Correlation Matrix ---")
        if factor_correlation_matrix is not None:
            print("(See heatmap plot above if generated, or matrix output)")
            # Optionally display the matrix again if plotting failed but calculation succeeded
            # print(factor_correlation_matrix)
        else:
            print("Factor correlation matrix could not be calculated or plotted.")

    print("\n" + "="*30 + " Analysis Complete " + "="*30)

    # Return key results
    return {
        "factor_return_ic_ir": ic_results_df,
        "factor_correlation": factor_correlation_matrix,
        # Include data for inspection (might be large!)
        # "final_data_with_returns": df_with_returns
    }




In [8]:
# --- Example Configuration ---
# Use the structure relevant to *this* analysis (factor vs. return correlation)
# The second CONFIG example you provided seemed designed for a different task (composite scoring).
CONFIG = {
    # File Paths
    'cb_data_path': '/Users/yiwei/Desktop/git/cb_data_with_factors_enhanced_with_junxian.pq', # *** CHANGE TO YOUR PATH ***
    'index_data_path': '/Users/yiwei/Desktop/git/index.pq', # Optional, not used in core IC calc

    # Analysis Time Period
    'start_date': '2022-08-01', # Use YYYY-MM-DD format
    'end_date': '2024-12-31',   # Use YYYY-MM-DD format

    # Data Filtering Rules (applied before return calc and IC)
    # Use pandas query syntax (backticks for special characters)
    'filters': [
        # "`转股溢价率` < 0.5",   # Example: Premium < 50%
        # "`剩余规模` > 0.1",     # Example: Remaining size > 0.1 Billion
        "`close` < 150",       # Example: Price < 150
        "`close` > 105",       # Example: Price > 105
        # Add more relevant filters based on your strategy/universe
        # Example: Exclude bonds near redemption
        # "`is_call`.isin(['已公告强赎', '公告到期赎回', '公告实施强赎', '公告提示强赎', '已满足强赎条件']) == False",
        # Example: Exclude newly listed bonds
        # "`list_days` > 3"
    ],

    # Factors to Analyze (Column names from your Parquet file)
    'factors_to_analyze': [
        # 'ytm',              # Yield to maturity
        'conv_prem',        # Conversion premium % (e.g., 0.2 for 20%)
        # '剩余规模',         # Remaining size (e.g., in Billion Yuan)
        # '成交量比',          # Example: Volume ratio (make sure column exists)
        # '双低值',            # Example: Double low value (make sure column exists)
        # '纯债价值',          # Example: Pure bond value
        'close',            # Closing price itself can be a factor
        # 'pb_stk',           # Underlying stock PB ratio
        # '总市值',           # Market cap of underlying stock
        # '波动率_10d',       # 10-day volatility (ensure exists)
        # Add ALL factor column names you want to analyze
    ],

    # Pulse Percentages for Forward Return Calculation
    # Defines the stop-profit thresholds (e.g., 2.0 means 2%)
    # 'pulse_percentages': [2.0, 2.5, 2.8, 3.0, 3.5, 3.8, 4.0, 5.0],
    'pulse_percentages': [3.0],

    # Optional: Run Factor Correlation Analysis
    'analyze_factor_correlation': True # Set to True to calculate and plot factor correlations
}

analysis_results = run_simplified_factor_analysis(CONFIG)

--- Step 1: Loading Data ---
Attempting to set MultiIndex ['code', 'trade_date']...
MultiIndex set successfully.
Loaded CB data shape: (593654, 513)
Loaded Index data shape: (1765, 8)
Factors to analyze: ['conv_prem', 'close']
Pulse percentages for forward returns: [3.0]%
--- Step 2: Filtering Data ---
Filtered by date: 2022-08-01 to 2024-12-31. Shape: (297663, 513)
Applying custom filters...
 - Applied: `close` < 150. Marked 272022 additional rows.
 - Applied: `close` > 105. Marked 25641 additional rows.
Filtering rows with NaN in factor values...
NaN factor check complete. Marked approx 297663 additional rows due to NaN factors.
Filtering complete. Eligible bond-days for return calculation: 0
Stopping analysis: No eligible data remaining after filtering.


In [None]:
# # --- Run the Analysis ---
# if __name__ == "__main__": # Ensures code runs only when script is executed directly
#     # Make sure the file paths in CONFIG are correct before running!
#     print("Starting Factor Analysis...")
#     analysis_results = run_simplified_factor_analysis(CONFIG)
#     print("Analysis finished.")

#     # You can access results for further processing if needed:
#     if analysis_results and "factor_return_ic_ir" in analysis_results:
#         ic_ir_df = analysis_results["factor_return_ic_ir"]
#         if ic_ir_df is not None:
#              print("\n --- IC/IR Results DataFrame Head (for verification) ---")
#              display(ic_ir_df.head())
#         else:
#              print("IC/IR DataFrame was not generated.")
#     else:
#         print("Analysis did not complete successfully or did not return IC/IR results.")