# Supply Chain Intelligence Agent (v2)

**Executive summary:** Production-grade conversational AI agent with 17 tools spanning
forecasting, risk analysis, supplier intelligence, contract search, macroeconomic
context, and executive briefing. Backed by Databricks Foundation Models with
conversation memory and MLflow tracing.

**Depends on:** All gold/silver/bronze tables from the v2 pipeline. Run ingestion,
transformation, and at least one forecasting notebook first.

### Capabilities
| Category | Tools | Data Sources |
|----------|-------|--------------|
| **Forecasting** | Demand forecast, model comparison, confidence scoring | Prophet, ARIMA, RF forecasts |
| **Analysis** | Anomaly detection, trend analysis, demand drivers | Demand signals, feature importance |
| **Scenarios** | Geopolitical, tariff, weather, custom what-if | Risk indices, commodity prices |
| **Intelligence** | Supplier lookup, contract search, DoD metrics, commodity prices | SAM entities, FPDS contracts |
| **Macro Context** | Economic indicators, trade barometer, supply chain pressure | World Bank, NY Fed, WTO |
| **Executive** | Supply chain health dashboard, executive briefing | All tables |

### Enhancements over notebooks 01 & 02
- All 17 tools consolidated (no duplication)
- **5 new tools**: supplier intelligence, contract search, macro context, supply chain health, executive briefing
- **Conversation memory**: multi-turn dialogue with context retention
- **Rich system prompt**: defense supply chain domain expertise
- **Graceful degradation**: missing tables handled without crashing
- **Interactive chat loop**: run as a conversational session in Databricks


## Setup


In [None]:
%pip install --upgrade "typing_extensions>=4.1" "langchain>=0.2,<0.4" "langchain-core>=0.2" langgraph databricks-langchain mlflow pandas numpy scikit-learn scipy


In [None]:
try:
    from typing_extensions import Sentinel
except ImportError:
    dbutils.library.restartPython()


In [None]:
from databricks_langchain import ChatDatabricks
_create_tool_calling_agent = None
_AgentExecutor = None
try:
    from langchain.agents import create_tool_calling_agent
    _create_tool_calling_agent = create_tool_calling_agent
except ImportError:
    try:
        from langchain.agents.tool_calling_agent.base import create_tool_calling_agent
        _create_tool_calling_agent = create_tool_calling_agent
    except (ModuleNotFoundError, ImportError):
        try:
            from langchain.agents import create_react_agent as create_tool_calling_agent
            _create_tool_calling_agent = create_tool_calling_agent
        except ImportError:
            pass
try:
    from langchain_core.agents import AgentExecutor
    _AgentExecutor = AgentExecutor
except ImportError:
    try:
        from langchain.agents import AgentExecutor
        _AgentExecutor = AgentExecutor
    except ImportError:
        pass

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
from pyspark.sql import functions as F
from scipy import stats
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import mlflow
import json
import textwrap


## Configuration


In [None]:
CATALOG = "supply_chain"

# ── Table registry ───────────────────────────────────────────────────────────
TABLES = {
    # Gold
    "demand_signals":       f"{CATALOG}.gold.oshkosh_monthly_demand_signals",
    "dod_metrics":          f"{CATALOG}.gold.dod_metrics_inputs_monthly",
    "trade_risk":           f"{CATALOG}.gold.trade_tariff_risk_monthly",
    "prophet_forecasts":    f"{CATALOG}.gold.prophet_forecasts",
    "arima_forecasts":      f"{CATALOG}.gold.arima_forecasts",
    "rf_forecasts":         f"{CATALOG}.gold.random_forest_forecasts",
    "rf_feature_importance": f"{CATALOG}.gold.random_forest_feature_importance",
    # Silver
    "suppliers":            f"{CATALOG}.silver.supplier_geolocations",
    "commodity":            f"{CATALOG}.silver.commodity_prices_monthly",
    "weather":              f"{CATALOG}.silver.weather_risk_monthly",
    "tariff_events":        f"{CATALOG}.silver.trade_tariff_risk_events",
    # Bronze
    "fpds_contracts":       f"{CATALOG}.bronze.fpds_contracts",
    "prime_awards":         f"{CATALOG}.bronze.oshkosh_prime_award_actions",
    "subawards":            f"{CATALOG}.bronze.oshkosh_subawards",
    "wdi":                  f"{CATALOG}.bronze.worldbank_wdi",
    "wgi":                  f"{CATALOG}.bronze.worldbank_wgi",
    "gscpi":                f"{CATALOG}.bronze.nyfed_gscpi",
    "wto":                  f"{CATALOG}.bronze.wto_trade_barometer",
}


def _safe_load(table_key: str) -> pd.DataFrame:
    """Load a table by key, returning empty DataFrame if it doesn't exist."""
    table_name = TABLES.get(table_key, table_key)
    try:
        return spark.table(table_name).toPandas()
    except Exception:
        return pd.DataFrame()


print(f"Table registry: {len(TABLES)} tables configured")


## Initialize LLM


In [None]:
llm = ChatDatabricks(
    endpoint="databricks-meta-llama-3-3-70b-instruct",
    temperature=0.1,
    max_tokens=2000,
)

mlflow.langchain.autolog()
print("LLM initialized: databricks-meta-llama-3-3-70b-instruct")


---
# Tools: Forecasting


In [None]:
@tool
def get_demand_forecast(months_ahead: int = 3, include_confidence: bool = True) -> str:
    """
    Retrieve demand forecast for Oshkosh Defense contracts.

    Args:
        months_ahead: Number of months to forecast (1-12)
        include_confidence: Whether to include confidence intervals
    """
    try:
        df = _safe_load("prophet_forecasts")
        if df.empty:
            return "No forecast data available. Run the forecasting notebooks first."
        df['month'] = pd.to_datetime(df['month'])
        future = df[df['month'] > datetime.now()].head(months_ahead)
        if future.empty:
            return "No future forecast rows found. The forecast horizon may need extending."

        lines = [f"DEMAND FORECAST - Next {months_ahead} Months", "=" * 50, ""]
        total = 0.0
        for _, r in future.iterrows():
            m = r['month'].strftime('%B %Y')
            f_val = r['forecast_demand_usd']
            total += f_val
            lines.append(f"  {m}")
            lines.append(f"    Forecast: ${f_val:,.0f}")
            if include_confidence:
                lo = r.get('forecast_lower')
                hi = r.get('forecast_upper')
                if pd.notna(lo) and pd.notna(hi):
                    lines.append(f"    95% CI:   ${lo:,.0f} - ${hi:,.0f}")
            lines.append("")
        lines.append(f"TOTAL FORECAST:   ${total:,.0f}")
        lines.append(f"MONTHLY AVERAGE:  ${total / len(future):,.0f}")
        return "\n".join(lines)
    except Exception as e:
        return f"Error retrieving forecast: {e}"


In [None]:
@tool
def compare_forecast_models(months_ahead: int = 3) -> str:
    """
    Compare forecasts from Prophet, ARIMA, and Random Forest models side-by-side.

    Args:
        months_ahead: Number of months to compare (default 3)
    """
    try:
        now = datetime.now()
        models_data = {}
        for key, label in [("prophet_forecasts", "Prophet"), ("arima_forecasts", "ARIMA"), ("rf_forecasts", "Random Forest")]:
            df = _safe_load(key)
            if not df.empty:
                df['month'] = pd.to_datetime(df['month'])
                models_data[label] = df[df['month'] > now].head(months_ahead)

        if not models_data:
            return "No forecast tables found. Run at least one forecasting notebook first."

        lines = [f"MULTI-MODEL FORECAST COMPARISON (next {months_ahead} months)", "=" * 60, ""]

        # Get months from first available model
        ref = list(models_data.values())[0]
        for i in range(min(months_ahead, len(ref))):
            month_str = ref.iloc[i]['month'].strftime('%B %Y')
            lines.append(f"  {month_str}")
            vals = []
            for label, mdf in models_data.items():
                if i < len(mdf):
                    v = mdf.iloc[i]['forecast_demand_usd']
                    vals.append(v)
                    lines.append(f"    {label:<15s} ${v:>12,.0f}")
            if len(vals) > 1:
                avg = np.mean(vals)
                spread = np.std(vals)
                lines.append(f"    {'Ensemble Avg':<15s} ${avg:>12,.0f}  (spread: {spread/avg*100:.1f}%)")
            lines.append("")

        # Agreement summary
        if len(models_data) > 1:
            spreads = []
            for i in range(min(months_ahead, len(ref))):
                vals = [mdf.iloc[i]['forecast_demand_usd'] for mdf in models_data.values() if i < len(mdf)]
                if len(vals) > 1:
                    spreads.append(np.std(vals) / np.mean(vals) * 100)
            if spreads:
                avg_sp = np.mean(spreads)
                level = "HIGH" if avg_sp < 5 else "MODERATE" if avg_sp < 15 else "LOW"
                lines.append(f"Model agreement: {level} (avg spread {avg_sp:.1f}%)")

        return "\n".join(lines)
    except Exception as e:
        return f"Error comparing models: {e}"


In [None]:
@tool
def assess_forecast_confidence(months_ahead: int = 3) -> str:
    """
    Score forecast confidence based on volatility, trend strength, risk stability, and data recency.

    Args:
        months_ahead: Forecast horizon to assess (default 3)
    """
    try:
        df = _safe_load("demand_signals")
        if df.empty:
            return "Demand signals table not available."
        df['month'] = pd.to_datetime(df['month'])
        df = df.sort_values('month')
        recent = df.tail(12)
        y = recent['total_obligations_usd'].values

        # Factor 1 - Volatility
        cv = y.std() / y.mean()
        vol_score = max(0, 100 - cv * 200)

        # Factor 2 - Trend consistency
        slope, _, r_value, p_value, _ = stats.linregress(np.arange(len(y)), y)
        trend_score = (r_value ** 2) * 100

        # Factor 3 - Risk environment stability
        risk_scores = []
        for col in ['geo_risk_index', 'tariff_risk_index', 'weather_disruption_index']:
            if col in recent.columns and recent[col].mean() > 0:
                risk_scores.append(max(0, 100 - (recent[col].std() / (recent[col].mean() + 1e-9)) * 100))
        risk_score = np.mean(risk_scores) if risk_scores else 50

        # Factor 4 - Data recency
        days_since = (datetime.now() - recent['month'].max()).days
        recency_score = max(0, 100 - days_since / 30 * 50)

        overall = vol_score * 0.35 + trend_score * 0.30 + risk_score * 0.25 + recency_score * 0.10

        label = "HIGH" if overall >= 80 else "MODERATE" if overall >= 60 else "LOW"

        lines = [
            f"FORECAST CONFIDENCE ASSESSMENT (next {months_ahead} months)",
            "=" * 60, "",
            f"  1. Historical Stability   {vol_score:5.0f}/100  (CV={cv:.3f})",
            f"  2. Trend Consistency       {trend_score:5.0f}/100  (R²={r_value**2:.3f}, p={p_value:.4f})",
            f"  3. Risk Env. Stability     {risk_score:5.0f}/100",
            f"  4. Data Recency            {recency_score:5.0f}/100  ({days_since} days since last)",
            "", "=" * 60,
            f"  OVERALL CONFIDENCE: {overall:.0f}/100 ({label})", "",
        ]
        if overall >= 80:
            lines.append("  Forecasts are reliable for planning. Standard safety stock appropriate.")
        elif overall >= 60:
            lines.append("  Forecasts are reasonable. Consider 15-20% safety buffer.")
        else:
            lines.append("  High uncertainty. Increase safety stock 30-50% and review model inputs.")
        return "\n".join(lines)
    except Exception as e:
        return f"Error assessing confidence: {e}"


---
# Tools: Analysis


In [None]:
@tool
def detect_anomalies(threshold_pct: float = 20.0, lookback_months: int = 6) -> str:
    """
    Detect demand anomalies by comparing recent months to the historical baseline.

    Args:
        threshold_pct: Percentage deviation to flag as anomaly (default 20%)
        lookback_months: Months to analyze (default 6)
    """
    try:
        df = _safe_load("demand_signals")
        if df.empty:
            return "Demand signals table not available."
        df['month'] = pd.to_datetime(df['month'])
        df = df.sort_values('month')
        recent = df.tail(lookback_months)
        historical = df.iloc[:-lookback_months]
        if historical.empty:
            return "Not enough history for anomaly detection."
        baseline = historical['total_obligations_usd'].mean()

        lines = [
            "ANOMALY DETECTION REPORT",
            f"  Threshold: +/-{threshold_pct}%  |  Baseline: ${baseline:,.0f}",
            "=" * 50, "",
        ]
        count = 0
        for _, r in recent.iterrows():
            actual = r['total_obligations_usd']
            dev = (actual - baseline) / baseline * 100
            if abs(dev) > threshold_pct:
                count += 1
                sev = "CRITICAL" if abs(dev) > 50 else "HIGH" if abs(dev) > 30 else "MODERATE"
                lines.append(f"  [{sev}] {r['month'].strftime('%B %Y')}")
                lines.append(f"    Actual: ${actual:,.0f}  ({dev:+.1f}% vs baseline)")
                risk = r.get('combined_risk_index', 'N/A')
                lines.append(f"    Combined Risk Index: {risk}")
                lines.append("")

        if count == 0:
            lines.append("  No anomalies detected within the threshold.")
        else:
            lines.append(f"  TOTAL ANOMALIES: {count}")
        return "\n".join(lines)
    except Exception as e:
        return f"Error detecting anomalies: {e}"


In [None]:
@tool
def detect_trends(lookback_months: int = 12, trend_type: str = "ALL") -> str:
    """
    Detect growth trends, seasonality, volatility, and risk correlations in demand.

    Args:
        lookback_months: Months to analyze (default 12)
        trend_type: GROWTH, SEASONAL, VOLATILITY, CORRELATION, or ALL
    """
    try:
        df = _safe_load("demand_signals")
        if df.empty:
            return "Demand signals table not available."
        df['month'] = pd.to_datetime(df['month'])
        df = df.sort_values('month')
        recent = df.tail(lookback_months)
        y = recent['total_obligations_usd'].values
        tt = trend_type.upper()

        lines = [f"TREND ANALYSIS (last {lookback_months} months)", "=" * 60, ""]

        if tt in ("GROWTH", "ALL"):
            slope, _, r_val, p_val, _ = stats.linregress(np.arange(len(y)), y)
            mg = slope / y.mean() * 100
            lines += [
                "GROWTH TREND",
                f"  Direction: {'Increasing' if slope > 0 else 'Decreasing'}",
                f"  Monthly growth: {mg:+.2f}%  |  Annualized: {mg*12:+.1f}%",
                f"  R²={r_val**2:.3f}, p={p_val:.4f} ({'significant' if p_val < 0.05 else 'not significant'})",
                "",
            ]

        if tt in ("SEASONAL", "ALL"):
            recent_copy = recent.copy()
            recent_copy['moy'] = recent_copy['month'].dt.month
            mavg = recent_copy.groupby('moy')['total_obligations_usd'].mean()
            amp = (mavg.max() - mavg.min()) / mavg.mean() * 100
            mn = ['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']
            lines += [
                "SEASONALITY",
                f"  Peak: {mn[mavg.idxmax()-1]} (${mavg.max():,.0f})",
                f"  Trough: {mn[mavg.idxmin()-1]} (${mavg.min():,.0f})",
                f"  Amplitude: {amp:.1f}%  ({'HIGH' if amp > 30 else 'MODERATE' if amp > 15 else 'LOW'})",
                "",
            ]

        if tt in ("VOLATILITY", "ALL"):
            cv = y.std() / y.mean()
            swings = (np.abs(np.diff(y) / y[:-1]) > 0.20).sum()
            lines += [
                "VOLATILITY",
                f"  Std Dev: ${y.std():,.0f}  |  CV: {cv:.3f}",
                f"  Large swings (>20%): {swings}",
                f"  Assessment: {'HIGH' if cv > 0.5 else 'MODERATE' if cv > 0.3 else 'LOW'}",
                "",
            ]

        if tt in ("CORRELATION", "ALL"):
            lines.append("RISK CORRELATIONS")
            for col, label in [('geo_risk_index','Geopolitical'), ('tariff_risk_index','Tariff'), ('commodity_cost_pressure','Commodity')]:
                if col in recent.columns:
                    corr = recent['total_obligations_usd'].corr(recent[col])
                    if pd.notna(corr):
                        strength = "Strong" if abs(corr) > 0.7 else "Moderate" if abs(corr) > 0.4 else "Weak"
                        lines.append(f"  {label}: {corr:+.3f} ({strength})")
            lines.append("")

        return "\n".join(lines)
    except Exception as e:
        return f"Error detecting trends: {e}"


In [None]:
@tool
def explain_demand_drivers(top_n: int = 10) -> str:
    """
    Explain which features drive demand using Random Forest feature importance.

    Args:
        top_n: Number of top features to show (default 10)
    """
    try:
        df = _safe_load("rf_feature_importance")
        if df.empty:
            return "Feature importance table not available. Run the Random Forest forecasting notebook first."
        df = df.sort_values('importance', ascending=False).head(top_n)

        interp = {
            'demand_lag_1m': 'Previous month demand (momentum)',
            'demand_rolling_mean_12m': 'Annual average demand (baseline)',
            'demand_rolling_mean_6m': '6-month average (recent trend)',
            'demand_rolling_mean_3m': '3-month average (short-term)',
            'geo_risk_index': 'Geopolitical risk level',
            'tariff_risk_index': 'Trade/tariff risk level',
            'commodity_cost_pressure': 'Material cost pressure',
            'weather_disruption_index': 'Weather disruptions',
            'is_q4': 'Fiscal year-end spending effect',
            'months_since_start': 'Long-term time trend',
            'demand_trend_3m': '3-month demand momentum',
            'demand_pct_change_1m': 'Month-over-month growth rate',
            'combined_risk_interaction': 'Combined risk multiplier',
        }

        lines = [f"DEMAND DRIVER ANALYSIS - Top {top_n} Features", "=" * 60, ""]
        for _, r in df.iterrows():
            f_name = r['feature']
            imp = r['importance']
            lines.append(f"  {int(r['rank']):>2}. {f_name}")
            lines.append(f"      Score: {imp:.4f}  |  {interp.get(f_name, 'Model feature')}")
            lines.append("")

        # Summary insight
        top_feat = df.iloc[0]['feature']
        if 'lag' in top_feat or 'rolling' in top_feat:
            lines.append("KEY INSIGHT: Historical patterns dominate -- demand shows strong momentum/seasonality.")
        elif 'risk' in top_feat:
            lines.append("KEY INSIGHT: Risk signals dominate -- demand is highly sensitive to geopolitical/trade factors.")
        elif 'commodity' in top_feat:
            lines.append("KEY INSIGHT: Commodity prices dominate -- material costs significantly influence demand.")

        return "\n".join(lines)
    except Exception as e:
        return f"Error analyzing demand drivers: {e}"


---
# Tools: Scenarios


In [None]:
@tool
def scenario_geopolitical_risk(risk_level: str = "HIGH", region: str = "ALL") -> str:
    """
    Analyze impact of a geopolitical risk scenario on defense demand.

    Args:
        risk_level: MODERATE, ELEVATED, HIGH, or CRITICAL
        region: EUROPE, MIDEAST, INDO_PACIFIC, AMERICAS, or ALL
    """
    try:
        df = _safe_load("demand_signals")
        if df.empty:
            return "Demand signals table not available."
        df['month'] = pd.to_datetime(df['month'])
        baseline = df.tail(12)['total_obligations_usd'].mean()

        mult = {"MODERATE": 1.0, "ELEVATED": 1.15, "HIGH": 1.35, "CRITICAL": 1.75}.get(risk_level.upper(), 1.0)
        proj = baseline * mult
        inc = proj - baseline

        lines = [
            "GEOPOLITICAL RISK SCENARIO",
            "=" * 50,
            f"  Level: {risk_level.upper()}  |  Region: {region.upper()}",
            f"  Demand multiplier: {mult:.2f}x", "",
            f"  Baseline monthly: ${baseline:,.0f}",
            f"  Projected monthly: ${proj:,.0f}  ({(mult-1)*100:+.0f}%)",
            f"  Annual impact: ${inc * 12:,.0f}", "",
            "RECOMMENDED ACTIONS:",
        ]
        if risk_level.upper() == "CRITICAL":
            lines += ["  - Activate surge capacity", "  - Increase safety stock 75%", "  - Expedite critical orders", "  - Qualify backup suppliers"]
        elif risk_level.upper() == "HIGH":
            lines += ["  - Increase safety stock 35%", "  - Accelerate procurement", "  - Monitor supplier capacity"]
        elif risk_level.upper() == "ELEVATED":
            lines += ["  - Increase safety stock 15%", "  - Review contingency plans"]
        else:
            lines += ["  - Continue normal operations"]
        return "\n".join(lines)
    except Exception as e:
        return f"Error: {e}"


In [None]:
@tool
def scenario_tariff_increase(tariff_increase_pct: float = 25.0, product_category: str = "ALL") -> str:
    """
    Analyze cost impact of tariff increases on supply chain.

    Args:
        tariff_increase_pct: Percentage tariff increase (default 25%)
        product_category: VEHICLES, ELECTRONICS, STEEL, ALUMINUM, or ALL
    """
    try:
        df = _safe_load("demand_signals")
        if df.empty:
            return "Demand signals table not available."
        df['month'] = pd.to_datetime(df['month'])
        baseline = df.tail(12)['total_obligations_usd'].mean()

        import_pct = 0.30
        imported = baseline * import_pct
        tariff_cost = imported * tariff_increase_pct / 100
        total_pct = tariff_cost / baseline * 100

        lines = [
            "TARIFF INCREASE SCENARIO",
            "=" * 50,
            f"  Tariff increase: {tariff_increase_pct}%  |  Category: {product_category.upper()}", "",
            f"  Monthly spend: ${baseline:,.0f}",
            f"  Import content: {import_pct*100:.0f}% (${imported:,.0f})",
            f"  Additional tariff cost: ${tariff_cost:,.0f}/mo ({total_pct:.1f}% of total)",
            f"  Annual impact: ${tariff_cost * 12:,.0f}", "",
            "MITIGATION:",
        ]
        if tariff_increase_pct >= 50:
            lines += ["  - Evaluate domestic sourcing", "  - Lock long-term contracts now", "  - Apply for tariff exclusions"]
        elif tariff_increase_pct >= 25:
            lines += ["  - Accelerate supplier diversification", "  - Review make-vs-buy", "  - Explore bonded warehouses"]
        else:
            lines += ["  - Monitor developments", "  - Update cost models"]
        return "\n".join(lines)
    except Exception as e:
        return f"Error: {e}"


In [None]:
@tool
def scenario_weather_disruption(disruption_type: str = "SEVERE_WINTER", affected_region: str = "MIDWEST") -> str:
    """
    Analyze supply chain impact of weather disruptions.

    Args:
        disruption_type: SEVERE_WINTER, HURRICANE, FLOODING, or EXTREME_HEAT
        affected_region: MIDWEST, SOUTHEAST, GULF_COAST, or WEST_COAST
    """
    try:
        impacts = {
            "SEVERE_WINTER": {"delay": 7, "prod_impact": 15, "subsystems": ["POWERTRAIN", "SUSPENSION", "MATERIALS"]},
            "HURRICANE":     {"delay": 14, "prod_impact": 25, "subsystems": ["ELECTRONICS", "TIRES", "ARMOR"]},
            "FLOODING":      {"delay": 10, "prod_impact": 20, "subsystems": ["MATERIALS", "HYDRAULICS", "ELECTRICAL"]},
            "EXTREME_HEAT":  {"delay": 3,  "prod_impact": 10, "subsystems": ["RUBBER", "ELECTRONICS"]},
        }
        imp = impacts.get(disruption_type.upper(), impacts["SEVERE_WINTER"])

        lines = [
            "WEATHER DISRUPTION SCENARIO",
            "=" * 50,
            f"  Event: {disruption_type.replace('_', ' ')}  |  Region: {affected_region.replace('_', ' ')}", "",
            f"  Transport delays: {imp['delay']} days",
            f"  Production impact: {imp['prod_impact']}% reduction",
            f"  Affected subsystems: {', '.join(imp['subsystems'])}", "",
            "CONTINGENCY ACTIONS:",
        ]
        if imp['delay'] >= 10:
            lines += ["  - Activate emergency logistics", "  - Pre-position critical inventory", "  - Engage backup carriers"]
        else:
            lines += ["  - Monitor forecasts", "  - Communicate with suppliers", "  - Review safety stock"]
        lines += ["", f"  Expected recovery: {imp['delay'] + 7} days"]
        return "\n".join(lines)
    except Exception as e:
        return f"Error: {e}"


In [None]:
@tool
def build_whatif_scenario(
    geo_risk_change_pct: float = 0.0,
    tariff_change_pct: float = 0.0,
    commodity_price_change_pct: float = 0.0,
    demand_shock_pct: float = 0.0,
) -> str:
    """
    Build a custom what-if scenario by adjusting multiple risk factors simultaneously.

    Args:
        geo_risk_change_pct: % change in geopolitical risk (e.g. 50 for +50%)
        tariff_change_pct: % change in tariffs
        commodity_price_change_pct: % change in commodity prices
        demand_shock_pct: Direct demand shock (e.g. -10 for -10%)
    """
    try:
        df = _safe_load("demand_signals")
        if df.empty:
            return "Demand signals table not available."
        df['month'] = pd.to_datetime(df['month'])
        recent = df.tail(12)
        baseline = recent['total_obligations_usd'].mean()

        # Simplified elasticity model
        geo_eff    = geo_risk_change_pct / 100 * 0.30      # risk up -> defense spend up
        tariff_eff = tariff_change_pct / 100 * -0.05       # tariff up -> cost pressure
        comm_eff   = commodity_price_change_pct / 100 * -0.02
        direct     = demand_shock_pct / 100
        total_pct  = (geo_eff + tariff_eff + comm_eff + direct) * 100
        projected  = baseline * (1 + total_pct / 100)
        change     = projected - baseline

        lines = [
            "CUSTOM WHAT-IF SCENARIO",
            "=" * 60, "",
            f"  Baseline monthly demand: ${baseline:,.0f}", "",
            "  Adjustments:",
        ]
        for lbl, val in [("Geopolitical risk", geo_risk_change_pct), ("Tariff risk", tariff_change_pct),
                         ("Commodity prices", commodity_price_change_pct), ("Direct demand shock", demand_shock_pct)]:
            if val != 0:
                lines.append(f"    {lbl}: {val:+.0f}%")
        lines += [
            "", "  Impact breakdown:",
            f"    Geopolitical effect: {geo_eff*100:+.2f}%",
            f"    Tariff effect:      {tariff_eff*100:+.2f}%",
            f"    Commodity effect:    {comm_eff*100:+.2f}%",
            f"    Direct shock:        {direct*100:+.2f}%",
            f"    {'─'*35}",
            f"    NET DEMAND IMPACT:   {total_pct:+.2f}%", "",
            f"  Projected monthly: ${projected:,.0f}  (${change:+,.0f})",
            f"  Annual impact:     ${change * 12:,.0f}",
        ]
        return "\n".join(lines)
    except Exception as e:
        return f"Error: {e}"


---
# Tools: Intelligence (NEW)


In [None]:
@tool
def query_suppliers(
    name: str = "",
    state: str = "",
    subsystem: str = "",
    company_size: str = "",
    top_n: int = 15,
) -> str:
    """
    Search the defense supplier base from SAM.gov entity data.

    Args:
        name: Partial supplier name to search (case-insensitive)
        state: US state abbreviation to filter (e.g. WI, TX, CA)
        subsystem: Subsystem category (e.g. ARMOR, POWERTRAIN, ELECTRONICS)
        company_size: Company size filter (e.g. SMALL, LARGE)
        top_n: Max results to return (default 15)
    """
    try:
        df = _safe_load("suppliers")
        if df.empty:
            return "Supplier table not available. Run 04_sam_entity_ingestion_v2 first."

        mask = pd.Series(True, index=df.index)
        if name:
            mask &= df['supplier_name'].str.contains(name, case=False, na=False)
        if state:
            mask &= df['state'].str.upper() == state.upper()
        if subsystem:
            mask &= df['subsystem_category'].str.contains(subsystem, case=False, na=False)
        if company_size:
            mask &= df['company_size'].str.contains(company_size, case=False, na=False)

        filtered = df[mask].head(top_n)

        if filtered.empty:
            return f"No suppliers found matching criteria (name={name!r}, state={state!r}, subsystem={subsystem!r}, size={company_size!r})."

        lines = [f"SUPPLIER SEARCH RESULTS ({len(filtered)} of {mask.sum()} matches)", "=" * 60, ""]
        for _, r in filtered.iterrows():
            lines.append(f"  {r['supplier_name']}")
            lines.append(f"    UEI: {r.get('uei','N/A')}  |  CAGE: {r.get('cage_code','N/A')}")
            lines.append(f"    Location: {r.get('city','')}, {r.get('state','')}, {r.get('country','')}")
            lines.append(f"    Subsystem: {r.get('subsystem_category','N/A')}  |  Size: {r.get('company_size','N/A')}")
            dist = r.get('distance_to_nearest_oshkosh_facility_km')
            if pd.notna(dist):
                lines.append(f"    Distance to Oshkosh HQ: {dist:,.0f} km")
            lines.append(f"    NAICS: {r.get('naics_code_primary','N/A')}")
            lines.append("")

        # Summary stats
        if mask.sum() > 0:
            full = df[mask]
            lines.append("  SUMMARY:")
            lines.append(f"    Total matching: {mask.sum()}")
            if 'company_size' in full.columns:
                sizes = full['company_size'].value_counts()
                lines.append(f"    By size: {dict(sizes)}")
            if 'region_group' in full.columns:
                regions = full['region_group'].value_counts().head(5)
                lines.append(f"    Top regions: {dict(regions)}")

        return "\n".join(lines)
    except Exception as e:
        return f"Error querying suppliers: {e}"


In [None]:
@tool
def search_contracts(
    vendor_name: str = "",
    psc_code: str = "",
    fiscal_year: int = 0,
    min_amount: float = 0,
    top_n: int = 15,
) -> str:
    """
    Search FPDS contract data for defense contracts.

    Args:
        vendor_name: Partial vendor name (case-insensitive)
        psc_code: Product Service Code prefix (e.g. '23' for vehicles)
        fiscal_year: Filter by fiscal year (e.g. 2024). 0 = all years
        min_amount: Minimum obligated amount in USD
        top_n: Max results to return (default 15)
    """
    try:
        df = _safe_load("fpds_contracts")
        if df.empty:
            return "FPDS contracts table not available. Run 02_fpds_ingestion_v2 first."

        mask = pd.Series(True, index=df.index)
        if vendor_name:
            mask &= df['vendor_name'].str.contains(vendor_name, case=False, na=False)
        if psc_code:
            mask &= df['psc_code'].astype(str).str.startswith(psc_code)
        if fiscal_year > 0 and 'fiscal_year' in df.columns:
            mask &= df['fiscal_year'] == fiscal_year
        if min_amount > 0 and 'obligated_amount' in df.columns:
            df['obligated_amount'] = pd.to_numeric(df['obligated_amount'], errors='coerce')
            mask &= df['obligated_amount'] >= min_amount

        filtered = df[mask].sort_values('obligated_amount', ascending=False).head(top_n)

        if filtered.empty:
            return f"No contracts found matching criteria."

        lines = [f"CONTRACT SEARCH RESULTS ({len(filtered)} of {mask.sum()} matches)", "=" * 60, ""]
        total_val = 0
        for _, r in filtered.iterrows():
            amt = float(r.get('obligated_amount', 0) or 0)
            total_val += amt
            lines.append(f"  Contract: {r.get('contract_id', 'N/A')}")
            lines.append(f"    Vendor: {r.get('vendor_name', 'N/A')}")
            lines.append(f"    Amount: ${amt:,.0f}  |  FY: {r.get('fiscal_year', 'N/A')}")
            lines.append(f"    PSC: {r.get('psc_code', 'N/A')}  |  NAICS: {r.get('naics_code', 'N/A')}")
            desc = str(r.get('description', ''))[:100]
            if desc and desc != 'None':
                lines.append(f"    Desc: {desc}")
            lines.append(f"    Signed: {r.get('signed_date', 'N/A')}  |  Completion: {r.get('ultimate_completion_date', 'N/A')}")
            lines.append("")

        lines.append(f"  TOTAL VALUE (shown): ${total_val:,.0f}")
        lines.append(f"  ALL MATCHES: {mask.sum()} contracts")
        return "\n".join(lines)
    except Exception as e:
        return f"Error searching contracts: {e}"


In [None]:
@tool
def compare_dod_metrics(metric_type: str = "ALL") -> str:
    """
    Compare current performance against DoD supply chain metrics (RO, AAO, Days of Supply, NMCS Risk).

    Args:
        metric_type: RO, AAO, DAYS_OF_SUPPLY, NMCS_RISK, or ALL
    """
    try:
        df = _safe_load("dod_metrics")
        if df.empty:
            return "DoD metrics table not available. Run transformation notebooks first."
        df['month'] = pd.to_datetime(df['month'])
        latest = df.sort_values('month').iloc[-1]
        mt = metric_type.upper()

        lines = [
            "DoD SUPPLY CHAIN METRICS",
            "=" * 50,
            f"  As of: {latest['month'].strftime('%B %Y')}", "",
        ]

        if mt in ("RO", "ALL"):
            ro = latest['requirements_objective_proxy']
            rro = latest['risk_adjusted_ro']
            lines += [f"  REQUIREMENTS OBJECTIVE (RO)", f"    Current: ${ro:,.0f}  |  Risk-Adjusted: ${rro:,.0f}", ""]

        if mt in ("AAO", "ALL"):
            aao = latest['approved_acquisition_objective_proxy']
            lines += [f"  APPROVED ACQUISITION OBJECTIVE (AAO)", f"    Current: ${aao:,.0f}  (includes 2-yr forecast)", ""]

        if mt in ("DAYS_OF_SUPPLY", "ALL"):
            dos = latest['days_of_supply_proxy']
            status = "Healthy" if dos >= 60 else "Monitor" if dos >= 30 else "CRITICAL"
            lines += [f"  DAYS OF SUPPLY", f"    Current: {dos:.0f} days  |  Target: 60+  |  Status: {status}", ""]

        if mt in ("NMCS_RISK", "ALL"):
            nmcs = latest['nmcs_risk_indicator']
            lines += [f"  NMCS RISK", f"    Level: {nmcs}", ""]

        if mt == "ALL":
            vol = latest.get('demand_volatility_category', 'N/A')
            cv = latest.get('coefficient_of_variation', 0)
            rec = latest.get('forecast_method_recommendation', 'N/A')
            lines += [f"  DEMAND VOLATILITY", f"    Category: {vol}  |  CV: {cv:.2f}  |  Recommended model: {rec}"]

        return "\n".join(lines)
    except Exception as e:
        return f"Error: {e}"


In [None]:
@tool
def get_commodity_prices(category: str = "ALL") -> str:
    """
    Get current defense-critical commodity prices with trends.

    Args:
        category: ENERGY, PRECIOUS_METALS, INDUSTRIAL_METALS, BATTERY_MATERIALS, or ALL
    """
    try:
        df = _safe_load("commodity")
        if df.empty:
            return "Commodity prices table not available."
        df['month'] = pd.to_datetime(df['month'])
        latest_month = df['month'].max()
        latest = df[df['month'] == latest_month]

        if category.upper() != "ALL":
            latest = latest[latest['category'].str.upper() == category.upper()]

        lines = [
            "DEFENSE MATERIALS PRICE MONITOR",
            "=" * 50,
            f"  As of: {latest_month.strftime('%B %Y')}", "",
        ]

        for cat in latest['category'].unique():
            cat_data = latest[latest['category'] == cat]
            lines.append(f"  {cat.upper()}")
            for _, r in cat_data.iterrows():
                trend = "+" if r.get('pct_change_1mo', 0) > 0 else "-"
                lines.append(f"    {r['commodity_name']}: ${r['close_price']:,.2f}")
                lines.append(f"      1mo: {r.get('pct_change_1mo', 0):+.1f}%  |  3mo: {r.get('pct_change_3mo', 0):+.1f}%  |  Use: {r.get('defense_use', 'N/A')}")
            lines.append("")

        avg_pressure = latest['cost_pressure_score'].mean()
        level = "HIGH" if avg_pressure > 10 else "MODERATE" if avg_pressure > 0 else "FAVORABLE"
        lines.append(f"  COST PRESSURE: {level} (score: {avg_pressure:.1f})")
        return "\n".join(lines)
    except Exception as e:
        return f"Error: {e}"


In [None]:
@tool
def get_macro_context(indicator: str = "ALL") -> str:
    """
    Get macroeconomic context from World Bank, NY Fed GSCPI, and WTO data.

    Args:
        indicator: GDP, INFLATION, TRADE, GSCPI, WTO, or ALL
    """
    try:
        ind = indicator.upper()
        lines = ["MACROECONOMIC CONTEXT", "=" * 60, ""]

        # NY Fed Global Supply Chain Pressure Index
        if ind in ("GSCPI", "ALL"):
            gscpi = _safe_load("gscpi")
            if not gscpi.empty:
                if 'as_of_date' in gscpi.columns:
                    gscpi['as_of_date'] = pd.to_datetime(gscpi['as_of_date'])
                    gscpi = gscpi.sort_values('as_of_date')
                latest = gscpi.tail(1).iloc[0] if len(gscpi) > 0 else None
                if latest is not None:
                    val = latest.get('value', 0)
                    date = latest.get('as_of_date', 'N/A')
                    if hasattr(date, 'strftime'):
                        date = date.strftime('%Y-%m')
                    level = "ELEVATED" if float(val) > 1 else "NORMAL" if float(val) > -1 else "LOW"
                    lines += [
                        "  NY FED GLOBAL SUPPLY CHAIN PRESSURE INDEX (GSCPI)",
                        f"    Latest value: {float(val):.2f}  ({date})",
                        f"    Level: {level}",
                        f"    (0 = historical avg; >1 = elevated pressure; <-1 = low pressure)", "",
                    ]
                    # Trend
                    if len(gscpi) >= 3:
                        recent_vals = gscpi.tail(3)['value'].astype(float).values
                        if recent_vals[-1] > recent_vals[0]:
                            lines.append(f"    Trend: INCREASING (3-month)")
                        else:
                            lines.append(f"    Trend: DECREASING (3-month)")
                        lines.append("")

        # WTO Trade Barometer
        if ind in ("WTO", "TRADE", "ALL"):
            wto = _safe_load("wto")
            if not wto.empty:
                if 'as_of_date' in wto.columns:
                    wto['as_of_date'] = pd.to_datetime(wto['as_of_date'])
                    wto = wto.sort_values('as_of_date')
                latest_wto = wto.tail(5)
                if len(latest_wto) > 0:
                    lines.append("  WTO GOODS TRADE BAROMETER")
                    for _, r in latest_wto.iterrows():
                        name = r.get('indicator_name', r.get('indicator_code', 'N/A'))
                        val = r.get('value', 'N/A')
                        date = r.get('as_of_date', 'N/A')
                        if hasattr(date, 'strftime'):
                            date = date.strftime('%Y-%m')
                        lines.append(f"    {name}: {val}  ({date})")
                    lines.append(f"    (100 = trend; >100 = above-trend growth)")
                    lines.append("")

        # World Bank indicators
        if ind in ("GDP", "INFLATION", "ALL"):
            wdi = _safe_load("wdi")
            if not wdi.empty and 'indicator_code' in wdi.columns:
                lines.append("  WORLD BANK DEVELOPMENT INDICATORS")
                for code, label in [("NY.GDP.MKTP.KD.ZG", "GDP Growth"), ("FP.CPI.TOTL.ZG", "Inflation (CPI)")]:
                    if ind != "ALL" and label.split()[0].upper() != ind:
                        continue
                    subset = wdi[wdi['indicator_code'] == code]
                    if not subset.empty:
                        if 'as_of_date' in subset.columns:
                            subset = subset.copy()
                            subset['as_of_date'] = pd.to_datetime(subset['as_of_date'])
                            subset = subset.sort_values('as_of_date')
                        latest_row = subset.tail(1).iloc[0]
                        val = latest_row.get('value', 'N/A')
                        date = latest_row.get('as_of_date', 'N/A')
                        if hasattr(date, 'strftime'):
                            date = date.strftime('%Y')
                        lines.append(f"    {label}: {val}%  ({date})")
                lines.append("")

        if len(lines) <= 3:
            lines.append("  No macroeconomic data available. Run ingestion notebooks 09-13 first.")

        return "\n".join(lines)
    except Exception as e:
        return f"Error retrieving macro context: {e}"


---
# Tools: Dashboard (NEW)


In [None]:
@tool
def get_supply_chain_health() -> str:
    """
    Generate a holistic supply chain health dashboard scoring demand, risk, suppliers, and materials.
    No arguments required.
    """
    try:
        scores = {}

        # Dimension 1: Demand stability
        demand = _safe_load("demand_signals")
        if not demand.empty:
            demand['month'] = pd.to_datetime(demand['month'])
            recent = demand.sort_values('month').tail(6)
            cv = recent['total_obligations_usd'].std() / recent['total_obligations_usd'].mean()
            scores['Demand Stability'] = max(0, min(100, 100 - cv * 200))
        else:
            scores['Demand Stability'] = None

        # Dimension 2: Risk environment
        if not demand.empty:
            recent = demand.sort_values('month').tail(3)
            risk_cols = [c for c in ['geo_risk_index', 'tariff_risk_index', 'weather_disruption_index', 'commodity_cost_pressure']
                         if c in recent.columns]
            if risk_cols:
                avg_risk = recent[risk_cols].mean().mean()
                scores['Risk Environment'] = max(0, min(100, 100 - avg_risk * 10))
            else:
                scores['Risk Environment'] = 50
        else:
            scores['Risk Environment'] = None

        # Dimension 3: Supplier base
        suppliers = _safe_load("suppliers")
        if not suppliers.empty:
            n_suppliers = len(suppliers)
            n_regions = suppliers['region_group'].nunique() if 'region_group' in suppliers.columns else 1
            # More suppliers & regions = healthier
            supplier_score = min(100, n_suppliers / 2 + n_regions * 10)
            scores['Supplier Base'] = supplier_score
        else:
            scores['Supplier Base'] = None

        # Dimension 4: Commodity costs
        commodity = _safe_load("commodity")
        if not commodity.empty:
            commodity['month'] = pd.to_datetime(commodity['month'])
            latest = commodity[commodity['month'] == commodity['month'].max()]
            avg_pressure = latest['cost_pressure_score'].mean()
            scores['Material Costs'] = max(0, min(100, 80 - avg_pressure * 2))
        else:
            scores['Material Costs'] = None

        # Dimension 5: Forecast confidence
        if not demand.empty:
            recent12 = demand.sort_values('month').tail(12)
            y = recent12['total_obligations_usd'].values
            cv12 = y.std() / y.mean()
            _, _, r_val, _, _ = stats.linregress(np.arange(len(y)), y)
            scores['Forecast Reliability'] = (max(0, 100 - cv12 * 200) * 0.5 + (r_val ** 2) * 100 * 0.5)
        else:
            scores['Forecast Reliability'] = None

        # Dimension 6: DoD readiness
        dod = _safe_load("dod_metrics")
        if not dod.empty:
            dod['month'] = pd.to_datetime(dod['month'])
            latest_dod = dod.sort_values('month').iloc[-1]
            dos = latest_dod.get('days_of_supply_proxy', 0)
            nmcs = str(latest_dod.get('nmcs_risk_indicator', ''))
            dos_score = min(100, dos / 60 * 100)
            nmcs_penalty = 30 if 'HIGH' in nmcs else 15 if 'ELEVATED' in nmcs else 0
            scores['DoD Readiness'] = max(0, dos_score - nmcs_penalty)
        else:
            scores['DoD Readiness'] = None

        # Build output
        lines = ["SUPPLY CHAIN HEALTH DASHBOARD", "=" * 60, ""]

        valid_scores = {k: v for k, v in scores.items() if v is not None}
        overall = np.mean(list(valid_scores.values())) if valid_scores else 0

        overall_label = "HEALTHY" if overall >= 75 else "CAUTION" if overall >= 50 else "AT RISK" if overall >= 25 else "CRITICAL"
        lines.append(f"  OVERALL HEALTH: {overall:.0f}/100 ({overall_label})")
        lines.append("")

        for dim, score in scores.items():
            if score is not None:
                bar_len = int(score / 5)
                bar = "#" * bar_len + "." * (20 - bar_len)
                label = "Good" if score >= 75 else "Fair" if score >= 50 else "Poor" if score >= 25 else "Critical"
                lines.append(f"  {dim:<22s} [{bar}] {score:5.0f}  ({label})")
            else:
                lines.append(f"  {dim:<22s} [   data unavailable   ]")

        lines += [
            "", "ATTENTION AREAS:",
        ]
        for dim, score in sorted(valid_scores.items(), key=lambda x: x[1]):
            if score < 50:
                lines.append(f"  - {dim}: {score:.0f}/100 -- needs attention")

        if all(s >= 75 for s in valid_scores.values()):
            lines.append("  All dimensions healthy.")

        return "\n".join(lines)
    except Exception as e:
        return f"Error generating health dashboard: {e}"


In [None]:
@tool
def generate_executive_briefing() -> str:
    """
    Generate a concise executive briefing covering demand outlook, risk environment,
    supply chain health, and recommended actions. No arguments required.
    """
    try:
        lines = [
            "EXECUTIVE SUPPLY CHAIN BRIEFING",
            f"  Date: {datetime.now().strftime('%B %d, %Y')}",
            "=" * 60, "",
        ]

        # 1. Demand outlook
        demand = _safe_load("demand_signals")
        if not demand.empty:
            demand['month'] = pd.to_datetime(demand['month'])
            demand = demand.sort_values('month')
            recent3 = demand.tail(3)['total_obligations_usd']
            recent12 = demand.tail(12)['total_obligations_usd']
            prior3 = demand.tail(6).head(3)['total_obligations_usd']

            qoq = (recent3.mean() - prior3.mean()) / prior3.mean() * 100 if prior3.mean() > 0 else 0

            lines += [
                "1. DEMAND OUTLOOK",
                f"   Last quarter avg:  ${recent3.mean():,.0f}",
                f"   12-month avg:      ${recent12.mean():,.0f}",
                f"   QoQ change:        {qoq:+.1f}%",
                f"   Trend:             {'Growing' if qoq > 2 else 'Declining' if qoq < -2 else 'Stable'}",
                "",
            ]

        # 2. Risk snapshot
        if not demand.empty:
            latest = demand.iloc[-1]
            lines.append("2. RISK SNAPSHOT")
            for col, label in [('geo_risk_index','Geopolitical'), ('tariff_risk_index','Tariff'),
                               ('commodity_cost_pressure','Commodity'), ('weather_disruption_index','Weather')]:
                val = latest.get(col)
                if pd.notna(val):
                    level = "High" if float(val) > 7 else "Moderate" if float(val) > 3 else "Low"
                    lines.append(f"   {label:<14s} {float(val):5.1f}  ({level})")
            lines.append("")

        # 3. Forecast
        forecast = _safe_load("prophet_forecasts")
        if not forecast.empty:
            forecast['month'] = pd.to_datetime(forecast['month'])
            future = forecast[forecast['month'] > datetime.now()].head(3)
            if not future.empty:
                lines.append("3. FORECAST (Next Quarter)")
                for _, r in future.iterrows():
                    lines.append(f"   {r['month'].strftime('%b %Y')}: ${r['forecast_demand_usd']:,.0f}")
                lines.append("")

        # 4. DoD readiness
        dod = _safe_load("dod_metrics")
        if not dod.empty:
            dod['month'] = pd.to_datetime(dod['month'])
            latest_dod = dod.sort_values('month').iloc[-1]
            lines += [
                "4. DoD READINESS",
                f"   Days of Supply:  {latest_dod.get('days_of_supply_proxy', 'N/A'):.0f}  (target: 60+)",
                f"   NMCS Risk:       {latest_dod.get('nmcs_risk_indicator', 'N/A')}",
                "",
            ]

        # 5. Supply chain pressure
        gscpi = _safe_load("gscpi")
        if not gscpi.empty:
            if 'as_of_date' in gscpi.columns:
                gscpi['as_of_date'] = pd.to_datetime(gscpi['as_of_date'])
                gscpi = gscpi.sort_values('as_of_date')
            gscpi_val = float(gscpi.tail(1).iloc[0].get('value', 0))
            gscpi_level = "Elevated" if gscpi_val > 1 else "Normal" if gscpi_val > -1 else "Low"
            lines += [
                "5. GLOBAL SUPPLY CHAIN PRESSURE",
                f"   NY Fed GSCPI:    {gscpi_val:.2f}  ({gscpi_level})",
                "",
            ]

        # 6. Key actions
        lines.append("6. RECOMMENDED ACTIONS")
        actions = []
        if not demand.empty:
            latest_risks = demand.iloc[-1]
            if float(latest_risks.get('geo_risk_index', 0) or 0) > 7:
                actions.append("Increase safety stock -- elevated geopolitical risk")
            if float(latest_risks.get('tariff_risk_index', 0) or 0) > 7:
                actions.append("Review tariff exposure and diversify sourcing")
            if float(latest_risks.get('commodity_cost_pressure', 0) or 0) > 7:
                actions.append("Lock in commodity contracts at current prices")
        if not dod.empty and str(latest_dod.get('nmcs_risk_indicator', '')) == 'HIGH_RISK':
            actions.append("Address NMCS risk -- expedite critical part orders")
        if not actions:
            actions.append("No urgent actions. Continue standard operations.")
        for a in actions:
            lines.append(f"   - {a}")

        return "\n".join(lines)
    except Exception as e:
        return f"Error generating briefing: {e}"


---
# Agent Assembly


In [None]:
# ── Register all tools ───────────────────────────────────────────────────────
all_tools = [
    # Forecasting
    get_demand_forecast,
    compare_forecast_models,
    assess_forecast_confidence,
    # Analysis
    detect_anomalies,
    detect_trends,
    explain_demand_drivers,
    # Scenarios
    scenario_geopolitical_risk,
    scenario_tariff_increase,
    scenario_weather_disruption,
    build_whatif_scenario,
    # Intelligence
    query_suppliers,
    search_contracts,
    compare_dod_metrics,
    get_commodity_prices,
    get_macro_context,
    # Dashboard
    get_supply_chain_health,
    generate_executive_briefing,
]

print(f"Registered {len(all_tools)} tools:")
for t in all_tools:
    print(f"  - {t.name}")


## System Prompt


In [None]:
SYSTEM_PROMPT = """You are a senior supply chain intelligence analyst for Oshkosh Defense, a major U.S. defense contractor producing tactical and armored vehicles (JLTV, FHTV, FMTV, M-ATV). You have deep expertise in:

- DoD acquisition lifecycle (RO, AAO, NMCS, Days of Supply)
- Defense supply chain risk management (geopolitical, tariff, weather, commodity)
- Federal procurement (FAR, DFARS, USAspending, FPDS, SAM.gov)
- Demand forecasting (Prophet, ARIMA, Random Forest, ensemble methods)
- Material cost dynamics (steel, aluminum, lithium, rare earths)

You have access to 17 specialized tools that query live pipeline data from Unity Catalog tables covering:
- Prime award and subaward contract data from USAspending/FPDS
- SAM.gov defense supplier base (geolocation, size, NAICS, subsystem)
- Trade/tariff risk events from the Federal Register
- Commodity prices from Yahoo Finance
- Weather disruption indices from Meteostat
- World Bank WDI/WGI macroeconomic indicators
- NY Fed Global Supply Chain Pressure Index
- WTO Goods Trade Barometer
- Prophet, ARIMA, and Random Forest demand forecasts
- DoD metrics (Requirements Objective, Days of Supply, NMCS Risk)

IMPORTANT GUIDELINES:
1. Always use tools to fetch real data -- never fabricate numbers
2. When asked about suppliers, contracts, or specific data, use the appropriate search tool
3. Provide actionable insights, not just raw data
4. Reference DoD terminology and metrics where relevant
5. Quantify impacts in dollars and percentages when possible
6. For complex questions, use multiple tools to build a comprehensive answer
7. If a tool returns no data, explain what pipeline step needs to run
8. When comparing scenarios, be explicit about assumptions
9. Flag high-risk situations prominently with recommended actions
10. Tailor detail level to the question -- executive summary for broad questions, deep dive for specific ones"""


## Create Agent with Memory


In [None]:
# ── Conversation memory ──────────────────────────────────────────────────────
conversation_history = []

def _get_history():
    """Return the last N turns of conversation for context."""
    max_turns = 10
    return conversation_history[-(max_turns * 2):]

# ── Build agent ──────────────────────────────────────────────────────────────
prompt = ChatPromptTemplate.from_messages([
    ("system", SYSTEM_PROMPT),
    MessagesPlaceholder(variable_name="chat_history"),
    ("human", "{input}"),
    ("placeholder", "{agent_scratchpad}"),
])

if _create_tool_calling_agent is not None and _AgentExecutor is not None:
    agent = _create_tool_calling_agent(llm, all_tools, prompt)
    agent_executor = _AgentExecutor(
        agent=agent,
        tools=all_tools,
        verbose=True,
        handle_parsing_errors=True,
        max_iterations=10,
    )
    print(f"Agent ready: {len(all_tools)} tools, AgentExecutor, conversation memory")
else:
    # Fallback executor with memory support
    _tools_by_name = {t.name: t for t in all_tools}
    class _MemoryToolCallingExecutor:
        def __init__(self, llm, tools, verbose=True):
            self.llm = llm
            self.tools = tools
            self.verbose = verbose

        def invoke(self, inputs):
            user_input = inputs.get("input", "")
            chat_history = inputs.get("chat_history", [])

            messages = [SystemMessage(content=SYSTEM_PROMPT)]
            messages.extend(chat_history)
            messages.append(HumanMessage(content=user_input))

            max_rounds = 10
            for _ in range(max_rounds):
                response = self.llm.bind_tools(self.tools).invoke(messages)
                if self.verbose and response.content:
                    print(response.content[:300], "..." if len(response.content) > 300 else "")
                if not getattr(response, "tool_calls", None):
                    return {"output": response.content or ""}

                messages.append(response)
                for tc in response.tool_calls:
                    name = tc.get("name", None) if isinstance(tc, dict) else getattr(tc, "name", None)
                    args = tc.get("args", {}) if isinstance(tc, dict) else getattr(tc, "args", {}) or {}
                    tid = tc.get("id", "") if isinstance(tc, dict) else getattr(tc, "id", "")
                    found_tool = _tools_by_name.get(name)
                    if found_tool:
                        if self.verbose:
                            print(f"  [tool] {name}({args})")
                        result = found_tool.invoke(args)
                        messages.append(ToolMessage(content=str(result), tool_call_id=tid))

            return {"output": (response.content or "") + "\n[Max rounds reached.]"}

    agent_executor = _MemoryToolCallingExecutor(llm, all_tools, verbose=True)
    print(f"Agent ready: {len(all_tools)} tools, bind_tools fallback, conversation memory")


## Helper: ask()


In [None]:
def ask(question: str) -> str:
    """
    Send a question to the agent with conversation memory.
    
    Usage:
        answer = ask("What is the demand forecast for next quarter?")
        print(answer)
    """
    # Add user message to history
    conversation_history.append(HumanMessage(content=question))

    # Run agent
    response = agent_executor.invoke({
        "input": question,
        "chat_history": _get_history()[:-1],  # exclude current question (already in input)
    })

    answer = response["output"]

    # Add assistant response to history
    conversation_history.append(AIMessage(content=answer))

    return answer


def reset_memory():
    """Clear conversation history."""
    conversation_history.clear()
    print("Conversation memory cleared.")


print("Use ask('your question') to chat with the agent.")
print("Use reset_memory() to start a new conversation.")


---
# Interactive Chat


## Example Queries

Run any of the cells below, or type your own question in a new cell:
```python
print(ask("your question here"))
```

**Sample questions:**
- "Give me an executive briefing on the supply chain"
- "What is the demand forecast for next quarter?"
- "Show me suppliers in Wisconsin"
- "Search for contracts with PSC code 23"
- "What are the top demand drivers?"
- "How is global supply chain pressure trending?"
- "What if geopolitical risk increases 50% and tariffs go up 25%?"
- "Generate a full supply chain health dashboard"
- "Compare Prophet vs ARIMA vs Random Forest forecasts"
- "How confident should we be in the 3-month forecast?"


In [None]:
# Example 1: Executive briefing
print(ask("Give me a concise executive briefing on the current state of our supply chain."))


In [None]:
# Example 2: Supplier search (multi-turn)
print(ask("Show me our defense suppliers in Wisconsin and Texas. How many are small businesses?"))


In [None]:
# Example 3: Follow-up with memory
print(ask("Based on those suppliers, which subsystem categories have the least diversification?"))


In [None]:
# Example 4: Multi-tool scenario
print(ask("What if geopolitical risk spikes 75%, tariffs rise 30%, and steel prices jump 20%? What's the combined impact and what should we do?"))


In [None]:
# Example 5: Health dashboard
print(ask("Generate a supply chain health dashboard and highlight anything below 50%."))


---
## Tool Reference

| # | Tool | Category | Description |
|---|------|----------|-------------|
| 1 | `get_demand_forecast` | Forecasting | Retrieve demand forecast with confidence intervals |
| 2 | `compare_forecast_models` | Forecasting | Compare Prophet, ARIMA, and Random Forest side-by-side |
| 3 | `assess_forecast_confidence` | Forecasting | Score forecast reliability (0-100) |
| 4 | `detect_anomalies` | Analysis | Find demand anomalies vs historical baseline |
| 5 | `detect_trends` | Analysis | Detect growth, seasonality, volatility, correlations |
| 6 | `explain_demand_drivers` | Analysis | Feature importance from Random Forest model |
| 7 | `scenario_geopolitical_risk` | Scenarios | Model geopolitical risk impact on demand |
| 8 | `scenario_tariff_increase` | Scenarios | Analyze tariff cost impact |
| 9 | `scenario_weather_disruption` | Scenarios | Model weather disruption scenarios |
| 10 | `build_whatif_scenario` | Scenarios | Custom multi-factor what-if builder |
| 11 | `query_suppliers` | Intelligence | Search SAM.gov supplier base |
| 12 | `search_contracts` | Intelligence | Search FPDS contract data |
| 13 | `compare_dod_metrics` | Intelligence | Check DoD supply chain metrics (RO, AAO, NMCS) |
| 14 | `get_commodity_prices` | Intelligence | Monitor defense-critical material prices |
| 15 | `get_macro_context` | Intelligence | World Bank, NY Fed GSCPI, WTO indicators |
| 16 | `get_supply_chain_health` | Dashboard | Holistic health score across 6 dimensions |
| 17 | `generate_executive_briefing` | Dashboard | One-page exec briefing with actions |


---
## Next Steps

1. **Deploy as Model Serving endpoint** — register the agent with MLflow and serve via Databricks Model Serving for API access
2. **Connect to Databricks Genie** — enable natural language queries over the gold tables
3. **Add Slack/Teams integration** — pipe `ask()` through a webhook for chat-based supply chain queries
4. **Schedule executive briefings** — run `generate_executive_briefing` on a daily/weekly Databricks Job
5. **Add alerting** — trigger notifications when `get_supply_chain_health` detects scores below threshold
