In [None]:
#RF Analysis for CWR, IWR and CDIs
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split, cross_val_score
import shap
import seaborn as sns

# Load the NetCDF files for maize, rice, and wheat
cwr_maize_ds = xr.open_dataset(r'\CWR_maize_global.nc')
cwr_rice_ds = xr.open_dataset(r'\CWR_rice_global.nc')
cwr_wheat_ds = xr.open_dataset(r'\CWR_wheat_global.nc')
cmdi_ds = xr.open_dataset(r'\CDI.nc')

# Convert the time variable in CWR to DatetimeIndex
for ds in [cwr_maize_ds, cwr_rice_ds, cwr_wheat_ds]:
    ds['time'] = pd.to_datetime(ds['time'].values, origin='2000-01-01', unit='D')

# Align both datasets to the common time period (2000-01 to 2020-12)
cwr_maize_ds = cwr_maize_ds.sel(time=slice('2000-01-01', '2020-12-31'))
cwr_rice_ds = cwr_rice_ds.sel(time=slice('2000-01-01', '2020-12-31'))
cwr_wheat_ds = cwr_wheat_ds.sel(time=slice('2000-01-01', '2020-12-31'))
cmdi_ds = cmdi_ds.sel(time=slice('2000-01-01', '2020-12-31'))

# Resample to monthly resolution
cwr_maize_ds = cwr_maize_ds.resample(time='M').mean()
cwr_rice_ds = cwr_rice_ds.resample(time='M').mean()
cwr_wheat_ds = cwr_wheat_ds.resample(time='M').mean()
cmdi_ds = cmdi_ds.resample(time='M').mean()

# Replace large placeholder values with NaN for all CWR datasets
cwr_maize_ds = cwr_maize_ds.where(cwr_maize_ds.CWR_maize < 1e30, np.nan)
cwr_rice_ds = cwr_rice_ds.where(cwr_rice_ds.CWR_rice < 1e30, np.nan)
cwr_wheat_ds = cwr_wheat_ds.where(cwr_wheat_ds.CWR_wheat < 1e30, np.nan)

# Convert any remaining inf to NaN
for ds in [cwr_maize_ds, cwr_rice_ds, cwr_wheat_ds]:
    ds = ds.where(np.isfinite(ds.to_array()), np.nan)

# Basin coordinates
basin_coords = {
    "Yellow River": [(34.5, 113.5), (41, 116)],
    "Yangtze": [(25, 90), (35, 120)],
    "Sutlej": [(28, 73), (34, 83)],
    "Indus": [(23, 67), (35, 80)],
    "Mississippi": [(29, -91), (47, -88)],
    "Mediterranean": [(30, -5), (45, 36)],
    "Iranian Plateau": [(24, 44), (40, 65)],
    "Ganges": [(22, 80), (30, 88)],
    "Aegean": [(36, 20), (42, 30)],
    "Amazon": [(-5, -75), (5, -50)]
}

# Initialize an empty list to store results
results = []

# Set the Seaborn theme for better aesthetics
sns.set_theme(style="whitegrid")

# Set up the subplot grid with increased spacing for titles and labels
num_basins = len(basin_coords)
num_cols = 2  # Two columns for the basins
num_rows = (num_basins + 1) // num_cols  # Calculate rows needed
fig, axes = plt.subplots(num_rows, num_cols, figsize=(18, 6 * num_rows))  # Increased figure size for better readability
axes = axes.flatten()

# Define the growing seasons for each crop (these are examples, adjust as necessary)
growing_seasons = {
    'maize': [5, 6, 7, 8],  # May to August
    'rice': [6, 7, 8, 9],   # June to September
    'wheat': [10, 11, 12, 1, 2, 3]  # October to March
}

# Loop over each basin and perform analysis for maize, rice, and wheat
for idx, (basin_name, coords) in enumerate(basin_coords.items()):
    print(f"\nAnalyzing basin: {basin_name}")
    
    # Extract the data for the basin
    cwr_maize_basin = cwr_maize_ds.sel(lat=slice(coords[0][0], coords[1][0]), lon=slice(coords[0][1], coords[1][1]))
    cwr_rice_basin = cwr_rice_ds.sel(lat=slice(coords[0][0], coords[1][0]), lon=slice(coords[0][1], coords[1][1]))
    cwr_wheat_basin = cwr_wheat_ds.sel(lat=slice(coords[0][0], coords[1][0]), lon=slice(coords[0][1], coords[1][1]))
    cmdi_basin = cmdi_ds.sel(lat=slice(coords[0][0], coords[1][0]), lon=slice(coords[0][1], coords[1][1]))
    
    # Average the spatial dimensions to get time series data for the basin
    cwr_maize_ts = cwr_maize_basin.CWR_maize.mean(dim=['lat', 'lon']).to_dataframe().dropna()
    cwr_rice_ts = cwr_rice_basin.CWR_rice.mean(dim=['lat', 'lon']).to_dataframe().dropna()
    cwr_wheat_ts = cwr_wheat_basin.CWR_wheat.mean(dim=['lat', 'lon']).to_dataframe().dropna()
    cmdi_ts = cmdi_basin.CMDI.mean(dim=['lat', 'lon']).to_dataframe().dropna()
    
    # Merge the datasets on time
    data_maize = pd.merge(cwr_maize_ts, cmdi_ts, left_index=True, right_index=True)
    data_rice = pd.merge(cwr_rice_ts, cmdi_ts, left_index=True, right_index=True)
    data_wheat = pd.merge(cwr_wheat_ts, cmdi_ts, left_index=True, right_index=True)

    # Add temporal features
    for data in [data_maize, data_rice, data_wheat]:
        data['month'] = data.index.month
        data['year'] = data.index.year

    # Colormap and marker for each crop
    cmap_maize = plt.cm.RdYlBu
    cmap_rice = plt.cm.viridis
    cmap_wheat = plt.cm.plasma
    
    markers = {'maize': 'o', 'rice': 's', 'wheat': '^'}  # Different markers for each crop

    # Perform analysis for each crop (maize, rice, wheat)
    for crop_name, data, cmap in zip(['maize', 'rice', 'wheat'], [data_maize, data_rice, data_wheat], [cmap_maize, cmap_rice, cmap_wheat]):
        if data[f'CWR_{crop_name}'].var() > 0 and data['CMDI'].var() > 0:# CMDI=CDI
            # Split the data into training and testing sets
            X = data[['CMDI', 'month', 'year']]
            y = data[f'CWR_{crop_name}']
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
            
            # Initialize the Random Forest Regressor
            rf = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)
            
            # Perform Cross-Validation
            cv_scores = cross_val_score(rf, X_train, y_train, cv=5, scoring='r2')
            mean_cv_score = cv_scores.mean()
            
            # Train the model
            rf.fit(X_train, y_train)
            
            # Make predictions
            y_pred = rf.predict(X_test)
            
            # Evaluate the model
            mse = mean_squared_error(y_test, y_pred)
            r2 = r2_score(y_test, y_pred)
            
            print(f"Model evaluation for {basin_name} ({crop_name}):")
            print(f"Mean Squared Error: {mse:.4f}")
            print(f"R-squared: {r2:.4f}")
            
            # SHAP analysis to understand the impact of CMDI on CWR
            explainer = shap.Explainer(rf, X_train)
            shap_values = explainer(X_test)
            
            # Highlight the growing season months
            is_growing_season = X_test['month'].isin(growing_seasons[crop_name])
            
            # Plot SHAP dependence plot for CMDI manually with different markers
            plt.sca(axes[idx])
            scatter = axes[idx].scatter(X_test['CMDI'][is_growing_season], shap_values.values[is_growing_season, X_test.columns.get_loc('CMDI')],
                                        c=X_test['month'][is_growing_season], cmap=cmap, s=100, marker=markers[crop_name],
                                        label=f'{crop_name.capitalize()} (Growing Season)', edgecolor='k', linewidth=0.5)
            # Plot for non-growing season months with a different color
            scatter = axes[idx].scatter(X_test['CMDI'][~is_growing_season], shap_values.values[~is_growing_season, X_test.columns.get_loc('CMDI')],
                                        c='lightgray', s=50, marker=markers[crop_name],
                                        label=f'{crop_name.capitalize()} (Off-Season)', edgecolor='k', linewidth=0.5)
            
            axes[idx].set_title(f'{basin_name}', fontsize=18, fontweight='bold')
            axes[idx].set_xlabel('CMDI', fontsize=16)
            axes[idx].set_ylabel(f'SHAP value ({crop_name.capitalize()})', fontsize=16)  # Set y-axis label based on current crop
            axes[idx].tick_params(axis='both', which='major', labelsize=14)  # Adjusted font size for ticks

            # Enable grid with subtle lines
            axes[idx].grid(True, linestyle='--', linewidth=0.6, color='gray')

            # Add text annotations for evaluation metrics
            textstr = (f'{crop_name.capitalize()}:\n'
                       f'Mean CV R²: {mean_cv_score:.2f}\n'
                       f'Test R²: {r2:.2f}\n'
                       f'MSE: {mse:.2f}')
            axes[idx].text(0.05, 0.95 - 0.2 * ["maize", "rice", "wheat"].index(crop_name), textstr, transform=axes[idx].transAxes, fontsize=14,
                                   verticalalignment='top', bbox=dict(facecolor='white', alpha=0.8))
            
        else:
            print(f"Insufficient variance in data for {basin_name} ({crop_name})")

    # Add a shared color bar for this subplot
    norm = plt.Normalize(vmin=1, vmax=12)  # Adjust the range as necessary
    sm = plt.cm.ScalarMappable(cmap=cmap_maize, norm=norm)  # Change cmap as appropriate
    sm.set_array([])

    # Adjust the shrink parameter to reduce the height of the color bar
    cbar = plt.colorbar(sm, ax=axes[idx], shrink=0.5)  # Shrink to 50% of the subplot height
    cbar.set_label('Month', fontsize=16)
    cbar.ax.tick_params(labelsize=14)  # Adjust the font size of the color bar ticks

    # Add a legend to indicate which marker represents which crop and season
    axes[idx].legend(loc='best', fontsize=14)

# If there are unused subplots (due to odd number of basins), hide them
for i in range(idx + 1, len(axes)):
    axes[i].set_visible(False)

# Adjust layout and save the figure
plt.tight_layout()
plt.savefig('/.svg', dpi=300, format='svg')
plt.show()

# Convert the results list to a DataFrame
results_df = pd.DataFrame(results)

# Print the results DataFrame to console
print("\nModel Evaluation Results for Basins:")
print(results_df)

