In [2]:
# ----------------------
# Imports and Setup
# ----------------------

"""
Inflation Mechanisms Analysis
==============================
This module analyzes transmission mechanisms through which inflation affects
asset returns, including real rate channels and lagged effects.

Integrated into Week 1 Project
Author: Enhanced from Jupyter Notebook
Date: 2026-01-27
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional, Tuple
from scipy import stats
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [4]:
# ----------------------
# DoWhy and EconML
# ----------------------

class InflationMechanismAnalyzer:
    """
    Analyzer for understanding the mechanisms through which inflation
    affects asset returns.
    """
    
    def __init__(self, df: pd.DataFrame, figures_dir: Optional[Path] = None):
        """
        Initialize the analyzer with financial data.
        
        Parameters
        ----------
        df : pd.DataFrame
            DataFrame containing inflation and asset return data
        figures_dir : Path, optional
            Directory to save figures. If None, uses '../figures'
        """
        self.df = df.copy()
        self.lag_correlations = {}
        self.channel_analysis = {}
        
        if figures_dir is None:
            self.figures_dir = Path.cwd() / "figures"
        else:
            self.figures_dir = Path(figures_dir)

        self.figures_dir.mkdir(parents=True, exist_ok=True)
        
    def analyze_real_rate_channel(
        self,
        inflation_col: str = 'Inflation_Rate',
        return_col: str = 'Real_Return',
        plot: bool = True,
        figsize: Tuple[int, int] = (12, 7)
    ) -> Dict[str, float]:
        """
        Analyze the relationship between inflation and real returns.
        
        Since we don't have separate real interest rate data, we analyze
        the direct relationship between inflation and real returns.
        
        Parameters
        ----------
        inflation_col : str, default='Inflation_Rate'
            Column name for inflation data
        return_col : str, default='Real_Return'
            Column name for real returns
        plot : bool, default=True
            Whether to create a visualization
        figsize : Tuple[int, int], default=(12, 7)
            Figure size for plot
            
        Returns
        -------
        Dict[str, float]
            Statistics on the inflation-return relationship
        """
        # Remove NaN values
        clean_data = self.df[[inflation_col, return_col]].dropna()
        
        # Calculate correlation
        correlation = clean_data[inflation_col].corr(clean_data[return_col])
        
        # Perform linear regression
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            clean_data[inflation_col],
            clean_data[return_col]
        )
        
        results = {
            'correlation': correlation,
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value ** 2,
            'p_value': p_value,
            'std_error': std_err,
            'observations': len(clean_data)
        }
        
        self.channel_analysis['inflation_return'] = results
        
        if plot:
            fig, ax = plt.subplots(figsize=figsize)
            
            # Scatter plot with regression line
            sns.regplot(
                x=inflation_col,
                y=return_col,
                data=clean_data,
                ax=ax,
                scatter_kws={'alpha': 0.6, 's': 50, 'edgecolors': 'white', 'linewidths': 0.5},
                line_kws={'color': '#e74c3c', 'linewidth': 2.5}
            )
            
            # Add reference lines
            ax.axhline(y=0, color='black', linestyle='-', linewidth=0.8, alpha=0.3)
            ax.axvline(x=0, color='black', linestyle='-', linewidth=0.8, alpha=0.3)
            
            # Formatting
            ax.set_xlabel('Inflation Rate (YoY %)', fontsize=13, fontweight='bold')
            ax.set_ylabel('Real Return (Monthly %)', fontsize=13, fontweight='bold')
            ax.set_title(
                f'Inflation vs Real Returns: Direct Relationship\n'
                f'Correlation: {correlation:.3f} | R²: {results["r_squared"]:.3f} | '
                f'P-value: {p_value:.4f}',
                fontsize=14,
                fontweight='bold',
                pad=15
            )
            
            # Add regression equation
            eq_text = f'y = {slope:.4f}x + {intercept:.4f}'
            ax.text(0.05, 0.95, eq_text, transform=ax.transAxes,
                   fontsize=11, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            ax.grid(True, alpha=0.3)
            plt.tight_layout()
            
            save_path = self.figures_dir / 'mechanism_inflation_return_channel.png'
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"✓ Saved: {save_path}")
            plt.close()
        
        return results
    
    def create_lagged_features(
        self,
        column: str = 'Inflation_Rate',
        lags: Optional[List[int]] = None
    ) -> pd.DataFrame:
        """
        Create lagged versions of inflation for analyzing delayed effects.
        
        Parameters
        ----------
        column : str, default='Inflation_Rate'
            Column to create lags for
        lags : List[int], optional
            List of lag periods (in months). If None, uses [3, 6, 12]
            
        Returns
        -------
        pd.DataFrame
            DataFrame with original and lagged columns
        """
        if lags is None:
            lags = [3, 6, 12]
        
        for lag in lags:
            lag_col_name = f"{column}_Lag_{lag}"
            self.df[lag_col_name] = self.df[column].shift(lag)
        
        return self.df
    
    def analyze_lagged_effects(
        self,
        inflation_col: str = 'Inflation_Rate',
        target_col: str = 'Real_Return',
        lags: Optional[List[int]] = None,
        plot: bool = True
    ) -> pd.DataFrame:
        """
        Analyze how past inflation affects current returns.
        
        Parameters
        ----------
        inflation_col : str, default='Inflation_Rate'
            Column name for inflation data
        target_col : str, default='Real_Return'
            Column name for the target variable
        lags : List[int], optional
            List of lag periods to analyze
        plot : bool, default=True
            Whether to create a visualization
            
        Returns
        -------
        pd.DataFrame
            Summary of lagged correlations and statistics
        """
        if lags is None:
            lags = [0, 3, 6, 9, 12, 18, 24]
        
        results = []
        
        for lag in lags:
            if lag == 0:
                lag_col = inflation_col
            else:
                lag_col = f"{inflation_col}_Lag_{lag}"
                if lag_col not in self.df.columns:
                    self.df[lag_col] = self.df[inflation_col].shift(lag)
            
            # Calculate correlation
            clean_data = self.df[[lag_col, target_col]].dropna()
            
            if len(clean_data) > 10:  # Need sufficient data
                correlation = clean_data[lag_col].corr(clean_data[target_col])
                
                # Perform regression
                slope, intercept, r_value, p_value, std_err = stats.linregress(
                    clean_data[lag_col],
                    clean_data[target_col]
                )
                
                results.append({
                    'lag_months': lag,
                    'correlation': correlation,
                    'slope': slope,
                    'r_squared': r_value ** 2,
                    'p_value': p_value,
                    'observations': len(clean_data)
                })
        
        results_df = pd.DataFrame(results)
        self.lag_correlations[target_col] = results_df
        
        if plot and len(results_df) > 0:
            fig, axes = plt.subplots(1, 2, figsize=(15, 6))
            
            # Plot correlations
            axes[0].plot(
                results_df['lag_months'],
                results_df['correlation'],
                marker='o',
                linewidth=2.5,
                markersize=10,
                color='#3498db',
                markerfacecolor='white',
                markeredgewidth=2
            )
            axes[0].axhline(y=0, color='red', linestyle='--', alpha=0.6, linewidth=1.5)
            axes[0].fill_between(
                results_df['lag_months'],
                0,
                results_df['correlation'],
                alpha=0.2,
                color='#3498db'
            )
            axes[0].set_xlabel('Lag (Months)', fontsize=12, fontweight='bold')
            axes[0].set_ylabel('Correlation with Real Returns', fontsize=12, fontweight='bold')
            axes[0].set_title('Correlation vs Lag Period', fontsize=13, fontweight='bold')
            axes[0].grid(True, alpha=0.3)
            
            # Plot R-squared
            axes[1].plot(
                results_df['lag_months'],
                results_df['r_squared'],
                marker='s',
                linewidth=2.5,
                markersize=10,
                color='#2ecc71',
                markerfacecolor='white',
                markeredgewidth=2
            )
            axes[1].fill_between(
                results_df['lag_months'],
                0,
                results_df['r_squared'],
                alpha=0.2,
                color='#2ecc71'
            )
            axes[1].set_xlabel('Lag (Months)', fontsize=12, fontweight='bold')
            axes[1].set_ylabel('R-Squared', fontsize=12, fontweight='bold')
            axes[1].set_title('Explanatory Power vs Lag Period', fontsize=13, fontweight='bold')
            axes[1].grid(True, alpha=0.3)
            
            plt.suptitle('Lagged Inflation Effects on Real Returns', 
                        fontsize=15, fontweight='bold', y=1.02)
            plt.tight_layout()
            
            save_path = self.figures_dir / 'mechanism_lagged_effects_analysis.png'
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"✓ Saved: {save_path}")
            plt.close()
        
        return results_df
    
    def analyze_volatility_channel(
        self,
        inflation_col: str = 'Inflation_Rate',
        return_col: str = 'Nominal_Return',
        window: int = 12,
        plot: bool = True
    ) -> Dict[str, float]:
        """
        Analyze whether inflation affects return volatility.
        
        Parameters
        ----------
        inflation_col : str, default='Inflation_Rate'
            Column name for inflation data
        return_col : str, default='Nominal_Return'
            Column name for returns
        window : int, default=12
            Rolling window for volatility calculation
        plot : bool, default=True
            Whether to create visualization
            
        Returns
        -------
        Dict[str, float]
            Statistics on inflation-volatility relationship
        """
        # Calculate rolling volatility
        self.df['Return_Volatility'] = (
            self.df[return_col]
            .rolling(window)
            .std()
        )
        
        # Analyze relationship
        clean_data = self.df[[inflation_col, 'Return_Volatility']].dropna()
        
        correlation = clean_data[inflation_col].corr(clean_data['Return_Volatility'])
        
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            clean_data[inflation_col],
            clean_data['Return_Volatility']
        )
        
        results = {
            'correlation': correlation,
            'slope': slope,
            'intercept': intercept,
            'r_squared': r_value ** 2,
            'p_value': p_value,
            'observations': len(clean_data)
        }
        
        self.channel_analysis['volatility'] = results
        
        if plot:
            fig, ax = plt.subplots(figsize=(12, 7))
            
            # Scatter plot with regression line
            sns.regplot(
                x=inflation_col,
                y='Return_Volatility',
                data=clean_data,
                ax=ax,
                scatter_kws={'alpha': 0.6, 's': 50, 'edgecolors': 'white', 'linewidths': 0.5},
                line_kws={'color': '#9b59b6', 'linewidth': 2.5}
            )
            
            # Formatting
            ax.set_xlabel('Inflation Rate (YoY %)', fontsize=13, fontweight='bold')
            ax.set_ylabel(f'{window}-Month Rolling Volatility', fontsize=13, fontweight='bold')
            ax.set_title(
                f'Inflation vs Return Volatility Channel\n'
                f'Correlation: {correlation:.3f} | R²: {results["r_squared"]:.3f} | '
                f'P-value: {p_value:.4f}',
                fontsize=14,
                fontweight='bold',
                pad=15
            )
            
            # Add regression equation
            eq_text = f'y = {slope:.4f}x + {intercept:.4f}'
            ax.text(0.05, 0.95, eq_text, transform=ax.transAxes,
                   fontsize=11, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            ax.grid(True, alpha=0.3)
            plt.tight_layout()
            
            save_path = self.figures_dir / 'mechanism_inflation_volatility_channel.png'
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"✓ Saved: {save_path}")
            plt.close()
        
        return results
    
    def run_full_analysis(
        self,
        inflation_col: str = 'Inflation_Rate',
        return_col: str = 'Real_Return',
        nominal_return_col: str = 'Nominal_Return',
        lags: Optional[List[int]] = None,
        verbose: bool = True
    ) -> Dict:
        """
        Run complete inflation mechanisms analysis.
        
        Parameters
        ----------
        inflation_col : str, default='Inflation_Rate'
            Column name for inflation data
        return_col : str, default='Real_Return'
            Column name for real returns
        nominal_return_col : str, default='Nominal_Return'
            Column name for nominal returns
        lags : List[int], optional
            Lag periods to analyze
        verbose : bool, default=True
            Whether to print results
            
        Returns
        -------
        Dict
            Dictionary containing all analysis results
        """
        results = {}
        
        if verbose:
            print("=" * 70)
            print("INFLATION MECHANISMS ANALYSIS")
            print("=" * 70)
        
        # Analyze direct inflation-return channel
        if verbose:
            print("\n1. Direct Inflation-Return Channel")
            print("-" * 70)
        
        direct_results = self.analyze_real_rate_channel(
            inflation_col,
            return_col,
            plot=True
        )
        results['direct_channel'] = direct_results
        
        if verbose:
            print(f"  Correlation: {direct_results['correlation']:.4f}")
            print(f"  Slope: {direct_results['slope']:.6f}")
            print(f"  R-squared: {direct_results['r_squared']:.4f}")
            print(f"  P-value: {direct_results['p_value']:.4e}")
            print(f"  Observations: {direct_results['observations']}")
        
        # Analyze lagged effects
        if verbose:
            print("\n2. Lagged Effects Analysis")
            print("-" * 70)
        
        lag_results = self.analyze_lagged_effects(
            inflation_col,
            return_col,
            lags=lags,
            plot=True
        )
        results['lagged_effects'] = lag_results
        
        if verbose:
            print("\n  Lag Structure:")
            for _, row in lag_results.iterrows():
                print(f"    {int(row['lag_months']):2d} months: "
                      f"corr={row['correlation']:7.4f}, "
                      f"R²={row['r_squared']:.4f}, "
                      f"p={row['p_value']:.4e}")
        
        # Analyze volatility channel
        if verbose:
            print("\n3. Volatility Channel")
            print("-" * 70)
        
        vol_results = self.analyze_volatility_channel(
            inflation_col,
            nominal_return_col,
            plot=True
        )
        results['volatility_channel'] = vol_results
        
        if verbose:
            print(f"  Correlation: {vol_results['correlation']:.4f}")
            print(f"  Slope: {vol_results['slope']:.6f}")
            print(f"  R-squared: {vol_results['r_squared']:.4f}")
            print(f"  P-value: {vol_results['p_value']:.4e}")
        
        # Summary
        if verbose:
            print("\n" + "=" * 70)
            print("KEY INSIGHTS")
            print("=" * 70)
            
            # Find strongest lag
            if len(lag_results) > 0:
                strongest_idx = lag_results['correlation'].abs().idxmax()
                strongest = lag_results.loc[strongest_idx]
                print(f"\n• Strongest lagged effect: {int(strongest['lag_months'])} months")
                print(f"  (correlation: {strongest['correlation']:.4f})")
            
            # Interpret direct channel
            if direct_results['p_value'] < 0.05:
                direction = "negative" if direct_results['correlation'] < 0 else "positive"
                print(f"\n• Significant {direction} direct relationship")
                print(f"  (p-value: {direct_results['p_value']:.4f})")
            else:
                print(f"\n• No significant direct relationship")
                print(f"  (p-value: {direct_results['p_value']:.4f})")
            
            # Interpret volatility channel
            if vol_results['p_value'] < 0.05:
                direction = "increases" if vol_results['correlation'] > 0 else "decreases"
                print(f"\n• Inflation {direction} return volatility")
                print(f"  (p-value: {vol_results['p_value']:.4f})")
        
        return results
    
    def get_strongest_lag(self, target_col: str = 'Real_Return') -> Dict:
        """
        Identify the lag period with the strongest relationship.
        
        Parameters
        ----------
        target_col : str, default='Real_Return'
            Target variable to analyze
            
        Returns
        -------
        Dict
            Information about the strongest lag
        """
        if target_col not in self.lag_correlations:
            raise ValueError(
                f"No lag analysis found for {target_col}. "
                "Run analyze_lagged_effects first."
            )
        
        lag_df = self.lag_correlations[target_col]
        
        # Find lag with highest absolute correlation
        strongest_idx = lag_df['correlation'].abs().idxmax()
        strongest_lag = lag_df.loc[strongest_idx]
        
        return {
            'lag_months': int(strongest_lag['lag_months']),
            'correlation': strongest_lag['correlation'],
            'r_squared': strongest_lag['r_squared'],
            'p_value': strongest_lag['p_value']
        }


def load_project_data(data_path: Optional[Path] = None) -> pd.DataFrame:
    """
    Load the combined analysis data from the Week 1 project.
    
    Parameters
    ----------
    data_path : Path, optional
        Path to combined_analysis.csv. If None, uses default project location.
        
    Returns
    -------
    pd.DataFrame
        Loaded and prepared dataframe
    """
    if data_path is None:
        project_root = Path.cwd()
        while project_root.name != "Inflation vs Market Returns Analysis":
            project_root = project_root.parent

        data_path = (
            project_root
            / "data"
            / "processed"
            / "combined_analysis.csv"
        )

    df = pd.read_csv(
        data_path,
        parse_dates=["Date"],
        index_col="Date"
    )

    return df


def main():
    """
    Run inflation mechanisms analysis on Week 1 project data.
    """
    print("\nLoading Week 1 Project Data...")
    print("-" * 70)
    
    # Load data
    df = load_project_data()
    print(f"✓ Loaded {len(df)} observations")
    print(f"  Date range: {df.index.min()} to {df.index.max()}")
    print(f"  Columns: {', '.join(df.columns)}")
    
    # Initialize analyzer
    analyzer = InflationMechanismAnalyzer(df)
    
    # Run full analysis
    results = analyzer.run_full_analysis(verbose=True)
    
    print("\n" + "=" * 70)
    print("✓ ANALYSIS COMPLETE")
    print("=" * 70)
    print(f"\nFigures saved to: {analyzer.figures_dir}")
    
    return analyzer, results


if __name__ == "__main__":
    analyzer, results = main()


Loading Week 1 Project Data...
----------------------------------------------------------------------
✓ Loaded 130 observations
  Date range: 2012-05-31 00:00:00 to 2023-02-28 00:00:00
  Columns: Inflation_Rate, Nominal_Return, Real_Return, Cumulative_Nominal, Cumulative_Real, Inflation_Regime
INFLATION MECHANISMS ANALYSIS

1. Direct Inflation-Return Channel
----------------------------------------------------------------------
✓ Saved: C:\Users\rfull\Building Data Together Weeklies\Finance February\Inflation vs Market Returns Analysis\figures\mechanism_inflation_return_channel.png
  Correlation: -0.5614
  Slope: -1.269203
  R-squared: 0.3152
  P-value: 3.6967e-12
  Observations: 130

2. Lagged Effects Analysis
----------------------------------------------------------------------
✓ Saved: C:\Users\rfull\Building Data Together Weeklies\Finance February\Inflation vs Market Returns Analysis\figures\mechanism_lagged_effects_analysis.png

  Lag Structure:
     0 months: corr=-0.5614, R²=0