In [1]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
import glob
import logging
from pathlib import Path
from typing import Optional, List # Added List

# --- Configuration ---
PLOTLY_TEMPLATE = "plotly_dark"  # Use plotly's dark theme
BASE_DATA_DIR = os.path.join("data_infra", "data") # Base directory for backtest results
# OUTPUT_SUBDIR is no longer needed as we are showing plots, not saving to a subdir

# Configure logging for the plotting script
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Helper Functions ---

def find_latest_backtest_dir(base_dir: str) -> Optional[str]:
    """Finds the most recently created backtest directory based on timestamp naming."""
    try:
        list_of_dirs = [d for d in Path(base_dir).iterdir() if d.is_dir() and '_backtest_' in d.name]
        if not list_of_dirs:
            logging.warning(f"No directories matching '*_backtest_*' found in {base_dir}")
            return None
        latest_dir = max(list_of_dirs, key=lambda d: d.stat().st_mtime)
        return str(latest_dir)
    except FileNotFoundError:
        logging.error(f"Base data directory not found: {base_dir}")
        return None
    except Exception as e:
        logging.error(f"Error finding latest backtest directory: {e}", exc_info=True)
        return None

def load_csv(file_path: str, index_col=None) -> Optional[pd.DataFrame]:
    """Loads a CSV file into a pandas DataFrame with error handling."""
    path_obj = Path(file_path)
    if not path_obj.is_file():
        logging.warning(f"CSV file not found or is not a file: {file_path}")
        return None
    if path_obj.stat().st_size == 0:
         logging.warning(f"CSV file is empty: {file_path}")
         return pd.DataFrame()

    try:
        df = pd.read_csv(file_path, index_col=index_col)
        # Attempt to parse timestamp if column exists and it's not the index
        # ** Important for performance plot **
        if 'timestamp' in df.columns and (index_col != 'timestamp' or index_col is None):
            # Check if it *looks* like YYYY-MM format (relevant for monthly returns file)
            # This is a basic check, might need refinement
            is_monthly_format = False
            if df['timestamp'].dtype == 'object':
                 try:
                      # See if first non-null value matches pattern
                      first_val = df['timestamp'].dropna().iloc[0]
                      if isinstance(first_val, str) and len(first_val) == 7 and first_val[4] == '-':
                           is_monthly_format = True
                 except IndexError:
                      pass # Empty column after dropna

            if not is_monthly_format: # Only parse full timestamps
                try:
                    df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
                except Exception as date_e:
                     logging.warning(f"Could not parse 'timestamp' column in {file_path}: {date_e}")

        # Attempt to parse index if it's named timestamp
        elif df.index.name == 'timestamp':
             try:
                 df.index = pd.to_datetime(df.index, errors='coerce')
             except Exception as date_e:
                 logging.warning(f"Could not parse 'timestamp' index in {file_path}: {date_e}")

        logging.debug(f"Successfully loaded {os.path.basename(file_path)} ({len(df)} rows)")
        return df
    except pd.errors.EmptyDataError:
        logging.warning(f"CSV file is empty (EmptyDataError): {file_path}")
        return pd.DataFrame()
    except Exception as e:
        logging.error(f"Error loading CSV file {file_path}: {e}", exc_info=True)
        return None

# --- Plotting Functions ---

def plot_performance(df_abs: pd.DataFrame, df_pct: pd.DataFrame, title_suffix="") -> go.Figure:
    """Creates a plot showing absolute value, percentage return, and ticker returns."""
    # (Function body remains the same as the corrected version from previous step)
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    fig.update_layout(
        title=f"Portfolio Performance{title_suffix}",
        xaxis_title="Timestamp",
        yaxis_title="Absolute Value ($)",
        yaxis2_title="Percentage Return (%)",
        template=PLOTLY_TEMPLATE,
        hovermode="x unified",
        legend_title_text='Metrics'
    )

    # Left Y-Axis: Absolute Portfolio Value
    if df_abs is not None and 'timestamp' in df_abs.columns and 'portfolio_value' in df_abs.columns:
        fig.add_trace(
            go.Scatter(x=df_abs['timestamp'], y=df_abs['portfolio_value'], mode='lines', name='Portfolio Value ($)', line=dict(color='lightblue', width=2.5), yaxis='y1'),
            secondary_y=False,
        )
        logging.debug("Added absolute portfolio value trace.")
    else:
         logging.warning("Could not plot absolute portfolio value - required columns/data missing.")

    # Right Y-Axis: Percentage Returns
    portfolio_pct_col = 'portfolio_pct_ret'
    if df_pct is not None and 'timestamp' in df_pct.columns and portfolio_pct_col in df_pct.columns:
        fig.add_trace(
            go.Scatter(x=df_pct['timestamp'], y=df_pct[portfolio_pct_col] * 100, mode='lines', name='Portfolio Return (%)', line=dict(color='gold', width=2, dash='dash'), yaxis='y2'),
            secondary_y=True,
        )
        logging.debug("Added portfolio percentage return trace.")

        ticker_pct_cols = sorted([c for c in df_pct.columns if isinstance(c, str) and c.endswith('_pct_ret') and c != portfolio_pct_col])
        for col in ticker_pct_cols:
            ticker_name = col.replace('_pct_ret', '')
            fig.add_trace(
                go.Scatter(x=df_pct['timestamp'], y=df_pct[col] * 100, mode='lines', name=f'{ticker_name} Return (%)', line=dict(width=1, dash='dot'), yaxis='y2', visible='legendonly'),
                secondary_y=True,
            )
        logging.debug(f"Added {len(ticker_pct_cols)} individual ticker return traces (legendonly).")
    else:
         logging.warning("Could not plot percentage returns - required columns/data missing.")

    # Corrected y-axis formatting
    fig.update_yaxes(title_text="Absolute Value ($)", secondary_y=False, gridcolor='rgba(128,128,128,0.2)')
    fig.update_yaxes(title_text="Percentage Return (%)", secondary_y=True, tickformat=".2f%", gridcolor='rgba(128,128,128,0.2)') # Use tickformat=".2f%"
    fig.update_xaxes(gridcolor='rgba(128,128,128,0.2)')

    return fig


def add_trade_markers(fig: go.Figure, df_trades: pd.DataFrame, performance_series: pd.Series):
    """Adds Buy/Sell markers to an existing performance figure, aligning with the percentage axis."""
    # (Function body remains the same as previous version)
    if df_trades is None or df_trades.empty: logging.info("Skipping trade markers (no trade data)."); return
    if performance_series is None or performance_series.empty: logging.info("Skipping trade markers (no performance series)."); return

    if 'timestamp' not in df_trades.columns or not pd.api.types.is_datetime64_any_dtype(df_trades['timestamp']):
         try:
             df_trades['timestamp'] = pd.to_datetime(df_trades['timestamp'], errors='coerce')
             df_trades = df_trades.dropna(subset=['timestamp'])
         except Exception as e: logging.error(f"Failed to parse trade timestamps: {e}. Skipping markers."); return

    if not isinstance(performance_series.index, pd.DatetimeIndex):
        try:
            performance_series.index = pd.to_datetime(performance_series.index, errors='coerce')
            performance_series = performance_series.dropna()
        except Exception as e: logging.error(f"Failed to convert performance index to DatetimeIndex: {e}. Skipping markers."); return
    if not performance_series.index.is_monotonic_increasing: performance_series = performance_series.sort_index()

    perf_index = performance_series.index
    aligned_data = {'BUY': {'x': [], 'y': [], 'text': []}, 'SELL': {'x': [], 'y': [], 'text': []}}

    for _, trade in df_trades.iterrows():
        signal = trade.get('signal_type', 'UNKNOWN').upper()
        if signal not in ['BUY', 'SELL']: continue
        try:
            y_val_pct_return = performance_series.asof(trade['timestamp'])
            if pd.notna(y_val_pct_return):
                aligned_data[signal]['x'].append(trade['timestamp'])
                aligned_data[signal]['y'].append(y_val_pct_return * 100)
                hover_text = (f"<b>{signal} {trade.get('ticker', 'N/A')}</b><br>"
                              f"Shares: {trade.get('shares', 'N/A'):.4f}<br>"
                              f"Price: ${trade.get('fill_price', 'N/A'):.4f}<br>"
                              f"Time: {trade['timestamp'].strftime('%H:%M:%S')}")
                aligned_data[signal]['text'].append(hover_text)
            else: logging.debug(f"Could not find performance value for trade at {trade['timestamp']}")
        except Exception as e: logging.warning(f"Error aligning trade at {trade['timestamp']}: {e}")

    if aligned_data['BUY']['x']:
        fig.add_trace(go.Scatter(x=aligned_data['BUY']['x'], y=aligned_data['BUY']['y'], mode='markers', name='Buy Signal', marker=dict(color='lime', size=8, symbol='triangle-up'), hoverinfo='text', hovertext=aligned_data['BUY']['text'], yaxis='y2'), secondary_y=True)
        logging.debug(f"Added {len(aligned_data['BUY']['x'])} buy markers.")
    if aligned_data['SELL']['x']:
        fig.add_trace(go.Scatter(x=aligned_data['SELL']['x'], y=aligned_data['SELL']['y'], mode='markers', name='Sell Signal', marker=dict(color='red', size=8, symbol='triangle-down'), hoverinfo='text', hovertext=aligned_data['SELL']['text'], yaxis='y2'), secondary_y=True)
        logging.debug(f"Added {len(aligned_data['SELL']['x'])} sell markers.")
    if aligned_data['BUY']['x'] or aligned_data['SELL']['x']: logging.info("Added trade markers to performance plot.")


def plot_rolling_stats(results_dir: str, tickers: list[str], portfolio_col_name: str):
    """Plots rolling volatility from available stats files."""
    # (Function body remains the same - generates fig, returns it)
    rolling_files = glob.glob(os.path.join(results_dir, "*D_Rolling_stats.csv"))
    if not rolling_files:
        logging.warning("No rolling stats files found (*D_Rolling_stats.csv). Skipping plot.")
        return None # Return None if no files

    fig = make_subplots(rows=1, cols=1)
    fig.update_layout(
        title="Rolling Volatility (Std Dev of Daily % Return Changes)",
        xaxis_title="Timestamp", yaxis_title="Volatility",
        template=PLOTLY_TEMPLATE, hovermode="x unified", legend_title_text='Windows'
    )
    plot_found = False
    rolling_files.sort()

    for f_path in rolling_files:
        df_rolling = load_csv(f_path)
        if df_rolling is None or df_rolling.empty or 'timestamp' not in df_rolling.columns: continue
        try:
            filename = os.path.basename(f_path)
            window = filename.split('_Rolling')[0]
            if not window.endswith('D') or not window[:-1].isdigit(): raise ValueError("Parse fail")
        except Exception: window = "UnknownWindow"; logging.warning(f"Could not parse window size from filename: {filename}")

        port_vol_col = f'{portfolio_col_name}_vol_{window.lower()}'
        if port_vol_col in df_rolling.columns:
            fig.add_trace(go.Scatter(x=df_rolling['timestamp'], y=df_rolling[port_vol_col], mode='lines', name=f'Portfolio Vol ({window})', line=dict(width=2)))
            plot_found = True; logging.debug(f"Added portfolio rolling vol trace for window {window}.")

        for ticker in tickers:
            ticker_vol_col = f'{ticker}_pct_ret_vol_{window.lower()}'
            if ticker_vol_col in df_rolling.columns:
                 fig.add_trace(go.Scatter(x=df_rolling['timestamp'], y=df_rolling[ticker_vol_col], mode='lines', name=f'{ticker} Vol ({window})', line=dict(width=1, dash='dash'), visible='legendonly'))
                 plot_found = True; logging.debug(f"Added {ticker} rolling vol trace for window {window} (legendonly).")

    if plot_found:
        fig.update_yaxes(gridcolor='rgba(128,128,128,0.2)')
        fig.update_xaxes(gridcolor='rgba(128,128,128,0.2)')
        logging.info("Generated rolling volatility figure.")
        return fig
    else:
        logging.warning("Could not generate rolling volatility plot - no relevant data columns found.")
        return None


def plot_monthly_heatmap(df_monthly: pd.DataFrame, title="Monthly Returns Heatmap") -> go.Figure:
    """Generates a heatmap of monthly returns."""
    if df_monthly is None or df_monthly.empty:
        logging.warning("Skipping monthly returns heatmap (no data).")
        return go.Figure() # Return empty figure

    # --- START FIX for Month Column ---
    month_col = None
    if 'Month' in df_monthly.columns:
        month_col = 'Month'
    elif 'timestamp' in df_monthly.columns: # Check for 'timestamp' based on user's CSV
        month_col = 'timestamp'
        logging.debug(f"Using '{month_col}' column for heatmap months.")
    else:
        logging.error("Monthly returns CSV missing 'Month' or 'timestamp' column.")
        return go.Figure()
    # --- END FIX ---

    df_plot = df_monthly.copy()
    # Ensure the identified month column is the index
    if df_plot.index.name != month_col:
        try:
             # Make sure the month column exists before setting index
             if month_col not in df_plot.columns:
                  raise KeyError(f"Column '{month_col}' designated as month column not found.")
             df_plot = df_plot.set_index(month_col)
        except KeyError as e:
             logging.error(f"Error setting index for monthly heatmap: {e}")
             return go.Figure()

    # Identify return columns (ending with _pct_ret)
    return_cols = [col for col in df_plot.columns if isinstance(col, str) and col.endswith('_pct_ret')]
    if not return_cols:
         logging.error("No percentage return columns ('*_pct_ret') found in monthly data.")
         return go.Figure()

    # Convert return columns to numeric, coercing errors, then multiply by 100
    # Important: Select only return columns *before* applying numeric conversion
    df_numeric = df_plot[return_cols].apply(pd.to_numeric, errors='coerce') * 100

    # Clean up Y-axis labels (asset/portfolio names)
    clean_y_labels = [c.replace('_pct_ret','') for c in df_numeric.columns]

    fig = go.Figure(data=go.Heatmap(
                   z=df_numeric.values.T,
                   x=df_numeric.index, # Index should be the Month/Timestamp strings
                   y=clean_y_labels,
                   colorscale='RdYlGn', zmid=0,
                   text=df_numeric.values.T, texttemplate="%{text:.2f}%",
                   hoverongaps=False,
                   hovertemplate="<b>Month:</b> %{x}<br><b>Asset:</b> %{y}<br><b>Return:</b> %{z:.2f}%<extra></extra>"
                   ))

    fig.update_layout(
        title=title,
        xaxis_title="Month", yaxis_title="Asset / Portfolio",
        template=PLOTLY_TEMPLATE,
        yaxis=dict(tickmode='array', tickvals=list(range(len(clean_y_labels))), ticktext=clean_y_labels), # Use range for tickvals if y labels are strings
        xaxis=dict(type='category') # Treat month strings (YYYY-MM) as categories
    )
    logging.info("Created monthly returns heatmap figure.")
    return fig


def plot_correlation_heatmap(df_corr: pd.DataFrame, title="Ticker Return Correlation Heatmap") -> go.Figure:
    """Generates a heatmap of the correlation matrix."""
    # (Function body remains the same - generates fig, returns it)
    if df_corr is None or df_corr.empty:
        logging.warning("Skipping correlation heatmap (no data).")
        return go.Figure()

    clean_labels = [str(c).replace('_pct_ret','') for c in df_corr.columns] # Ensure labels are strings

    fig = go.Figure(data=go.Heatmap(
                    z=df_corr.values, x=clean_labels, y=clean_labels,
                    colorscale='viridis', zmin=-1, zmax=1,
                    text=df_corr.values, texttemplate="%{text:.2f}",
                    hoverongaps=False,
                    hovertemplate="<b>Corr(%{x}, %{y})</b> = %{z:.3f}<extra></extra>"
                    ))
    fig.update_layout(
        title=title, xaxis_title="Ticker", yaxis_title="Ticker",
        template=PLOTLY_TEMPLATE, xaxis={'side': 'bottom'},
        yaxis={'autorange': 'reversed'}
    )
    logging.info("Created correlation heatmap figure.")
    return fig


# --- Main Execution ---
if __name__ == "__main__":
    logging.info("--- Starting Backtest Visualization Script ---")

    # --- Find Target Directory ---
    target_dir = find_latest_backtest_dir(BASE_DATA_DIR)
    # To hardcode: comment above line and uncomment below, replacing path
    # target_dir = os.path.join(BASE_DATA_DIR, "YYYYMMDD_HHMMSS_backtest_ID")

    if not target_dir or not os.path.isdir(target_dir):
        logging.error(f"Target backtest directory not found or invalid: '{target_dir}'. Exiting.")
        exit(1)

    logging.info(f"Generating visualizations for backtest results in: {target_dir}")

    # --- Load necessary files ---
    logging.info("Loading CSV files...")
    df_abs = load_csv(os.path.join(target_dir, "performance_timeseries_absolute.csv"))
    df_pct = load_csv(os.path.join(target_dir, "performance_timeseries_percentage.csv"))
    df_trades = load_csv(os.path.join(target_dir, "trade_log.csv"))
    df_monthly = load_csv(os.path.join(target_dir, "monthly_returns.csv"))
    df_corr = load_csv(os.path.join(target_dir, "ticker_return_correlations.csv"), index_col=0)

    # --- Basic Checks ---
    if df_abs is None or df_pct is None:
        logging.error("Essential performance timeseries CSV not found (absolute or percentage). Cannot generate core plots. Exiting.")
        exit(1)

    # --- Infer Portfolio Info ---
    # (Inferring portfolio_id and tickers remains the same)
    portfolio_id = "Unknown"; tickers: List[str] = []
    try:
        dir_name = os.path.basename(target_dir)
        if '_backtest_' in dir_name: portfolio_id = dir_name.split('_backtest_')[-1]
    except Exception as e: logging.warning(f"Could not infer portfolio ID: {e}")
    portfolio_pct_col = 'portfolio_pct_ret'
    if df_pct is not None and not df_pct.empty:
        potential_ticker_pct_cols = [c for c in df_pct.columns if isinstance(c, str) and c.endswith('_pct_ret') and c != portfolio_pct_col]
        tickers = sorted([c.replace('_pct_ret','') for c in potential_ticker_pct_cols])
        if portfolio_pct_col not in df_pct.columns: logging.warning(f"Column '{portfolio_pct_col}' not found.")
    logging.info(f"Portfolio ID: {portfolio_id}, Tickers: {tickers}")


    # --- Generate and Show Performance Plot ---
    logging.info("Generating performance plot...")
    try:
        fig_perf = plot_performance(df_abs, df_pct, title_suffix=f" (Portfolio {portfolio_id})")
        if df_trades is not None and df_pct is not None and portfolio_pct_col in df_pct.columns:
             perf_series = df_pct.set_index('timestamp')[portfolio_pct_col]
             if not isinstance(perf_series.index, pd.DatetimeIndex): perf_series.index = pd.to_datetime(perf_series.index, errors='coerce'); perf_series = perf_series.dropna()
             if not perf_series.index.is_monotonic_increasing: perf_series = perf_series.sort_index()
             if not perf_series.empty: add_trade_markers(fig_perf, df_trades, perf_series)
             else: logging.warning("Could not create valid performance series for trade markers.")
        else: logging.warning("Skipping trade markers - data missing.")

        if fig_perf.data: # Check if figure has data
            fig_perf.show() # *** MODIFICATION: Show instead of write_html ***
            logging.info("Displayed performance plot.")
        else:
            logging.warning("Performance plot figure was empty.")
    except Exception as e:
        logging.error(f"Failed to generate/show performance plot: {e}", exc_info=True)


    # --- Generate and Show Rolling Stats Plot ---
    logging.info("Generating rolling stats plot...")
    try:
        if df_pct is not None and tickers:
            fig_rolling = plot_rolling_stats(target_dir, tickers, portfolio_pct_col)
            if fig_rolling and fig_rolling.data: # Check if figure was returned and has data
                fig_rolling.show() # *** MODIFICATION: Show instead of write_html ***
                logging.info("Displayed rolling stats plot.")
            else:
                 logging.warning("Rolling stats plot figure was empty or not generated.")
        else:
            logging.warning("Skipping rolling stats plot - data missing.")
    except Exception as e:
        logging.error(f"Failed to generate/show rolling stats plot: {e}", exc_info=True)


    # --- Generate and Show Monthly Returns Heatmap ---
    logging.info("Generating monthly returns heatmap...")
    if df_monthly is not None:
        try:
            fig_monthly = plot_monthly_heatmap(df_monthly, title=f"Monthly Returns (%) - Portfolio {portfolio_id}")
            if fig_monthly.data:
                fig_monthly.show() # *** MODIFICATION: Show instead of write_html ***
                logging.info("Displayed monthly returns heatmap.")
            else:
                 logging.warning("Monthly returns heatmap figure was empty.")
        except Exception as e:
            logging.error(f"Failed to generate/show monthly returns heatmap: {e}", exc_info=True)
    else:
        logging.warning("Skipping monthly returns heatmap - CSV not loaded.")


    # --- Generate and Show Correlation Heatmap ---
    logging.info("Generating correlation heatmap...")
    if df_corr is not None:
        try:
             fig_corr = plot_correlation_heatmap(df_corr, title=f"Daily Ticker Return Correlations - Portfolio {portfolio_id}")
             if fig_corr.data:
                 fig_corr.show() # *** MODIFICATION: Show instead of write_html ***
                 logging.info("Displayed correlation heatmap.")
             else:
                 logging.warning("Correlation heatmap figure was empty.")
        except Exception as e:
            logging.error(f"Failed to generate/show correlation heatmap: {e}", exc_info=True)
    else:
        logging.warning("Skipping correlation heatmap - CSV not loaded.")

    logging.info("--- Visualization Script Finished ---")
    logging.info("NOTE: Plots displayed in separate browser windows/tabs.")

2025-06-12 23:26:35,410 - INFO - --- Starting Backtest Visualization Script ---
2025-06-12 23:26:35,412 - INFO - Generating visualizations for backtest results in: data_infra/data/20250612_232546_backtest_02
2025-06-12 23:26:35,413 - INFO - Loading CSV files...
2025-06-12 23:26:35,532 - INFO - Portfolio ID: 02, Tickers: ['NVDA', 'TSLA']
2025-06-12 23:26:35,533 - INFO - Generating performance plot...
2025-06-12 23:26:36,602 - INFO - Added trade markers to performance plot.


2025-06-12 23:26:37,141 - INFO - Displayed performance plot.
2025-06-12 23:26:37,143 - INFO - Generating rolling stats plot...
2025-06-12 23:26:37,587 - INFO - Generated rolling volatility figure.


2025-06-12 23:26:38,361 - INFO - Displayed rolling stats plot.
2025-06-12 23:26:38,361 - INFO - Generating monthly returns heatmap...
2025-06-12 23:26:38,385 - INFO - Created monthly returns heatmap figure.


2025-06-12 23:26:38,387 - INFO - Displayed monthly returns heatmap.
2025-06-12 23:26:38,388 - INFO - Generating correlation heatmap...
2025-06-12 23:26:38,401 - INFO - Created correlation heatmap figure.


2025-06-12 23:26:38,403 - INFO - Displayed correlation heatmap.
2025-06-12 23:26:38,403 - INFO - --- Visualization Script Finished ---
2025-06-12 23:26:38,404 - INFO - NOTE: Plots displayed in separate browser windows/tabs.
