In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re
from datetime import datetime
import warnings
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
warnings.filterwarnings('ignore')

# Add for new models
from statsmodels.tsa.statespace.sarimax import SARIMAX

# --- Crop Recommendation Logic ---
def recommend_crop(avg_temp, avg_ndvi):
    """
    Recommends a crop (Rice, Cassava, Maize) based on average forecasted
    Temperature (LST) and NDVI.

    Logic based on patterns observed in combined_with_plants.csv:
    - Cassava: High Temp + High/Moderate NDVI, or High Temp + Low NDVI (drought tolerant)
    - Rice: Moderate Temp + High NDVI (implies sufficient water)
    - Maize: Moderate/Slightly Lower Temp + Moderate NDVI, or Lower Temp overall
    """
    high_temp_threshold = 29.0
    moderate_temp_upper = 31.0
    moderate_ndvi_threshold = 0.55
    recommendation = "Unclassified"

    if avg_temp > high_temp_threshold:
        if avg_ndvi > moderate_ndvi_threshold:
            recommendation = "Cassava"
            reason = "High temperature and high NDVI are ideal for Cassava."
        else:
            recommendation = "Cassava"
            reason = "High temperature; Cassava is the most drought-tolerant option."
    else:
        if avg_ndvi > moderate_ndvi_threshold:
            if high_temp_threshold < avg_temp <= moderate_temp_upper:
                recommendation = "Rice"
                reason = "Moderate temperature with high NDVI suggests good water availability, suitable for Rice."
            else:
                if avg_temp <= high_temp_threshold:
                    recommendation = "Maize"
                    reason = "Moderate/lower temperature with high NDVI, suitable for Maize."
                else:
                    recommendation = "Rice"
                    reason = "Conditions support high vegetation health, suitable for Rice."
        else:
            if avg_temp > high_temp_threshold - 1.5:
                recommendation = "Cassava"
                reason = "Warmer conditions, even with moderate NDVI, favor drought-tolerant Cassava."
            else:
                recommendation = "Maize"
                reason = "Moderate/lower temperature with moderate/low NDVI suits Maize."

    if recommendation == "Unclassified":
        if avg_temp > high_temp_threshold:
            recommendation = "Cassava"
            reason = "Defaulting to Cassava for high temperature."
        elif avg_ndvi > moderate_ndvi_threshold:
            recommendation = "Rice"
            reason = "Defaulting to Rice for high NDVI with non-extreme temp."
        else:
            recommendation = "Maize"
            reason = "Defaulting to Maize for other conditions."

    return recommendation, reason

# --- Model Evaluation ---
def evaluate_crop_recommendation(combined_df):
    """
    Evaluate the crop recommendation model using historical data.
    """
    print("\n" + "=" * 60)
    print("📈 MODEL EVALUATION")
    print("=" * 60)

    # Apply recommendation logic to historical data
    predictions = []
    actuals = combined_df['RECOMMENDED'].tolist()

    for _, row in combined_df.iterrows():
        pred, _ = recommend_crop(row['LST_VALUE'], row['NDVI_VALUE'])
        predictions.append(pred)

    # Debugging: Print predictions and actuals
    print("Predictions:", predictions)
    print("Actuals:", actuals)

    # Dynamically determine unique labels
    unique_labels = list(set(actuals + predictions))
    if not unique_labels:
        raise ValueError("No valid labels found in actuals or predictions.")

    # Calculate metrics
    accuracy = accuracy_score(actuals, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        actuals, predictions, average='weighted', zero_division=0
    )

    print(f"Accuracy: {accuracy:.3f}")
    print(f"Precision (Weighted): {precision:.3f}")
    print(f"Recall (Weighted): {recall:.3f}")
    print(f"F1-Score (Weighted): {f1:.3f}")

    # Confusion Matrix
    cm = confusion_matrix(actuals, predictions, labels=unique_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=unique_labels, yticklabels=unique_labels)
    plt.title('Confusion Matrix for Crop Recommendation')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()

# --- Forecasting System ---
def create_region_mapping(combined_df):
    """
    Create a mapping of GID_2 codes to human-readable region names
    """
    unique_gids = combined_df['GID_2'].unique()
    malawi_regions = {
        'MWI.1.1_1': 'Northern Region - Chitipa',
        'MWI.1.2_1': 'Northern Region - Karonga',
        'MWI.1.3_1': 'Northern Region - Rumphi',
        'MWI.2.1_1': 'Central Region - Kasungu',
        'MWI.2.2_1': 'Central Region - Nkhotakota',
        'MWI.2.3_1': 'Central Region - Ntchisi',
        'MWI.2.4_1': 'Central Region - Dowa',
        'MWI.2.5_1': 'Central Region - Salima',
        'MWI.2.6_1': 'Central Region - Lilongwe',
        'MWI.2.7_1': 'Central Region - Mchinji',
        'MWI.3.1_1': 'Southern Region - Mangochi',
        'MWI.3.2_1': 'Southern Region - Machinga',
        'MWI.3.3_1': 'Southern Region - Zomba',
        'MWI.3.4_1': 'Southern Region - Chiradzulu',
        'MWI.3.5_1': 'Southern Region - Blantyre',
        'MWI.3.6_1': 'Southern Region - Mwanza',
        'MWI.3.7_1': 'Southern Region - Thyolo',
        'MWI.3.8_1': 'Southern Region - Mulanje',
        'MWI.3.9_1': 'Southern Region - Phalombe',
        'MWI.4.1_1': 'Southern Region - Chikwawa',
        'MWI.4.2_1': 'Southern Region - Nsanje',
        'MWI.4.3_1': 'Southern Region - Balaka',
        'MWI.4.4_1': 'Southern Region - Neno',
    }
    region_mapping = {gid: malawi_regions.get(gid, f"Region {gid}") for gid in unique_gids}
    return region_mapping

def get_available_options(combined_df):
    """
    Get available countries and regions for user selection
    """
    combined_df['country_code'] = combined_df['GID_2'].str.slice(0, 3)
    country_mapping = {
        'MWI': 'Malawi',
        'GHA': 'Ghana',
        'UGA': 'Uganda',
        'SEN': 'Senegal',
        'BEN': 'Benin'
    }
    available_countries = {code: country_mapping.get(code, f"Country_{code}") for code in combined_df['country_code'].unique()}
    return available_countries, create_region_mapping(combined_df)

def display_available_options(available_countries, region_mapping):
    """
    Display available options to the user
    """
    print("🌍 AVAILABLE COUNTRIES:")
    print("=" * 40)
    for code, name in available_countries.items():
        print(f"{code}: {name}")
    print("\n📍 AVAILABLE REGIONS (for detailed analysis):")
    print("=" * 60)
    region_items = list(region_mapping.items())
    for gid, name in region_items[:15]:
        print(f"{gid}: {name}")
    if len(region_mapping) > 15:
        print(f"... and {len(region_mapping) - 15} more regions")
    print(f"\n💡 Tip: Type 'Malawi' for country-level forecast or a GID code for regional analysis")

def get_user_selection(available_countries, region_mapping):
    """
    Get user input for country/region selection
    """
    while True:
        print("\n" + "=" * 60)
        user_input = input("Enter country name (e.g., 'Malawi') or GID code (e.g., 'MWI.1.1_1'): ").strip()
        country_lower = user_input.lower()
        for code, name in available_countries.items():
            if country_lower == name.lower():
                return {'type': 'country', 'value': code, 'name': name}
        if user_input in region_mapping:
            return {'type': 'region', 'value': user_input, 'name': region_mapping[user_input]}
        if country_lower == 'malawi':
            return {'type': 'country', 'value': 'MWI', 'name': 'Malawi'}
        print("❌ Invalid selection. Please choose from the available options.")
        display_available_options(available_countries, region_mapping)

def prepare_data_for_forecast(combined_df, selection, forecast_years=5):
    """
    Prepare data based on user selection
    """
    if selection['type'] == 'country':
        country_code = selection['value']
        country_data = combined_df[combined_df['GID_2'].str.startswith(country_code)]
        if country_data.empty:
            print(f"❌ No data found for country code: {country_code}")
            return None
        aggregated_data = country_data.groupby('DATE').agg({
            'RAINFALL_MM': 'mean',
            'NDVI_VALUE': 'mean',
            'LST_VALUE': 'mean'
        }).sort_index()
        print(f"✅ Prepared country-level data for {selection['name']}")
        print(f"   Time range: {aggregated_data.index.min()} to {aggregated_data.index.max()}")
        print(f"   Total records: {len(aggregated_data)}")
        return aggregated_data
    else:
        region_gid = selection['value']
        region_data = combined_df[combined_df['GID_2'] == region_gid].sort_values('DATE')
        if region_data.empty:
            print(f"❌ No data found for region: {region_gid}")
            return None
        region_data = region_data.set_index('DATE')[['RAINFALL_MM', 'NDVI_VALUE', 'LST_VALUE']]
        print(f"✅ Prepared region-level data for {selection['name']}")
        print(f"   Time range: {region_data.index.min()} to {region_data.index.max()}")
        print(f"   Total records: {len(region_data)}")
        return region_data

def prophet_forecast(series, periods=115, yearly_seasonality=True):
    """
    Use Facebook's Prophet for robust forecasting
    """
    try:
        from prophet import Prophet
    except ImportError:
        raise ImportError("The 'prophet' library is required for forecasting. Please install it using 'pip install prophet'.")
    prophet_df = pd.DataFrame({
        'ds': series.index,
        'y': series.values
    })
    model = Prophet(
        yearly_seasonality=yearly_seasonality,
        weekly_seasonality=False,
        daily_seasonality=False,
        changepoint_prior_scale=0.05
    )
    model.fit(prophet_df)
    future = model.make_future_dataframe(periods=periods, freq='16D')
    forecast = model.predict(future)
    forecast_series = forecast.set_index('ds')['yhat'][-periods:]
    confidence_lower = forecast.set_index('ds')['yhat_lower'][-periods:]
    confidence_upper = forecast.set_index('ds')['yhat_upper'][-periods:]
    return forecast_series, confidence_lower, confidence_upper

def sarima_forecast(series, periods=115):
    """
    Use SARIMA for forecasting with seasonal period of 23 (approx. yearly cycle for 16-day data)
    """
    model = SARIMAX(series, order=(5,1,0), seasonal_order=(1,1,1,23))  # Tune parameters as needed
    result = model.fit(disp=False)
    fore = result.get_forecast(steps=periods)
    forecast = fore.predicted_mean
    conf = fore.conf_int(alpha=0.05)
    lower = conf.iloc[:, 0]
    upper = conf.iloc[:, 1]
    last_date = series.index[-1]
    future_dates = pd.date_range(last_date + pd.Timedelta(days=16), periods=periods, freq='16D')
    forecast = pd.Series(forecast, index=future_dates)
    lower = pd.Series(lower, index=future_dates)
    upper = pd.Series(upper, index=future_dates)
    return forecast, lower, upper

def evaluate_forecast(model_type, series):
    """
    Evaluate forecasting model using MAE on test set
    """
    test_size = 0.2
    if len(series) < 5:  # Too small
        return float('nan')
    train_len = int(len(series) * (1 - test_size))
    if train_len < 3 or (len(series) - train_len) < 1:
        return float('nan')
    train = series[:train_len]
    test = series[train_len:]
    periods = len(test)
    if model_type == 'prophet':
        forecast, _, _ = prophet_forecast(train, periods)
    elif model_type == 'sarima':
        forecast, _, _ = sarima_forecast(train, periods)
    else:
        raise ValueError("Unknown model type")
    mae = np.mean(np.abs(forecast.values - test.values))
    return mae

def run_forecast(data, selection, forecast_years=5):
    """
    Run the forecasting pipeline
    """
    periods = int(365 * forecast_years / 16)
    if periods <= 0:
        periods = 60
    print(f"\n🔮 Forecasting for {selection['name']} ({forecast_years} years)...")
    rainfall_forecast, rain_lower, rain_upper = prophet_forecast(data['RAINFALL_MM'], periods)
    ndvi_forecast, ndvi_lower, ndvi_upper = prophet_forecast(data['NDVI_VALUE'], periods)
    lst_forecast, lst_lower, lst_upper = prophet_forecast(data['LST_VALUE'], periods)
    forecasts = pd.DataFrame({
        'RAINFALL_MM': rainfall_forecast.values,
        'NDVI_VALUE': ndvi_forecast.values,
        'LST_VALUE': lst_forecast.values
    }, index=rainfall_forecast.index)
    confidence_intervals = {
        'rainfall': {'lower': rain_lower, 'upper': rain_upper},
        'ndvi': {'lower': ndvi_lower, 'upper': ndvi_upper},
        'lst': {'lower': lst_lower, 'upper': lst_upper}
    }
    return forecasts, confidence_intervals

def plot_results(historical_data, forecasts, ci_dict, selection):
    """
    Plot forecasting results
    """
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 12))
    colors = ['#2E86AB', '#A23B72', '#F18F01']
    ax1.plot(historical_data.index, historical_data['RAINFALL_MM'],
             label='Historical', linewidth=2, color=colors[0], alpha=0.8)
    ax1.plot(forecasts.index, forecasts['RAINFALL_MM'],
             label='Forecast', linewidth=3, color=colors[0])
    ax1.fill_between(forecasts.index, ci_dict['rainfall']['lower'], ci_dict['rainfall']['upper'],
                    color=colors[0], alpha=0.2, label='95% CI')
    ax1.set_ylabel('Rainfall (mm)', fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax2.plot(historical_data.index, historical_data['NDVI_VALUE'],
             label='Historical', linewidth=2, color=colors[1], alpha=0.8)
    ax2.plot(forecasts.index, forecasts['NDVI_VALUE'],
             label='Forecast', linewidth=3, color=colors[1])
    ax2.fill_between(forecasts.index, ci_dict['ndvi']['lower'], ci_dict['ndvi']['upper'],
                    color=colors[1], alpha=0.2, label='95% CI')
    ax2.set_ylabel('NDVI Value', fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax3.plot(historical_data.index, historical_data['LST_VALUE'],
             label='Historical', linewidth=2, color=colors[2], alpha=0.8)
    ax3.plot(forecasts.index, forecasts['LST_VALUE'],
             label='Forecast', linewidth=3, color=colors[2])
    ax3.fill_between(forecasts.index, ci_dict['lst']['lower'], ci_dict['lst']['upper'],
                    color=colors[2], alpha=0.2, label='95% CI')
    ax3.set_ylabel('Temperature (°C)', fontweight='bold')
    ax3.set_xlabel('Year', fontweight='bold')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    plt.suptitle(f'5-Year Forecast for {selection["name"]}\n', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

def generate_summary(forecasts, selection):
    """
    Generate forecast summary and crop recommendation
    """
    print("\n" + "=" * 60)
    print(f"📊 FORECAST SUMMARY: {selection['name']}")
    print("=" * 60)
    avg_rainfall = forecasts['RAINFALL_MM'].mean()
    avg_ndvi = forecasts['NDVI_VALUE'].mean()
    avg_temp = forecasts['LST_VALUE'].mean()
    print(f"\n📅 Forecast Period: {forecasts.index[0].strftime('%Y-%m-%d')} to {forecasts.index[-1].strftime('%Y-%m-%d')}")
    print(f"\n🌧  Rainfall Forecast:")
    print(f"   Average: {avg_rainfall:.1f} mm")
    print(f"   Range: {forecasts['RAINFALL_MM'].min():.1f} - {forecasts['RAINFALL_MM'].max():.1f} mm")
    print(f"\n🌿 Vegetation Health (NDVI):")
    print(f"   Average: {avg_ndvi:.3f}")
    print(f"   Range: {forecasts['NDVI_VALUE'].min():.3f} - {forecasts['NDVI_VALUE'].max():.3f}")
    print(f"\n🌡  Temperature Forecast:")
    print(f"   Average: {avg_temp:.1f}°C")
    print(f"   Range: {forecasts['LST_VALUE'].min():.1f} - {forecasts['LST_VALUE'].max():.1f}°C")
    print("\n" + "=" * 60)
    print(f"🌾 CROP RECOMMENDATION FOR {selection['name']}")
    print("=" * 60)
    recommended_crop, reason = recommend_crop(avg_temp, avg_ndvi)
    print(f"Based on the average forecasted conditions:")
    print(f"  🌾 Recommended Crop: {recommended_crop}")
    print(f"  🧠 Reason: {reason}")
    print("=" * 60)

def main():
    """
    Main interactive forecasting system
    """
    print("🌍 WELCOME TO THE CLIMATE FORECASTING SYSTEM")
    print("=" * 50)
    print("This system predicts Rainfall, Vegetation Health, and Temperature")
    print("for any available country or region for the next 5 years!")
    print("It also provides a crop recommendation (Rice, Cassava, Maize)!")
    print("=" * 50)
    try:
        with open('combined_with_plants.csv', 'r') as f:
            content = f.read()
        pattern = r'([^,]+),([^,]+),(\d{1,2}/\d{1,2}/\d{4}),([^,]+),([^,]+),([^,]+),([^,]+)'
        matches = re.findall(pattern, content)
        if not matches:
            raise ValueError("No data records could be parsed from the file using the regex pattern.")
        uids, gids, dates, rainfall, ndvi, lst, plants = [], [], [], [], [], [], []
        for match in matches:
            uids.append(match[0])
            gids.append(match[1])
            try:
                date_obj = datetime.strptime(match[2], '%m/%d/%Y')
                dates.append(date_obj)
            except ValueError:
                print(f"Warning: Could not parse date '{match[2]}'. Skipping record.")
                continue
            try:
                rainfall.append(float(match[3]))
                ndvi.append(float(match[4]))
                lst.append(float(match[5]))
                plants.append(match[6].strip())  # Strip to remove any trailing newline
            except ValueError:
                print(f"Warning: Could not parse numerical values in record UID '{match[0]}'. Skipping record.")
                continue
        combined_df = pd.DataFrame({
            'UID': uids,
            'GID_2': gids,
            'DATE': dates,
            'RAINFALL_MM': rainfall,
            'NDVI_VALUE': ndvi,
            'LST_VALUE': lst,
            'RECOMMENDED': plants
        })
        combined_df = combined_df.dropna(subset=['DATE', 'RAINFALL_MM', 'NDVI_VALUE', 'LST_VALUE'])
        combined_df = combined_df.reset_index(drop=True)
        if combined_df.empty:
            raise ValueError("No valid data found after parsing and cleaning.")
    except FileNotFoundError:
        print("❌ Error: File 'combined_with_plants.csv' not found. Please ensure the file is in the correct directory.")
        return
    except Exception as e:
        print(f"❌ Error loading data: {e}")
        return
    print(f"✅ Data loaded and parsed successfully. Total records: {len(combined_df)}")

    # Evaluate crop recommendation model
    evaluate_crop_recommendation(combined_df)

    available_countries, region_mapping = get_available_options(combined_df)
    while True:
        display_available_options(available_countries, region_mapping)
        selection = get_user_selection(available_countries, region_mapping)
        data = prepare_data_for_forecast(combined_df, selection)
        if data is None or data.empty:
            print("⚠ Unable to prepare data for forecasting. Please try another selection.")
            continue
        # Evaluate forecasting models
        print("\n📊 Forecasting Model Evaluation (MAE on 20% test set):")
        for var in ['RAINFALL_MM', 'NDVI_VALUE', 'LST_VALUE']:
            mae_prophet = evaluate_forecast('prophet', data[var])
            mae_sarima = evaluate_forecast('sarima', data[var])
            print(f"{var}: Prophet MAE = {mae_prophet:.3f}, SARIMA MAE = {mae_sarima:.3f}")
        try:
            forecasts, ci_dict = run_forecast(data, selection)
        except Exception as e:
            print(f"❌ Error during forecasting: {e}")
            continue
        try:
            plot_results(data, forecasts, ci_dict, selection)
        except Exception as e:
            print(f"⚠ Could not display plot: {e}")
        generate_summary(forecasts, selection)
        filename = f"forecast_{selection['name'].replace(' ', '_').lower()}.csv"
        try:
            forecasts.to_csv(filename)
            print(f"\n💾 Forecast saved to: {filename}")
        except Exception as e:
            print(f"⚠ Could not save forecast to file: {e}")
        print("\n" + "=" * 50)
        continue_choice = input("Would you like to forecast another region? (yes/no): ").strip().lower()
        if continue_choice not in ['yes', 'y']:
            print("Thank you for using the Climate Forecasting System! 👋")
            break

if __name__ == "__main__":
    main()

🌍 WELCOME TO THE CLIMATE FORECASTING SYSTEM
This system predicts Rainfall, Vegetation Health, and Temperature
for any available country or region for the next 5 years!
It also provides a crop recommendation (Rice, Cassava, Maize)!
✅ Data loaded and parsed successfully. Total records: 8587

📈 MODEL EVALUATION
Predictions: ['Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Maize', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Maize', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Maize', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 'Cassava', 

KeyboardInterrupt: 