## Load & Validate Data

In [0]:
%sql
-- Quick validation
SELECT 
  'Total Rows' as metric, COUNT(*) as value FROM commodity.silver.unified_data
UNION ALL
SELECT 'Unique (date, commodity, region)', COUNT(DISTINCT date, commodity, region) FROM commodity.silver.unified_data
UNION ALL
SELECT 'Duplicates', COUNT(*) - COUNT(DISTINCT date, commodity, region) FROM commodity.silver.unified_data
UNION ALL
SELECT 'Null Close Prices', SUM(CASE WHEN close IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'Null VIX', SUM(CASE WHEN vix IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data;


In [0]:
%sql
-- Check for nulls across ALL columns
SELECT 
  'date' as column_name, SUM(CASE WHEN date IS NULL THEN 1 ELSE 0 END) as null_count FROM commodity.silver.unified_data
UNION ALL
SELECT 'commodity', SUM(CASE WHEN commodity IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'close', SUM(CASE WHEN close IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'high', SUM(CASE WHEN high IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'low', SUM(CASE WHEN low IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'open', SUM(CASE WHEN open IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'volume', SUM(CASE WHEN volume IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'vix', SUM(CASE WHEN vix IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'region', SUM(CASE WHEN region IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'temp_c', SUM(CASE WHEN temp_c IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'humidity_pct', SUM(CASE WHEN humidity_pct IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'precipitation_mm', SUM(CASE WHEN precipitation_mm IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'vnd_usd', SUM(CASE WHEN vnd_usd IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'cop_usd', SUM(CASE WHEN cop_usd IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'idr_usd', SUM(CASE WHEN idr_usd IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'uah_usd', SUM(CASE WHEN uah_usd IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'irr_usd', SUM(CASE WHEN irr_usd IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
UNION ALL
SELECT 'byn_usd', SUM(CASE WHEN byn_usd IS NULL THEN 1 ELSE 0 END) FROM commodity.silver.unified_data
ORDER BY column_name;

In [0]:
# Save result to dataframe
df_spark = spark.table("commodity.silver.unified_data")
df_spark.cache()
df_spark.take(10)


In [0]:
%sql
SELECT *
FROM commodity.silver.unified_data
WHERE region IN ('Bugisu_Uganda','Chiapas_Mexico')
    AND commodity IN ('Coffee')
ORDER BY date DESC
LIMIT 15
;

## Low-accuracy Forecast for Integration Testing

In [0]:
"""
Simple Coffee Forecast - Databricks PySpark Notebook
=====================================================

Purpose: Establish data contract with Risk Agent by producing backdated forecast 
distributions for backtesting.

Focus: Simple implementation first, iterate on accuracy later.

Output Contract:
- Point forecasts (14-day ahead with confidence intervals)
- Distribution paths (2,000 Monte Carlo samples)
- Backdated with data_cutoff_date for proper backtesting
- Partitioned by model_type_version (e.g., sarimax_v0)
- Saved to Delta tables in commodity.default schema
"""

# ============================================================================
# 1. SETUP & IMPORTS
# ============================================================================

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from statsmodels.tsa.statespace.sarimax import SARIMAX
import warnings
warnings.filterwarnings('ignore')

# Configuration
MODEL_VERSION = "sarimax_v0"
N_PATHS = 2000
FORECAST_HORIZON = 14
BACKTESTING_START_DATE = '2018-01-01'
FORECAST_FREQUENCY = '1D'  # Daily forecasts for complete coverage

# Delta table configuration
CATALOG = "commodity"
SCHEMA = "silver"  
FORECAST_TABLE = f"{CATALOG}.{SCHEMA}.coffee_point_forecasts"
DISTRIBUTION_TABLE = f"{CATALOG}.{SCHEMA}.coffee_distributions"

print(f"""
╔══════════════════════════════════════════════════════════════╗
║           COFFEE FORECAST CONFIGURATION                      ║
╚══════════════════════════════════════════════════════════════╝
  Model Version:     {MODEL_VERSION}
  Forecast Horizon:  {FORECAST_HORIZON} days
  Frequency:         {FORECAST_FREQUENCY} (daily)
  Start Date:        {BACKTESTING_START_DATE}
  Sample Paths:      {N_PATHS:,}
  
  Output Tables:
    • {FORECAST_TABLE}
    • {DISTRIBUTION_TABLE}
  
  Estimated: ~2,500 forecasts, ~60-90 min runtime
╚══════════════════════════════════════════════════════════════╝
""")

# ============================================================================
# 2. LOAD UNIFIED DATA
# ============================================================================

print("Loading unified_data from Databricks...")
print("(Make sure you've run the unified_data temp view creation SQL first!)")

# Load from the temp view you created
df_spark = spark.sql("""
    SELECT 
        date,
        is_trading_day,
        commodity,
        close,
        high,
        low,
        open,
        volume,
        vix,
        region,
        temp_c,
        humidity_pct,
        precipitation_mm
    FROM commodity.silver.unified_data
    WHERE commodity = 'Coffee'
    AND is_trading_day = 1  -- Only trading days for modeling
    ORDER BY date, region
""")

print(f"Loaded {df_spark.count()} rows")

# ============================================================================
# 3. PREPARE DATA FOR MODELING
# ============================================================================

print("Preparing data for time series modeling...")

# Convert to pandas for SARIMAX (simpler for MVP)
# For a simple model, we'll aggregate weather across regions (mean)
df_agg = df_spark.groupBy("date", "commodity").agg(
    first("close").alias("close"),
    first("vix").alias("vix"),
    avg("temp_c").alias("avg_temp"),
    avg("humidity_pct").alias("avg_humidity"),
    avg("precipitation_mm").alias("avg_precip")
).orderBy("date")

df_pd = df_agg.toPandas()
df_pd['date'] = pd.to_datetime(df_pd['date'])
df_pd = df_pd.set_index('date').sort_index()

print(f"Prepared {len(df_pd)} daily observations from {df_pd.index.min()} to {df_pd.index.max()}")

# ============================================================================
# 4. SIMPLE FORECASTING FUNCTION
# ============================================================================

def create_simple_forecast(train_data, forecast_date, n_days=14):
    """
    Create a simple SARIMAX forecast with minimal covariates.
    
    Args:
        train_data: DataFrame with close price and covariates
        forecast_date: Date to generate forecast for
        n_days: Number of days ahead to forecast
        
    Returns:
        forecast_mean: Array of point forecasts
        forecast_std: Array of standard errors
        success: Boolean indicating if model converged
    """
    
    # Prepare target and covariates
    y = train_data['close']
    
    # Use minimal covariates for MVP (just VIX and temperature)
    exog = train_data[['vix', 'avg_temp']]
    
    try:
        # Fit SARIMAX(1,1,1) - same as V1 champion model
        model = SARIMAX(
            y,
            exog=exog,
            order=(1, 1, 1),
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        
        results = model.fit(disp=False, maxiter=100)
        
        # Generate forecast
        # For exog forecast, use last available values (simple forward fill)
        last_exog = exog.iloc[-1:].values
        exog_forecast = np.tile(last_exog, (n_days, 1))
        
        forecast_obj = results.get_forecast(steps=n_days, exog=exog_forecast)
        forecast_mean = forecast_obj.predicted_mean.values
        forecast_std = forecast_obj.se_mean.values
        
        return forecast_mean, forecast_std, True
        
    except Exception as e:
        print(f"Warning: Forecast failed for {forecast_date}: {str(e)}")
        # Return naive forecast (last value) if model fails
        last_price = y.iloc[-1]
        forecast_mean = np.full(n_days, last_price)
        forecast_std = np.full(n_days, y.std())
        return forecast_mean, forecast_std, False

# ============================================================================
# 5. GENERATE SAMPLE PATHS (MONTE CARLO DISTRIBUTION)
# ============================================================================

def generate_sample_paths(forecast_mean, forecast_std, n_paths=2000):
    """
    Generate Monte Carlo sample paths for distribution.
    
    Simple approach: Normal distribution with increasing variance.
    Future enhancement: Add skewness, regime-specific distributions, etc.
    """
    
    n_days = len(forecast_mean)
    paths = np.zeros((n_paths, n_days))
    
    for day in range(n_days):
        # Generate samples from normal distribution
        # Variance increases with forecast horizon (uncertainty grows)
        daily_std = forecast_std[day] * np.sqrt(day + 1)
        paths[:, day] = np.random.normal(forecast_mean[day], daily_std, n_paths)
    
    return paths



In [0]:
# ============================================================================
# 6. BACKTESTING LOOP - GENERATE BACKDATED FORECASTS
# ============================================================================

print("\nGenerating backdated forecasts for backtesting...")

# Define backtesting period
START_DATE = pd.Timestamp(BACKTESTING_START_DATE)
END_DATE = df_pd.index.max() - timedelta(days=FORECAST_HORIZON)

# Walk-forward backtesting - Daily forecasts for complete coverage
backtest_dates = pd.date_range(start=START_DATE, end=END_DATE, freq=FORECAST_FREQUENCY)
MIN_TRAIN_DAYS = 365

forecasts_list = []
distributions_list = []

print(f"Generating {len(backtest_dates)} forecasts from {START_DATE} to {END_DATE}")
print(f"Forecast frequency: {FORECAST_FREQUENCY} (daily - covers all days of week)")
print(f"Forecast horizon: {FORECAST_HORIZON} days ahead")
print(f"This will create overlapping forecasts for realistic backtesting\n")

# START TIMING
import time
start_time = time.time()
forecast_times = []

for i, forecast_date in enumerate(backtest_dates):
    
    # Time each forecast
    forecast_start = time.time()
    
    if (i + 1) % 100 == 0:
        elapsed = time.time() - start_time
        avg_time = np.mean(forecast_times[-100:]) if len(forecast_times) >= 100 else np.mean(forecast_times)
        remaining = (len(backtest_dates) - i - 1) * avg_time
        print(f"Progress: {i+1}/{len(backtest_dates)} forecasts | "
              f"Elapsed: {elapsed/60:.1f} min | "
              f"Avg: {avg_time:.2f}s/forecast | "
              f"ETA: {remaining/60:.1f} min")
    
    # Get training data up to (but not including) forecast_date
    train_end_date = forecast_date - timedelta(days=1)
    train_start_date = train_end_date - timedelta(days=730)
    
    train_data = df_pd.loc[train_start_date:train_end_date]
    
    # Skip if insufficient data
    if len(train_data) < MIN_TRAIN_DAYS:
        continue
    
    # Generate forecast
    forecast_mean, forecast_std, success = create_simple_forecast(
        train_data, 
        forecast_date,
        n_days=FORECAST_HORIZON
    )
    
    # Create forecast dates
    forecast_dates = pd.date_range(start=forecast_date, periods=FORECAST_HORIZON, freq='D')
    
    # Build point forecasts DataFrame
    for day_idx, fcast_date in enumerate(forecast_dates):
        forecasts_list.append({
            'forecast_date': fcast_date,
            'data_cutoff_date': train_end_date,
            'generation_timestamp': datetime.now(),
            'day_ahead': day_idx + 1,
            'forecast_mean': float(forecast_mean[day_idx]),
            'forecast_std': float(forecast_std[day_idx]),
            'lower_95': float(forecast_mean[day_idx] - 1.96 * forecast_std[day_idx]),
            'upper_95': float(forecast_mean[day_idx] + 1.96 * forecast_std[day_idx]),
            'model_version': MODEL_VERSION,
            'commodity': 'Coffee',
            'model_success': success
        })
    
    # Generate sample paths for distribution
    sample_paths = generate_sample_paths(forecast_mean, forecast_std, N_PATHS)
    
    # Build distributions DataFrame
    for path_id in range(N_PATHS):
        path_data = {
            'path_id': path_id + 1,
            'forecast_start_date': forecast_date,
            'data_cutoff_date': train_end_date,
            'generation_timestamp': datetime.now(),
            'model_version': MODEL_VERSION,
            'commodity': 'Coffee'
        }
        
        # Add each day's forecasted price
        for day_idx in range(FORECAST_HORIZON):
            path_data[f'day_{day_idx + 1}'] = float(sample_paths[path_id, day_idx])
        
        distributions_list.append(path_data)
    
    # Record forecast time
    forecast_times.append(time.time() - forecast_start)

# END TIMING
total_time = time.time() - start_time

print(f"\nGenerated {len(forecasts_list)} point forecasts")
print(f"Generated {len(distributions_list)} distribution paths")
print(f"\nTotal Runtime: {total_time/60:.2f} minutes ({total_time/3600:.2f} hours)")
print(f"   Average: {np.mean(forecast_times):.2f} seconds per forecast")
print(f"   Min: {np.min(forecast_times):.2f}s | Max: {np.max(forecast_times):.2f}s")

# ============================================================================
# 7. CREATE SPARK DATAFRAMES
# ============================================================================

print("\nCreating Spark DataFrames...")

# Point forecasts
df_forecasts = spark.createDataFrame(pd.DataFrame(forecasts_list))

# Distributions
df_distributions = spark.createDataFrame(pd.DataFrame(distributions_list))

print(f"Point forecasts shape: {df_forecasts.count()} rows")
print(f"Distributions shape: {df_distributions.count()} rows")

In [0]:
# ============================================================================
# 8. WRITE TO DELTA TABLES
# ============================================================================

print("\nWriting to Delta tables...")

# Write point forecasts
print(f"Writing point forecasts to {FORECAST_TABLE}...")
df_forecasts.write \
    .mode("overwrite") \
    .format("delta") \
    .partitionBy("model_version", "commodity") \
    .option("overwriteSchema", "true") \
    .saveAsTable(FORECAST_TABLE)

print(f"✓ Point forecasts written to: {FORECAST_TABLE}")

# Write distributions
print(f"Writing distributions to {DISTRIBUTION_TABLE}...")
df_distributions.write \
    .mode("overwrite") \
    .format("delta") \
    .partitionBy("model_version", "commodity") \
    .option("overwriteSchema", "true") \
    .saveAsTable(DISTRIBUTION_TABLE)

print(f"✓ Distributions written to: {DISTRIBUTION_TABLE}")

# Show table locations
forecast_location = spark.sql(f"DESCRIBE DETAIL {FORECAST_TABLE}").select("location").collect()[0][0]
dist_location = spark.sql(f"DESCRIBE DETAIL {DISTRIBUTION_TABLE}").select("location").collect()[0][0]

print(f"\nTable locations:")
print(f"  Point forecasts: {forecast_location}")
print(f"  Distributions: {dist_location}")

In [0]:
# # ============================================================================
# # 8. WRITE TO CSV FILES (While we wait on save permissions)
# # ============================================================================

# print("\nWriting to CSV files...")
# write_start = time.time()

# # Define output paths (local DBFS)
# OUTPUT_BASE_PATH = "/dbfs/FileStore/coffee_forecasts/"
# FORECAST_CSV_PATH = f"{OUTPUT_BASE_PATH}point_forecasts/"
# DISTRIBUTION_CSV_PATH = f"{OUTPUT_BASE_PATH}distributions/"

# # Create directories if they don't exist
# import os
# os.makedirs(FORECAST_CSV_PATH, exist_ok=True)
# os.makedirs(DISTRIBUTION_CSV_PATH, exist_ok=True)

# # Convert to pandas and write CSV
# print(f"Writing point forecasts to {FORECAST_CSV_PATH}...")
# df_forecasts_pd = df_forecasts.toPandas()
# df_forecasts_pd.to_csv(
#     f"{FORECAST_CSV_PATH}coffee_point_forecasts_{MODEL_VERSION}.csv",
#     index=False
# )
# print(f"✓ Point forecasts written: {len(df_forecasts_pd):,} rows")

# # Write distributions (this will be large!)
# print(f"Writing distributions to {DISTRIBUTION_CSV_PATH}...")
# df_distributions_pd = df_distributions.toPandas()
# df_distributions_pd.to_csv(
#     f"{DISTRIBUTION_CSV_PATH}coffee_distributions_{MODEL_VERSION}.csv",
#     index=False
# )
# print(f"✓ Distributions written: {len(df_distributions_pd):,} rows")

# write_time = time.time() - write_start
# print(f"\n⏱️  Write time: {write_time:.2f} seconds")

# # Show file locations and sizes
# print(f"\nFiles created:")
# print(f"  Point forecasts: /FileStore/coffee_forecasts/point_forecasts/coffee_point_forecasts_{MODEL_VERSION}.csv")
# print(f"  Distributions: /FileStore/coffee_forecasts/distributions/coffee_distributions_{MODEL_VERSION}.csv")
# print(f"\nDownload via Databricks UI: Workspace → FileStore → coffee_forecasts/")

# Coffee Forecast Data - Risk Agent Quick Start

## Data Locations

**Point Forecasts:**
`https://dbc-5474a94c-61c9.cloud.databricks.com/files/coffee_forecasts/point_forecasts/coffee_point_forecasts_sarimax_v0.csv`

**Distributions:**
`https://dbc-5474a94c-61c9.cloud.databricks.com/files/coffee_forecasts/distributions/coffee_distributions_sarimax_v0.csv`

### Production (S3) - Coming Soon, pending access
**Point Forecasts:**
`commodity.default.coffee_point_forecasts`

**Distributions:**
`commodity.default.coffee_distributions`

## Loading Data in Databricks
```python
# Point Forecasts
df_forecasts = spark.read.csv(
    "/FileStore/coffee_forecasts/point_forecasts/coffee_point_forecasts_sarimax_v0.csv",
    header=True,
    inferSchema=True
)

# Distributions
df_distributions = spark.read.csv(
    "/FileStore/coffee_forecasts/distributions/coffee_distributions_sarimax_v0.csv",
    header=True,
    inferSchema=True
)
```

### From Delta Tables (Production) - Coming Soon
```python
# Point Forecasts
df_forecasts = spark.table("commodity.default.coffee_point_forecasts")

# Distributions
df_distributions = spark.table("commodity.default.coffee_distributions")
```


## Point Forecasts Schema

| Field | Type | Description |
|-------|------|-------------|
| `forecast_date` | DATE | Target date being forecasted |
| `data_cutoff_date` | DATE | Last training date (must be < forecast_date) |
| `day_ahead` | INT | Horizon (1-14 days) |
| `forecast_mean` | FLOAT | Point forecast (cents/lb) |
| `forecast_std` | FLOAT | Forecast uncertainty |
| `lower_95` | FLOAT | 95% CI lower bound |
| `upper_95` | FLOAT | 95% CI upper bound |
| `model_version` | STRING | 'sarimax_v0' |
| `commodity` | STRING | 'Coffee' |

**Usage:**
```sql
-- Get 7-day ahead forecasts for backtesting
SELECT forecast_date, forecast_mean, lower_95, upper_95
FROM point_forecasts
WHERE day_ahead = 7
AND data_cutoff_date < forecast_date  -- CRITICAL: No data leakage
AND forecast_date BETWEEN '2023-01-01' AND '2023-12-31'
```

## Distributions Schema

| Field | Type | Description |
|-------|------|-------------|
| `path_id` | INT | Sample path ID (1-2000) |
| `forecast_start_date` | DATE | First day of forecast |
| `data_cutoff_date` | DATE | Last training date |
| `day_1` to `day_14` | FLOAT | Forecasted prices for each day |
| `model_version` | STRING | 'sarimax_v0' |
| `commodity` | STRING | 'Coffee' |

**Usage:**
```sql
-- Calculate 95% VaR for day 7
SELECT 
    forecast_start_date,
    PERCENTILE(day_7, 0.05) as var_95,
    AVG(day_7) as mean_price
FROM distributions
WHERE forecast_start_date = '2024-01-15'
AND data_cutoff_date < '2024-01-15'  -- No data leakage
GROUP BY forecast_start_date
```

## Key Concepts

### 1. Data Leakage Prevention
**ALWAYS filter:** `data_cutoff_date < forecast_date` or `forecast_start_date`

### 2. Overlapping Forecasts
Daily frequency = 14 forecasts exist for any target date
- Forecast from Jan 7 (day_14) for Jan 21
- Forecast from Jan 8 (day_13) for Jan 21
- ...
- Forecast from Jan 20 (day_1) for Jan 21

### 3. Distribution = 2,000 Paths
Each forecast date has 2,000 Monte Carlo samples for risk analysis


## Common Queries

### Backtest Forecast Accuracy
```sql
SELECT 
    pf.forecast_date,
    pf.forecast_mean,
    actual.close as actual_price,
    ABS(pf.forecast_mean - actual.close) as error
FROM point_forecasts pf
JOIN unified_data actual 
    ON actual.date = pf.forecast_date 
    AND actual.commodity = 'Coffee'
WHERE pf.day_ahead = 7
AND pf.data_cutoff_date < pf.forecast_date
```

### Calculate VaR/CVaR
```sql
-- 95% VaR
SELECT 
    forecast_start_date,
    PERCENTILE(day_14, 0.05) as var_95
FROM distributions
WHERE forecast_start_date = CURRENT_DATE()
GROUP BY forecast_start_date;

-- CVaR (average loss in worst 5%)
WITH var AS (
    SELECT PERCENTILE(day_14, 0.05) as threshold
    FROM distributions
    WHERE forecast_start_date = CURRENT_DATE()
)
SELECT AVG(day_14) as cvar_95
FROM distributions, var
WHERE day_14 <= var.threshold
AND forecast_start_date = CURRENT_DATE();
```

## Data Quality Checks
```sql
-- 1. Verify no data leakage (MUST return 0)
SELECT COUNT(*) FROM point_forecasts 
WHERE forecast_date <= data_cutoff_date;

-- 2. Verify 2,000 paths per date
SELECT forecast_start_date, COUNT(*) as paths
FROM distributions
GROUP BY forecast_start_date
HAVING COUNT(*) != 2000;

-- 3. Check date coverage
SELECT 
    MIN(forecast_date) as earliest,
    MAX(forecast_date) as latest,
    COUNT(DISTINCT forecast_date) as num_dates
FROM point_forecasts;
```

## Detials

**Data Period:** 2018-01-01 to present  
**Update Frequency:** Daily (production)  
**Model:** SARIMAX(1,1,1) baseline

In [0]:
# ============================================================================
# 9. VALIDATION QUERIES
# ============================================================================

print("\n" + "="*70)
print("VALIDATION CHECKS")
print("="*70)

# Check 1: Point forecast summary
print("\n1. Point Forecasts Summary:")
spark.sql(f"""
    SELECT 
        model_version,
        commodity,
        COUNT(*) as total_forecasts,
        COUNT(DISTINCT data_cutoff_date) as unique_cutoff_dates,
        MIN(forecast_date) as earliest_forecast,
        MAX(forecast_date) as latest_forecast,
        AVG(forecast_mean) as avg_forecast_price,
        AVG(forecast_std) as avg_std_error
    FROM {FORECAST_TABLE}
    GROUP BY model_version, commodity
""").show(truncate=False)

# Check 2: Distribution summary
print("\n2. Distribution Summary:")
spark.sql(f"""
    SELECT 
        model_version,
        commodity,
        COUNT(*) as total_paths,
        COUNT(DISTINCT forecast_start_date) as unique_forecast_dates,
        MIN(forecast_start_date) as earliest_forecast,
        MAX(forecast_start_date) as latest_forecast
    FROM {DISTRIBUTION_TABLE}
    GROUP BY model_version, commodity
""").show(truncate=False)

# Check 3: Sample distribution statistics (Day 1 vs Day 14)
print("\n3. Sample Distribution Statistics (Day 1 vs Day 14):")
spark.sql(f"""
    SELECT 
        'Day 1' as forecast_day,
        AVG(day_1) as mean_price,
        STDDEV(day_1) as std_price,
        PERCENTILE(day_1, 0.05) as p5,
        PERCENTILE(day_1, 0.50) as p50,
        PERCENTILE(day_1, 0.95) as p95
    FROM {DISTRIBUTION_TABLE}
    WHERE model_version = 'sarimax_v0'
    
    UNION ALL
    
    SELECT 
        'Day 14' as forecast_day,
        AVG(day_14) as mean_price,
        STDDEV(day_14) as std_price,
        PERCENTILE(day_14, 0.05) as p5,
        PERCENTILE(day_14, 0.50) as p50,
        PERCENTILE(day_14, 0.95) as p95
    FROM {DISTRIBUTION_TABLE}
    WHERE model_version = 'sarimax_v0'
""").show(truncate=False)

# Check 4: Data leakage validation (CRITICAL!)
print("\n4. Data Cutoff Validation (ensuring no data leakage):")
spark.sql(f"""
    SELECT 
        COUNT(*) as total_forecasts,
        SUM(CASE WHEN forecast_date > data_cutoff_date THEN 1 ELSE 0 END) as valid_forecasts,
        SUM(CASE WHEN forecast_date <= data_cutoff_date THEN 1 ELSE 0 END) as data_leakage_errors
    FROM {FORECAST_TABLE}
""").show(truncate=False)

# Check 5: Sample overlapping forecasts
print("\n5. Sample Overlapping Forecasts (for one target date):")
spark.sql(f"""
    SELECT 
        forecast_date as target_date,
        data_cutoff_date,
        day_ahead,
        forecast_mean,
        DATEDIFF(forecast_date, data_cutoff_date) as forecast_age_days
    FROM {FORECAST_TABLE}
    WHERE forecast_date = (
        SELECT MAX(forecast_date) 
        FROM {FORECAST_TABLE} 
        WHERE day_ahead <= 14
    )
    AND data_cutoff_date < forecast_date
    ORDER BY data_cutoff_date DESC
    LIMIT 14
""").show(truncate=False)

print("\n" + "="*70)
print("DATA CONTRACT ESTABLISHED ✓")
print("="*70)
print(f"""
Risk Agent can now access backdated forecasts from Delta tables:
- Point Forecasts: {FORECAST_TABLE}
- Distributions: {DISTRIBUTION_TABLE}

Query example for overlapping forecasts:
SELECT 
    forecast_date,
    data_cutoff_date,
    day_ahead,
    forecast_mean,
    DATEDIFF(forecast_date, data_cutoff_date) as forecast_age_days
FROM {FORECAST_TABLE}
WHERE forecast_date = '2024-10-15'
AND data_cutoff_date < '2024-10-15'
ORDER BY data_cutoff_date DESC;

This returns all 14 forecasts for Oct 15 (made from Oct 1-14)

Next Steps:
1. Risk Agent queries tables directly for backtesting
2. Iterate on forecast accuracy (add more covariates, tune hyperparameters)
3. Add Sugar commodity forecasts
4. Implement skewed distributions (Phase 2)
5. Add hierarchical regional forecasts (Phase 2)
""")

print("\nForecast generation complete!")