In [None]:
# 6-Model Ensemble Generator for Jordan Climate Data
# Creates ensemble from all 6 available climate models uniformly across all basins

import xarray as xr
import geopandas as gpd
import numpy as np
from shapely.geometry import Point
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt


def create_basin_mask(basins_gdf, lat_points, lon_points, jordan_gdf):
    """Create a mask array with expanded coverage to ensure all stations get data"""
    # Ensure consistent CRS
    basins_gdf = basins_gdf.to_crs("EPSG:4326").copy()
    jordan_gdf = jordan_gdf.to_crs("EPSG:4326").copy()

    mask = np.full((len(lat_points), len(lon_points)), None, dtype=object)

    # Create mesh grid of points
    lon_mesh, lat_mesh = np.meshgrid(lon_points, lat_points)

    # Get Jordan boundary with a buffer for edge cases
    jordan_boundary = jordan_gdf.geometry.union_all()

    # Create an expanded boundary - add ~0.2 degrees (~22km) buffer around Jordan
    jordan_boundary_buffered = jordan_boundary.buffer(0.2)

    # Filter out Syria basins
    basins_filtered = basins_gdf[~basins_gdf['BASIN_NAME'].str.contains('SYRIA', case=False, na=False)].copy()
    basins_filtered['geometry'] = basins_filtered.geometry.intersection(jordan_boundary)
    basins_filtered = basins_filtered[~basins_filtered.geometry.is_empty]

    # Get grid bounds to identify boundary regions
    min_lat, max_lat = min(lat_points), max(lat_points)
    min_lon, max_lon = min(lon_points), max(lon_points)

    # Calculate grid spacing
    lat_spacing = lat_points[1] - lat_points[0] if len(lat_points) > 1 else 0.1
    lon_spacing = lon_points[1] - lon_points[0] if len(lon_points) > 1 else 0.1

    # Expand processing region by 2 grid cells as requested
    processing_min_lat = min_lat - 2 * lat_spacing
    processing_max_lat = max_lat + 2 * lat_spacing
    processing_min_lon = min_lon - 2 * lon_spacing
    processing_max_lon = max_lon + 2 * lon_spacing

    assigned_count = 0
    boundary_assignments = 0
    nearest_assignments = 0

    print(f"Processing expanded grid region:")
    print(f"  Original grid: {min_lat:.3f} to {max_lat:.3f} lat, {min_lon:.3f} to {max_lon:.3f} lon")
    print(
        f"  Expanded region: {processing_min_lat:.3f} to {processing_max_lat:.3f} lat, {processing_min_lon:.3f} to {processing_max_lon:.3f} lon")
    print(f"  Grid spacing: {lat_spacing:.3f} lat, {lon_spacing:.3f} lon")

    # Assign each point to a basin with multiple fallback strategies
    for i in range(len(lat_points)):
        for j in range(len(lon_points)):
            point = Point(lon_mesh[i, j], lat_mesh[i, j])
            current_lat = lat_points[i]
            current_lon = lon_points[j]

            # Strategy 1: Direct basin containment
            basin_found = False
            for idx, basin in basins_filtered.iterrows():
                if basin.geometry.contains(point):
                    mask[i, j] = basin['BASIN_NAME']
                    basin_found = True
                    assigned_count += 1
                    break

            if basin_found:
                continue

            # Strategy 2: Point within Jordan boundary (original or buffered)
            if jordan_boundary.contains(point) or jordan_boundary_buffered.contains(point):
                # Find nearest basin for points in Jordan
                min_distance = float('inf')
                nearest_basin = None

                for idx, basin in basins_filtered.iterrows():
                    distance = basin.geometry.distance(point)
                    if distance < min_distance:
                        min_distance = distance
                        nearest_basin = basin['BASIN_NAME']

                if nearest_basin and min_distance < 0.5:  # Within ~55km
                    mask[i, j] = nearest_basin
                    assigned_count += 1
                    boundary_assignments += 1
                    if min_distance > 0.1:  # Log if it's a significant distance assignment
                        print(
                            f"    Boundary assignment: Grid ({current_lat:.3f}, {current_lon:.3f}) → {nearest_basin} ({min_distance * 111:.1f} km)")
                continue

            # Strategy 3: Extended region coverage (for stations near borders)
            # Check if point is in our expanded processing region
            if (processing_min_lat <= current_lat <= processing_max_lat and
                    processing_min_lon <= current_lon <= processing_max_lon):

                # Find nearest basin regardless of boundaries for border regions
                min_distance = float('inf')
                nearest_basin = None

                for idx, basin in basins_filtered.iterrows():
                    distance = basin.geometry.distance(point)
                    if distance < min_distance:
                        min_distance = distance
                        nearest_basin = basin['BASIN_NAME']

                if nearest_basin and min_distance < 1.0:  # Within ~111km for border regions
                    mask[i, j] = nearest_basin
                    assigned_count += 1
                    nearest_assignments += 1
                    print(
                        f"    Extended assignment: Grid ({current_lat:.3f}, {current_lon:.3f}) → {nearest_basin} ({min_distance * 111:.1f} km)")

    print(f"Basin mask creation summary:")
    print(f"  Total assignments: {assigned_count}")
    print(f"  Direct basin containment: {assigned_count - boundary_assignments - nearest_assignments}")
    print(f"  Boundary assignments: {boundary_assignments}")
    print(f"  Extended region assignments: {nearest_assignments}")
    print(
        f"  Coverage: {assigned_count}/{len(lat_points) * len(lon_points)} grid points ({assigned_count / (len(lat_points) * len(lon_points)) * 100:.1f}%)")

    return mask


def create_6_model_ensemble_for_period(nc_dir, basin_mask, model_names,
                                       start_date, end_date, output_path):
    """Create ensemble dataset using all 6 models for a specific time period"""
    # Open one NC file to get coordinate reference
    sample_ds = xr.open_dataset(Path(nc_dir) / f"{model_names[0]}.nc")

    # Select time period
    time_slice = slice(start_date, end_date)

    # Initialize output dataset
    ds_out = xr.Dataset(
        coords={
            'time': sample_ds.time.sel(time=time_slice),
            'lat': sample_ds.lat,
            'lon': sample_ds.lon
        }
    )

    # Create empty array for ensemble data
    ensemble_data = np.zeros((len(ds_out.time), len(ds_out.lat), len(ds_out.lon)))
    model_count = np.zeros((len(ds_out.lat), len(ds_out.lon)), dtype=int)

    # Process each grid point
    print(f"\nProcessing period {start_date} to {end_date}")
    total_points = len(ds_out.lat) * len(ds_out.lon)
    processed_points = 0

    for i in range(len(ds_out.lat)):
        for j in range(len(ds_out.lon)):
            basin = basin_mask[i, j]
            if basin is not None:  # Point is within a Jordan basin
                point_data = []
                successful_models = []

                # Load data from all 6 models for this grid point
                for model in model_names:
                    try:
                        ds = xr.open_dataset(Path(nc_dir) / f"{model}.nc")
                        model_data = ds.prAdjust.sel(time=time_slice).isel(lat=i, lon=j)
                        point_data.append(model_data.values)
                        successful_models.append(model)
                        ds.close()
                    except Exception as e:
                        print(f"Warning: Could not load {model} for point ({i},{j}): {e}")
                        continue

                if point_data:
                    # Calculate ensemble mean of all available models
                    ensemble_data[:, i, j] = np.mean(point_data, axis=0)
                    model_count[i, j] = len(point_data)

            processed_points += 1
            if processed_points % 1000 == 0:
                print(
                    f"Processed {processed_points}/{total_points} points ({processed_points / total_points * 100:.1f}%)")

    # Add variables to output dataset
    ds_out['prAdjust'] = (('time', 'lat', 'lon'), ensemble_data)
    ds_out['model_count'] = (('lat', 'lon'), model_count)

    # Add metadata
    ds_out.prAdjust.attrs = sample_ds.prAdjust.attrs
    ds_out.prAdjust.attrs['description'] = 'Ensemble average of precipitation from all 6 climate models'
    ds_out.prAdjust.attrs['models_used'] = ', '.join(model_names)
    ds_out.model_count.attrs['description'] = 'Number of models used in ensemble average at each point'
    ds_out.attrs['description'] = f'Ensemble average of 6 climate models ({start_date} to {end_date})'
    ds_out.attrs['models'] = ', '.join(model_names)
    ds_out.attrs['period'] = f'{start_date} to {end_date}'
    ds_out.attrs['creation_date'] = pd.Timestamp.now().strftime('%Y-%m-%d')
    ds_out.attrs['ensemble_type'] = '6-model uniform ensemble'

    # Save to file
    print(f"Saving to {output_path}")
    ds_out.to_netcdf(output_path)
    sample_ds.close()
    return ds_out


def create_ensemble_summary_report(output_dir, model_names):
    """Create a summary report of the ensemble creation process"""
    report_content = f"""
# 6-Model Ensemble Summary Report
Generated on: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}

## Models Used:
{chr(10).join(f"{i + 1}. {model}" for i, model in enumerate(model_names))}

## Ensemble Method:
- Uniform application of all 6 models across all grid points
- Simple arithmetic mean of all available models at each point
- Geographic coverage: Jordan basins only (Syria basins excluded)

## Time Periods Processed:
1. 1961_1994: Historical baseline period
2. 1995_2014: Recent historical period  
3. 2015_2020: Recent period
4. 2021_2040: Near-future projections
5. 2041_2060: Mid-future projections
6. 2061_2070: Far-future projections

## Output Files:
- ensemble_precipitation_6models_1961_1994.nc
- ensemble_precipitation_6models_1995_2014.nc
- ensemble_precipitation_6models_2015_2020.nc
- ensemble_precipitation_6models_2021_2040.nc
- ensemble_precipitation_6models_2041_2060.nc
- ensemble_precipitation_6models_2061_2070.nc

## Technical Details:
- Variable: prAdjust (adjusted precipitation)
- Grid: 85x75 points
- CRS: EPSG:4326 (WGS84)
- Units: mm/day
- Missing values handled: Yes
"""

    with open(Path(output_dir) / 'ensemble_summary_report.txt', 'w') as f:
        f.write(report_content)


def main():
    """Main execution function for creating 6-model ensemble datasets"""

    # Define paths
    nc_dir = r"D:\RICAAR\riccar-data_jordan-ssp2-4-5-daily-data_2024-07-29_0915\Merge\Precipitation"
    basin_shapefile = r"D:\RICAAR\surfacebasin\surface_basin.shp"
    gov_shapefile = r"D:\RICAAR\Governorates\JordanwithGovernorates.shp"
    output_dir = r"D:\RICAAR\Pr.New.Stations.Selection\ensemble.models.6.models"

    # Define all 6 model names
    model_names = [
        'CMCC-CM2-SR5',
        'CNRM-ESM2-1',
        'EC-Earth3-Veg',
        'IPSL-CM6A-LR',
        'MPI-ESM1-2-LR',
        'NorESM2-MM'
    ]

    # Create output directory if it doesn't exist
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    try:
        print("=== 6-Model Ensemble Generator ===")
        print(f"Using models: {', '.join(model_names)}")
        print(f"Input directory: {nc_dir}")
        print(f"Output directory: {output_dir}")

        # Verify all model files exist
        print("\nVerifying model files...")
        missing_files = []
        for model in model_names:
            model_file = Path(nc_dir) / f"{model}.nc"
            if not model_file.exists():
                missing_files.append(f"{model}.nc")

        if missing_files:
            print(f"ERROR: Missing model files: {missing_files}")
            return
        print("All model files found!")

        # Read shapefiles
        print("\nReading shapefiles...")
        basins_gdf = gpd.read_file(basin_shapefile)
        jordan_gdf = gpd.read_file(gov_shapefile)

        # Verify coordinate systems
        print("\nCoordinate Reference Systems:")
        print(f"Basins: {basins_gdf.crs}")
        print(f"Jordan: {jordan_gdf.crs}")

        # Create basin mask using sample NetCDF for coordinates
        print("\nCreating basin mask...")
        sample_ds = xr.open_dataset(Path(nc_dir) / f"{model_names[0]}.nc")
        basin_mask = create_basin_mask(basins_gdf, sample_ds.lat.values, sample_ds.lon.values, jordan_gdf)
        sample_ds.close()

        # Count valid grid points
        valid_points = np.sum(basin_mask != None)
        total_points = basin_mask.size
        print(
            f"Valid grid points (within Jordan basins): {valid_points}/{total_points} ({valid_points / total_points * 100:.1f}%)")

        # Define time periods
        periods = [
            ('1961_1994', '1961-01-01', '1994-12-31'),
            ('1995_2014', '1995-01-01', '2014-12-31'),
            ('2015_2020', '2015-01-01', '2020-12-31'),
            ('2021_2040', '2021-01-01', '2040-12-31'),
            ('2041_2060', '2041-01-01', '2060-12-31'),
            ('2061_2070', '2061-01-01', '2070-12-31')
        ]

        # Create ensemble datasets for each period
        successful_periods = []
        for period_name, start_date, end_date in periods:
            output_path = Path(output_dir) / f"ensemble_precipitation_6models_{period_name}.nc"
            print(f"\n{'=' * 50}")
            print(f"Processing period {period_name}...")

            try:
                ds = create_6_model_ensemble_for_period(
                    nc_dir, basin_mask, model_names,
                    start_date, end_date, output_path
                )

                # Print basic statistics for verification
                print(f"Dataset summary for {period_name}:")
                print(f"  Time steps: {len(ds.time)}")
                print(f"  Grid points with data: {(ds.model_count > 0).sum().values.item()}")
                print(f"  Average models per point: {ds.model_count[ds.model_count > 0].mean().values:.1f}")
                print(f"  Non-zero precipitation points: {(ds.prAdjust != 0).sum().values.item()}")
                print(f"Successfully saved to {output_path}")

                successful_periods.append(period_name)
                ds.close()

            except Exception as e:
                print(f"ERROR processing period {period_name}: {str(e)}")
                continue

        # Create summary report
        print(f"\n{'=' * 50}")
        print("Creating summary report...")
        create_ensemble_summary_report(output_dir, model_names)

        print(f"\n{'=' * 50}")
        print("PROCESSING COMPLETED!")
        print(f"Successfully processed {len(successful_periods)}/{len(periods)} periods")
        print(f"Output files and summary report can be found in: {output_dir}")
        print(f"Models used: {', '.join(model_names)}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        raise


if __name__ == "__main__":
    main()