In [1]:
from regimes import *
# Set seaborn style for better aesthetics
sns.set(style="whitegrid")

help me to add more variable to main to modularise the function and use it to retrieve analysis on multiple series. for instance i will next want to plot vol shift perfomance

### Inpsect data

In [3]:
# Define your plotting functions here or import them if they are in separate modules
# from regimes import plot_regime_distributions, plot_regime_pdf_individual, plot_regime_pdf_overlay, plot_regime_scatter_2x2, plot_regime_scatter_combined

def main(col_names=['date', 'spot', 'vol_1m', 'rv_1m'], performance_metrics=None, colorpalette=None):
    """
    Main function to process data and generate plots for multiple performance metrics.
    
    Parameters:
    - col_names (list): Column names for the raw data.
    - performance_metrics (list of dict): Each dict contains 'label', 'col_name', and 'type' for performance calculation.
    - colorpalette (dict): Dictionary mapping regimes to colors.
    """
    # -------------------------------
    # 1. Data Generation and Preparation
    # -------------------------------
    
    # Generate macro regime dates and raw data
    macro_regime_dates = generate_macro_regime_dates()
    raw_data = generate_raw_data()

    # Reset index to include 'date' as a column in raw_data
    raw_data.reset_index(inplace=True)

    # Debug: Print raw data columns before renaming
    print("Raw Data Columns Before Renaming:", raw_data.columns.tolist())

    # Rename columns for consistency
    raw_data.columns = col_names
    print("Raw Data Columns After Renaming:", raw_data.columns.tolist())

    # Reset index to include 'date' as a column in macro_regime_dates
    macro_regime_dates.reset_index(inplace=True)

    # Rename columns if necessary
    macro_regime_dates.columns = ['date', 'regime']
    print("Macro Regime Dates Columns After Renaming:", macro_regime_dates.columns.tolist())
    print("Macro Regime Dates Head:")
    print(macro_regime_dates.head())

    # -------------------------------
    # 2. Data Type Conversion
    # -------------------------------
    
    # Ensure 'date' columns are datetime
    raw_data['date'] = pd.to_datetime(raw_data['date'])
    macro_regime_dates['date'] = pd.to_datetime(macro_regime_dates['date'])

    # -------------------------------
    # 3. Performance Metrics Calculation
    # -------------------------------
    
    # Initialize a dictionary to store all performance series
    performance_series_dict = {}
    
    for metric in performance_metrics:
        label = metric['label']
        col = metric['col_name']
        perf_type = metric['type']
        performance_series = calc_performance(raw_data, col_name=col, performance_type=perf_type)
        performance_series_dict[label] = performance_series
        print(f"Calculated {label} with type '{perf_type}': {performance_series.dtype}")

    # -------------------------------
    # 4. Data Merging and Cleaning
    # -------------------------------
    
    # Iterate over each performance metric
    for label, series in performance_series_dict.items():
        print(f"\nProcessing performance metric: {label}")
        
        # Prepare the merged data for plotting
        merged_spot_df = prepare_plot_data(
            raw_data, 
            macro_regime_dates, 
            series, 
            performance_label=label
        )

        # Debug: Print merged data columns before renaming
        print("\nMerged Spot DataFrame Columns Before Renaming:", merged_spot_df.columns.tolist())
        print("Merged Spot DataFrame Head:")
        print(merged_spot_df.head())

        # Drop the 'date' column as it's no longer needed
        if 'date' in merged_spot_df.columns:
            merged_spot_df = merged_spot_df.drop(columns=['date'])
            print("Dropped 'date' column successfully.")
        else:
            print("'date' column not found. Available columns:", merged_spot_df.columns.tolist())

        # Rename 'SX5E spot (1m return)' to 'spot_return' if it exists
        if 'SX5E spot (1m return)' in merged_spot_df.columns:
            merged_spot_df.rename(columns={'SX5E spot (1m return)': label}, inplace=True)
            print(f"Renamed 'SX5E spot (1m return)' to '{label}'.")
        else:
            print(f"'SX5E spot (1m return)' column not found. Current columns:", merged_spot_df.columns.tolist())

        # Debug: Print merged data columns after renaming
        print("Merged Spot DataFrame Columns After Renaming:", merged_spot_df.columns.tolist())
        print("Merged Spot DataFrame Head After Renaming:")
        print(merged_spot_df.head())

        # -------------------------------
        # 5. Verify Unique Regimes
        # -------------------------------
        
        unique_regimes = merged_spot_df['regime'].unique()
        print("\nUnique Regimes:", unique_regimes)
        print("Regime Counts:\n", merged_spot_df['regime'].value_counts())

        # -------------------------------
        # 6. Plotting
        # -------------------------------
        
        # (A) 2x2 subplots for histograms per regime
        plot_regime_distributions(merged_spot_df, performance_label=label, colorpalette=colorpalette)
        plot_regime_pdf_individual(merged_spot_df, performance_label=label, colorpalette=colorpalette)
        plot_regime_pdf_overlay(merged_spot_df, performance_label=label, colorpalette=colorpalette)

        # (B) Scatter plot grid
        plot_regime_scatter_2x2(merged_spot_df, performance_label=label, colorpalette=colorpalette)
        plot_regime_scatter_combined(merged_spot_df, x_label=label, y_label="vol_1m", colorpalette=colorpalette)

In [None]:
if __name__ == "__main__":
    main(
        col_names=['date', 'spot', 'vol_1m', 'rv_1m'],
        performance_metrics=[
            {'label': 'spot_return', 'col_name': 'spot', 'type': 'return'},
            {'label': 'vol_shift', 'col_name': 'vol_1m', 'type': 'shift'},
            {'label': 'rv_shift', 'col_name': 'rv_1m', 'type': 'shift'}
        ],
        colorpalette={
            'Recession':  (215/255, 48/255, 48/255),      # Pure red
            'Slowdown':   (239/255, 123/255, 90/255),     # Orange-ish
            'Recovery':   (0/255, 145/255, 90/255),       # Darker green
            'Expansion':  (86/255, 180/255, 192/255),     # Pure blue
        }
    )