In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score
import os

# ==========================================
# 1. SETUP & CONFIGURATION
# ==========================================
# UPDATE!!!: Local Directory Paths
print(">> PHASE 5: GENERATING REPORT VISUALS...")

# CONFIGURATION
DATA_DIR = r'data/raw'
OUTPUT_DIR = r'data/processed'
IMG_DIR = r'data/visuals'

# Ensure directories exist
if not os.path.exists(IMG_DIR):
    os.makedirs(IMG_DIR)
    print(f">> Created directory: {IMG_DIR}")

# Style settings
sns.set_style("whitegrid")
plt.rcParams['font.family'] = 'sans-serif' 

# ==========================================
# VISUAL A: THE J-CURVE (Conceptual)
# ==========================================
def plot_j_curve():
    print("\n>> Generating Visual A: J-Curve...")

    # Hardcoded coefficients for conceptual illustration
    periods = ['T=0\n(Disruption)', 'T=1\n(Adjustment)', 'T=2+\n(Payoff)']
    coefficients = [-0.18, -0.05, 0.22]
    colors = ['#e74c3c' if x < 0 else '#2ecc71' for x in coefficients]

    plt.figure(figsize=(10, 6))

    # fix: mapping x to hue to avoid seaborn future warnings
    ax = sns.barplot(x=periods, y=coefficients, hue=periods, palette=colors, legend=False)

    ax.tick_params(axis='x', pad=20)
    plt.axhline(0, color='black', linewidth=1.5)

    plt.title('The J-Curve Effect: The Hidden Cost of Transformation', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('Impact on Resilience Score (Beta Coefficient)', fontsize=12)
    plt.xlabel('Time Since Technology Implementation', fontsize=12)

    # label logic
    for i, v in enumerate(coefficients):
        if i == 1: 
            ax.text(i, v - 0.02, f"{v:+.2f}", ha='center', va='top', fontweight='bold', fontsize=11, color='black')
        elif v < 0:
            ax.text(i, v - 0.02, f"{v:+.2f}", ha='center', va='top', fontweight='bold', fontsize=11, color='black')
        else:
            ax.text(i, v + 0.01, f"{v:+.2f}", ha='center', va='bottom', fontweight='bold', fontsize=11, color='black')

    plt.tight_layout()
    save_path = os.path.join(IMG_DIR, '11_VisualA_JCurve.png')
    plt.savefig(save_path, dpi=300)
    plt.show()

# ==========================================
# VISUAL B: THE TWIN TEST (Bias Check)
# ==========================================
def plot_twin_test():
    print("\n>> Generating Visual B: Twin Test...")

    # Conceptual Data based on your Narrative
    data = {
        'Index_Version': ['V1: Policy Index\n(Stability Bias)', 'V2: Robust Index\n(Dynamic Reality)'],
        'Correlation': [0.45, -0.15]
    }
    df = pd.DataFrame(data)
    colors = ['#95a5a6', '#c0392b'] # Grey vs Red

    plt.figure(figsize=(8, 6))
    ax = sns.barplot(x='Index_Version', y='Correlation', data=df, hue='Index_Version', palette=colors, legend=False)

    plt.axhline(0, color='black', linewidth=1.5)
    plt.title('The Twin Test: Measurement Bias Hides Risk', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('Correlation with Digital Efficiency', fontsize=12)

    # annotations
    plt.text(0, 0.2, "FALSE SECURITY\nV1 ignores transition costs",
             ha='center', color='white', weight='bold', bbox=dict(facecolor='black', alpha=0.5))
    plt.text(1, -0.08, "REALITY CHECK\nV2 captures the 'Dip'",
             ha='center', color='white', weight='bold', bbox=dict(facecolor='black', alpha=0.5))

    plt.tight_layout()
    save_path = os.path.join(IMG_DIR, '12_VisualB_TwinTest.png')
    plt.savefig(save_path, dpi=300)
    plt.show()

# ==========================================
# VISUAL C: TRANSITION RISK SIMULATION
# ==========================================
def plot_transition_simulation():
    print("\n>> Generating Visual C: Transition Simulation...")

    quarters = np.arange(0, 9)
    eri_scores = [50, 48, 42, 40, 43, 48, 52, 58, 62] 

    plt.figure(figsize=(12, 6))

    plt.plot(quarters, eri_scores, marker='o', linewidth=4, color='#2c3e50', label='Predicted Resilience (ERI)')
    plt.axvspan(1, 5, color='#e74c3c', alpha=0.15, label='The Transition Trap (Max Vulnerability)')
    plt.axhline(50, color='grey', linestyle='--', alpha=0.7, label='Pre-Shock Baseline')

    plt.title('Simulation: The "Transition Trap" in Distressed Sectors', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('Economic Resilience Index (ERI)', fontsize=12)
    plt.xlabel('Quarters Since Digital Implementation', fontsize=12)

    # callouts
    plt.annotate('Tech Implementation\n(Cash Outflow)', xy=(1, 48), xytext=(0.5, 55),
                 arrowprops=dict(facecolor='black', shrink=0.05), fontsize=10)
    plt.annotate('Structural Recovery\n(Efficiency Gains)', xy=(6, 52), xytext=(6, 45),
                 arrowprops=dict(facecolor='green', shrink=0.05), fontsize=10)

    plt.legend(loc='lower right')
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    save_path = os.path.join(IMG_DIR, '13_VisualC_TransitionRisk.png')
    plt.savefig(save_path, dpi=300)
    plt.show()

# ==========================================
# VISUAL D: DIAGNOSTIC SCATTER (Updated for ALL 4 ARCHETYPES)
# ==========================================
def plot_diagnostic_scatter():
    print("\n>> Generating Visual D: Diagnostic Scatter (Conceptual V2)...")

    # 1. Generate Synthetic Data matching REAL "Stubborn" V2 results
    # Logic: Predictions are stable (low variance), Actuals are volatile (high variance)
    np.random.seed(42)
    n_per_group = 30
    data = []

    # Group 1: All-Weather Star (High Res, Low Vol)
    # Model predicts ~60, Actual varies slightly (50-70)
    for _ in range(n_per_group):
        pred = 60 + np.random.normal(0, 1.5) 
        actual = 60 + np.random.normal(0, 8) 
        data.append(['All-Weather Star', actual, pred])

    # Group 2: Volatile Grower (High Res, High Vol)
    # Model predicts ~55, Actual varies wildly (40-80)
    for _ in range(n_per_group):
        pred = 55 + np.random.normal(0, 1.5)
        actual = 55 + np.random.normal(0, 15)
        data.append(['Volatile Grower', actual, pred])

    # Group 3: Safe Stagnator (Low Res, Low Vol) -> THE MISSING GROUP
    # Model predicts ~40, Actual varies slightly (35-45)
    for _ in range(n_per_group):
        pred = 40 + np.random.normal(0, 1.0)
        actual = 40 + np.random.normal(0, 5)
        data.append(['Safe Stagnator', actual, pred])

    # Group 4: Distressed (Low Res, High Vol)
    # Model predicts ~35, Actual varies widely (20-50)
    for _ in range(n_per_group):
        pred = 35 + np.random.normal(0, 1.5)
        actual = 35 + np.random.normal(0, 12)
        data.append(['Distressed', actual, pred])

    df = pd.DataFrame(data, columns=['Sector_Archetype', 'ERI_Actual', 'ERI_Predicted'])
    
    # Calculate R2 (Expect low value, e.g., ~0.10 - 0.20)
    r2 = r2_score(df['ERI_Actual'], df['ERI_Predicted'])

    # 2. Plotting
    plt.figure(figsize=(10, 8))

    # Define palette for all 4 groups
    palette = {
        'All-Weather Star': '#2ecc71',   # Green
        'Volatile Grower': '#f39c12',    # Orange
        'Safe Stagnator': '#3498db',     # Blue (New)
        'Distressed': '#e74c3c'          # Red
    }

    sns.scatterplot(
        data=df, x='ERI_Actual', y='ERI_Predicted', hue='Sector_Archetype',
        palette=palette, s=150, alpha=0.8, edgecolor='black'
    )

    # Perfect Fit Line (for reference)
    limit_min = min(df.min(numeric_only=True)) - 2
    limit_max = max(df.max(numeric_only=True)) + 2
    plt.plot([limit_min, limit_max], [limit_min, limit_max], color='gray', linestyle='--', linewidth=2, label='Perfect Prediction (Ideal)')

    # Update Title
    plt.title(f'CONCEPTUAL: Why V2 is "Stubborn"\n(Stable Predictions vs. Volatile Reality, $R^2={r2:.2f}$)', fontsize=14, fontweight='bold', pad=20)
    plt.xlabel('Actual Resilience Score (High Variance)', fontsize=12)
    plt.ylabel('Predicted Resilience Score (Stable/Flat)', fontsize=12)
    plt.legend(title='Sector Archetype', loc='upper left')

    # Force square aspect ratio
    plt.gca().set_aspect('equal', adjustable='box')
    plt.xlim(limit_min, limit_max)
    plt.ylim(limit_min, limit_max)

    plt.tight_layout()
    save_path = os.path.join(IMG_DIR, '14_VisualD_Diagnostic.png')
    plt.savefig(save_path, dpi=300)
    plt.show()

# ==========================================
# EXECUTION
# ==========================================
if __name__ == "__main__":
    plot_j_curve()
    plot_twin_test()
    plot_transition_simulation()
    plot_diagnostic_scatter()
    print("\n>> PHASE 5 COMPLETE. All visuals saved.")