In [None]:
"""
SWAT+ Evapotranspiration Spatial Performance Evaluation Tool

This tool  evaluates SWAT+ ET performance against reference data (e.g., WaPOR)
and identifies soil-landuse combinations associated with poor model performance.

"""

import rasterio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import uniform_filter, label
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.mask import mask
import geopandas as gpd
import warnings
warnings.filterwarnings("ignore")

# ===========================
# USER-DEFINED INPUTS
# ===========================

# Main ET comparison files
OBSERVED_ET_PATH = r"path/to/your/ObservedET.tif"  # Reference ET (e.g., WaPOR)
SWAT_ET_PATH = r"path/to/your/SWAT_modeled_ET.tif"   # SWAT+ modeled ET
WATERSHED_BOUNDARY = r"path/to/your/watershed_boundary.shp"  # Watershed shapefile

# Water bodies mask (optional - set to None if not needed)
WATERBODIES_MASK = r"path/to/your/waterbodies.shp"  # Areas to exclude from analysis

# Soil and landuse rasters for diagnostic analysis
SOIL_RASTER = r"path/to/your/soil_raster.tif"
LANDUSE_RASTER = r"path/to/your/landuse_raster.tif"

# Output directory
OUTPUT_DIR = r"path/to/output/directory"

# Performance thresholds (%)
PERFORMANCE_THRESHOLDS = [10, 25, 50]  # Good: 0-10%, Moderate: 10-25%, Poor: 25-50%, Very Poor: >50%

# Minimum region size (pixels) for poor performance analysis
MIN_REGION_SIZE = 20  # ~20 km² at 1km resolution

# Spatial smoothing window size
SMOOTHING_WINDOW = 7 # 7x7 pixel moving average (adjust as needed)

# Soil type names (Soil ID: Soil Type)
SOIL_NAMES = {
    0: 'Clay',
    1: 'Sandy Clay',
    2: 'Clay Loam',
    # Add your soil types here
}

# Landuse names (Landuse ID: Landuse Type)
LANDUSE_NAMES = {
    20: 'Rangeland - Brush',
    30: 'Rangeland - Grasses',
    41: 'Agricultural Land - Row Crops',
    42: 'Agricultural Land - Sugarcane',
    # Add your landuse types here
}

# ===========================
# FUNCTIONS
# ===========================

def clip_raster(raster_path, shape_path=None):
    """Load and optionally clip raster with shapefile"""
    with rasterio.open(raster_path) as src:
        if shape_path:
            shapes_gdf = gpd.read_file(shape_path)
            if shapes_gdf.crs != src.crs:
                shapes_gdf = shapes_gdf.to_crs(src.crs)
            
            shapes = [geom for geom in shapes_gdf.geometry if geom is not None]
            if shapes:
                out_image, out_transform = mask(src, shapes, invert=True, crop=False, 
                                               nodata=np.nan, all_touched=True)
                return out_image[0], src.bounds, src.crs, out_transform
        
        return src.read(1), src.bounds, src.crs, src.transform

def resample_raster(src_path, reference_raster, reference_transform, reference_crs):
    """Resample raster to match reference resolution and extent"""
    with rasterio.open(src_path) as src:
        resampled = np.zeros_like(reference_raster, dtype=src.dtypes[0])
        reproject(
            source=src.read(1),
            destination=resampled,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=reference_transform,
            dst_crs=reference_crs,
            resampling=Resampling.bilinear
        )
        return resampled

def calculate_spatial_mare(observed, predicted):
    """Calculate pixel-wise Mean Absolute Relative Error"""
    mask = ~(np.isnan(observed) | np.isnan(predicted) | 
             (observed <= 0) | (predicted <= 0))
    
    spatial_mare = np.full_like(observed, np.nan, dtype=float)
    spatial_mare[mask] = np.abs((observed[mask] - predicted[mask]) / observed[mask]) * 100
    
    return spatial_mare

def spatial_average(array, window_size):
    """Apply spatial smoothing using moving window average"""
    smoothed = np.copy(array)
    valid_mask = ~np.isnan(array)
    
    temp_array = np.nan_to_num(array, nan=0.0)
    filtered_data = uniform_filter(temp_array, size=window_size, mode='constant', cval=0.0)
    valid_count = uniform_filter(valid_mask.astype(float), size=window_size, 
                               mode='constant', cval=0.0)
    
    valid_count[valid_count == 0] = 1
    smoothed = filtered_data / valid_count
    smoothed[~valid_mask] = np.nan
    
    return smoothed

def classify_performance(mare_array, thresholds):
    """Classify areas by performance level"""
    classified = np.full_like(mare_array, np.nan)
    valid_mask = ~np.isnan(mare_array)
    
    classified[valid_mask] = 0  # Good
    for i, threshold in enumerate(thresholds):
        classified[valid_mask & (mare_array > threshold)] = i + 1
    
    return classified



def find_poor_regions(performance_classes, min_size=20):
    """
    Find connected regions in the poor and very poor performance classes
    
    Parameters:
    performance_classes (numpy.ndarray): Classification array
    min_size (int): Minimum region size to keep
    
    Returns:
    dict: Dictionary with region information including bounding boxes
    """
    regions = {}
    
    # Only interested in poor (2) and very poor (3) classes
    for cls in [2, 3]:
        # Create binary mask for this class
        binary = (performance_classes == cls)
        
        # Find connected components
        labeled_array, num_features = label(binary)
        
        # Extract regions
        regions[cls] = []
        
        if num_features > 0:
            # Get unique labels (excluding 0 which is background)
            unique_labels = np.unique(labeled_array)
            unique_labels = unique_labels[unique_labels > 0]
            
            for i in unique_labels:
                # Create mask for this region
                region_mask = (labeled_array == i)
                region_size = np.sum(region_mask)
                
                if region_size >= min_size:
                    # Find the centroid
                    rows, cols = np.where(region_mask)
                    centroid = (int(np.mean(rows)), int(np.mean(cols)))
                    
                    # Find the bounding box
                    if len(rows) > 0 and len(cols) > 0:
                        min_row, max_row = np.min(rows), np.max(rows)
                        min_col, max_col = np.min(cols), np.max(cols)
                        bbox = (int(min_row), int(min_col), int(max_row), int(max_col))
                        
                        regions[cls].append({
                            'id': int(i),
                            'size': int(region_size),
                            'centroid': centroid,
                            'bbox': bbox,
                            'mask': region_mask
                        })
    
    return regions

def analyze_regions_soil_landuse(regions, soil_data, landuse_data):
    """Analyze soil and landuse composition of poor regions"""
    results = {}
    
    for cls, region_list in regions.items():
        results[cls] = []
        
        for region in region_list:
            region_mask = region['mask']
            
            # Extract data for this region
            soil_in_region = soil_data[region_mask]
            landuse_in_region = landuse_data[region_mask]
            
            # Remove NaN values
            valid_soil = soil_in_region[~np.isnan(soil_in_region)]
            valid_landuse = landuse_in_region[~np.isnan(landuse_in_region)]
            
            # Find dominant types
            if len(valid_soil) > 0:
                soil_vals, soil_counts = np.unique(valid_soil, return_counts=True)
                dominant_soil_idx = np.argmax(soil_counts)
                dominant_soil = int(soil_vals[dominant_soil_idx])
                dominant_soil_pct = (soil_counts[dominant_soil_idx] / len(valid_soil)) * 100
            else:
                dominant_soil, dominant_soil_pct = None, 0
            
            if len(valid_landuse) > 0:
                lu_vals, lu_counts = np.unique(valid_landuse, return_counts=True)
                dominant_lu_idx = np.argmax(lu_counts)
                dominant_landuse = int(lu_vals[dominant_lu_idx])
                dominant_landuse_pct = (lu_counts[dominant_lu_idx] / len(valid_landuse)) * 100
            else:
                dominant_landuse, dominant_landuse_pct = None, 0
            
            results[cls].append({
                'region_id': region['id'],
                'size': region['size'],
                'dominant_soil': dominant_soil,
                'dominant_soil_pct': dominant_soil_pct,
                'dominant_landuse': dominant_landuse,
                'dominant_landuse_pct': dominant_landuse_pct
            })
    
    return results

def save_raster(data, reference_path, output_path, transform):
    """Save array as georeferenced raster"""
    with rasterio.open(reference_path) as src:
        profile = src.profile.copy()
        profile.update({
            'transform': transform,
            'dtype': 'float32',
            'nodata': np.nan
        })
        
        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(data.astype('float32'), 1)




# ===========================
# MAIN ANALYSIS
# ===========================

def main():
    print("SWAT+ ET Spatial Performance Evaluation")
    print("=" * 50)
    
    # Load and preprocess data
    print("Loading reference ET data...")
    ref_et, ref_bounds, ref_crs, ref_transform = clip_raster(WAPOR_ET_PATH, WATERBODIES_MASK)
    ref_et = np.where(ref_et <= 0, np.nan, ref_et)
    
    print("Loading SWAT+ ET data...")
    swat_et, _, _, _ = clip_raster(SWAT_ET_PATH, WATERBODIES_MASK)
    swat_et = np.where(swat_et <= 0, np.nan, swat_et)
    
    # Resample SWAT+ to match reference
    print("Resampling SWAT+ ET to match reference resolution...")
    swat_et_resampled = resample_raster(SWAT_ET_PATH, ref_et, ref_transform, ref_crs)
    swat_et_resampled = np.where(swat_et_resampled <= 0, np.nan, swat_et_resampled)
    
    # Calculate MARE
    print("Calculating spatial MARE...")
    spatial_mare = calculate_spatial_mare(ref_et, swat_et_resampled)
    
    # Apply smoothing
    print(f"Applying spatial smoothing (window size: {SMOOTHING_WINDOW})...")
    smoothed_mare = spatial_average(spatial_mare, SMOOTHING_WINDOW)
    
    # Classify performance
    print("Classifying performance...")
    performance_classes = classify_performance(smoothed_mare, PERFORMANCE_THRESHOLDS)
    
    # Load soil and landuse data
    print("Loading soil and landuse data...")
    soil_data = resample_raster(SOIL_RASTER, ref_et, ref_transform, ref_crs)
    landuse_data = resample_raster(LANDUSE_RASTER, ref_et, ref_transform, ref_crs)
    
    # Find poor regions
    print(f"Identifying poor performance regions (min size: {MIN_REGION_SIZE} pixels)...")
    poor_regions = find_poor_regions(performance_classes, MIN_REGION_SIZE)
    
    # Analyze regions
    print("Analyzing soil-landuse composition of poor regions...")
    region_analysis = analyze_regions_soil_landuse(poor_regions, soil_data, landuse_data)
    
    # Save outputs
    print("\nSaving outputs...")
    save_raster(spatial_mare, WAPOR_ET_PATH, f"{OUTPUT_DIR}/spatial_mare.tif", ref_transform)
    save_raster(smoothed_mare, WAPOR_ET_PATH, f"{OUTPUT_DIR}/smoothed_mare.tif", ref_transform)
    save_raster(performance_classes, WAPOR_ET_PATH, f"{OUTPUT_DIR}/performance_classes.tif", ref_transform)
    
    # Create summary report
    performance_labels = ['Good', 'Moderate', 'Poor', 'Very Poor']
    
    # Calculate area statistics
    unique_classes, counts = np.unique(performance_classes[~np.isnan(performance_classes)], 
                                     return_counts=True)
    total_pixels = np.sum(counts)
    
    print("\n" + "=" * 50)
    print("PERFORMANCE SUMMARY")
    print("=" * 50)
    
    for cls, count in zip(unique_classes.astype(int), counts):
        percentage = (count / total_pixels) * 100
        print(f"{performance_labels[cls]}: {percentage:.1f}% ({count} pixels)")
    
    # Save detailed results to CSV
    all_regions = []
    for cls, regions in region_analysis.items():
        for region in regions:
            region['performance_class'] = performance_labels[cls]
            region['dominant_soil_name'] = SOIL_NAMES.get(region['dominant_soil'], 'Unknown')
            region['dominant_landuse_name'] = LANDUSE_NAMES.get(region['dominant_landuse'], 'Unknown')
            all_regions.append(region)

    # Create spatial visualization of regions
    plt.figure(figsize=(15, 10))
    
    # Plot performance classification
    cmap = plt.cm.get_cmap('viridis', len(PERFORMANCE_THRESHOLDS) + 1)
    plt.imshow(performance_classes, cmap=cmap, vmin=-0.5, vmax=len(PERFORMANCE_THRESHOLDS) + 0.5)

    # Overlay region boundaries for poor and very poor regions
    for cls in [2, 3]:  # Poor and Very Poor
        if cls not in poor_regions:
            continue
            
        for region in poor_regions[cls]:
            # Get bounding box
            min_row, min_col, max_row, max_col = region['bbox']
            
            # Plot rectangle
            rect = plt.Rectangle((min_col-0.5, min_row-0.5), max_col-min_col+1, max_row-min_row+1, 
                                edgecolor='red' if cls == 3 else 'orange', 
                                facecolor='none', linewidth=1.5, alpha=0.7)
            plt.gca().add_patch(rect)
            
            # Add region ID text
            plt.text(region['centroid'][1], region['centroid'][0], str(region['id']), 
                    color='white', fontsize=8, ha='center', va='center',
                    bbox=dict(facecolor='black', alpha=0.5, boxstyle='round,pad=0.1'))
    
    cbar = plt.colorbar(ticks=range(len(PERFORMANCE_THRESHOLDS) + 1))
    cbar.set_label('Performance Class')
    cbar.set_ticklabels(performance_labels)
    plt.title('Poor Performance Regions with Region IDs')
    plt.axis('on')
    plt.savefig(f"{OUTPUT_DIR}/poor_regions_map.png", dpi=300) 
    plt.show()
    
    if all_regions:
        df = pd.DataFrame(all_regions)
        df.to_csv(f"{OUTPUT_DIR}/poor_regions_analysis.csv", index=False)
        print(f"\nDetailed analysis saved to: {OUTPUT_DIR}/poor_regions_analysis.csv")
    
    print("\nAnalysis complete!")

if __name__ == "__main__":
    main()