In [None]:
#!/usr/bin/env python3
"""
Extract precipitation data for specific grid points from CMIP6 NEX-GDDP data.

This script processes already-downloaded NetCDF files and extracts data for
22 specific grid points in the Pyrenees region, aggregating daily data to
monthly sums.

Grid points organized in 4 elevation groups:
- high: 3 points at lat 42.875
- uppermid: 3 points at lat 42.875
- lowermid: 8 points at lat 43.125
- low: 8 points at lat 43.375
"""

import xarray as xr
import numpy as np
import pandas as pd
from pathlib import Path


# Define the specific grid points
GRID_POINTS = {
    'high': [
        (42.875, -0.375),
        (42.875, -0.125),
        (42.875, 0.125),
    ],
    'uppermid': [
        (42.875, -0.625),
        (42.875, 0.375),
        (42.875, 0.625),
    ],
    'lowermid': [
        (43.125, -1.125),
        (43.125, -0.875),
        (43.125, -0.625),
        (43.125, -0.375),
        (43.125, -0.125),
        (43.125, 0.125),
        (43.125, 0.375),
        (43.125, 0.625),
    ],
    'low': [
        (43.375, -1.125),
        (43.375, -0.875),
        (43.375, -0.625),
        (43.375, -0.375),
        (43.375, -0.125),
        (43.375, 0.125),
        (43.375, 0.375),
        (43.375, 0.625),
    ]
}


def find_nearest_grid_index(array, value):
    """Find the index of the nearest value in array."""
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx, array[idx]


def extract_precipitation_monthly(
    data_dir,
    model='CNRM-CM6-1',
    scenario='historical',
    start_year=1950,
    end_year=2014
):
    """
    Extract monthly precipitation sums for specific grid points.
    
    Parameters:
    -----------
    data_dir : str
        Directory containing CMIP6 NetCDF files
    model : str
        Model name (default: 'CNRM-CM6-1')
    scenario : str
        Scenario (default: 'historical')
    start_year, end_year : int
        Year range to process
        
    Returns:
    --------
    tuple : (output_dfs, matched_cells)
        - output_dfs: Dictionary with DataFrames for each elevation group
        - matched_cells: Dictionary showing which grid cells were matched
    """
    
    data_dir = Path(data_dir)
    variable = 'pr'
    
    # Find matching files (both full files and regional subsets)
    file_patterns = [
        f"{variable}_day_{model}_{scenario}_*_gr_*subset.nc",  # Regional subsets
        f"{variable}_day_{model}_{scenario}_*_gr_*.nc"         # Full files
    ]
    
    all_files = []
    for pattern in file_patterns:
        all_files.extend(data_dir.glob(pattern))
    
    all_files = sorted(set(all_files))  # Remove duplicates
    
    if not all_files:
        raise FileNotFoundError(f"No files found in {data_dir}")
    
    print(f"Found {len(all_files)} precipitation files for {model} {scenario}")
    
    # Group files by year and select the newest version
    files_by_year = {}
    version_priority = {'subset': 4, 'v2.0': 3, 'v1.2': 2, 'v1.1': 1, 'v1.0': 0}
    
    for f in all_files:
        # Extract year from filename
        filename = f.stem
        parts = filename.split('_')
        
        # Determine version
        if 'subset' in filename:
            version = 'subset'
            # Year is in different position for subset files
            year_str = parts[-2] if 'subset' in parts[-1] else parts[-1]
        elif 'v2' in filename:
            version = 'v2.0'
            year_str = parts[-2]
        elif 'v1' in filename:
            if '_v1.2' in filename or '_v1_2' in filename:
                version = 'v1.2'
            elif '_v1.1' in filename or '_v1_1' in filename:
                version = 'v1.1'
            else:
                version = 'v1.0'
            year_str = parts[-2]
        else:
            version = 'v1.0'
            year_str = parts[-1]
        
        try:
            year = int(year_str)
        except ValueError:
            print(f"  Warning: Could not parse year from {f.name}, skipping")
            continue
        
        # Keep only the highest version for each year
        if year not in files_by_year:
            files_by_year[year] = (f, version)
        else:
            existing_version = files_by_year[year][1]
            if version_priority[version] > version_priority[existing_version]:
                print(f"  Upgrading {year}: {existing_version} -> {version}")
                files_by_year[year] = (f, version)
    
    print(f"\nVersion summary:")
    version_counts = {}
    for year, (f, version) in files_by_year.items():
        version_counts[version] = version_counts.get(version, 0) + 1
    for version, count in sorted(version_counts.items()):
        print(f"  {version}: {count} files")
    
    # Filter by year range
    filtered_files = []
    for year in sorted(files_by_year.keys()):
        if start_year <= year <= end_year:
            filtered_files.append(files_by_year[year][0])
    
    files = sorted(filtered_files)
    print(f"\nProcessing {len(files)} files from {start_year} to {end_year}")
    
    # Dictionary to store data for each elevation group
    results = {group: [] for group in GRID_POINTS.keys()}
    matched_cells = {group: [] for group in GRID_POINTS.keys()}
    
    # Process each file
    for file_idx, file_path in enumerate(files):
        print(f"Processing: {file_path.name}")
        
        ds = xr.open_dataset(file_path, engine='netcdf4')
        
        # Get lat/lon arrays from the file
        lats = ds.lat.values
        lons = ds.lon.values
        
        # Check if longitudes are in 0-360 format
        if lons.min() >= 0 and lons.max() > 180:
            use_360 = True
        else:
            use_360 = False
        
        # For each elevation group
        for group_name, points in GRID_POINTS.items():
            
            # Extract data for each point in this group
            group_monthly_data = []
            
            for lat_target, lon_target in points:
                # Convert longitude to 0-360 if data uses that format
                if use_360 and lon_target < 0:
                    lon_search = lon_target + 360
                else:
                    lon_search = lon_target
                
                # Find nearest grid cell
                lat_idx, lat_actual = find_nearest_grid_index(lats, lat_target)
                lon_idx, lon_actual = find_nearest_grid_index(lons, lon_search)
                
                # Store matched cells (only for first file to avoid duplicates)
                if file_idx == 0:
                    # Convert back to -180 to 180 for display if needed
                    lon_display = lon_actual if lon_actual <= 180 else lon_actual - 360
                    matched_cells[group_name].append(
                        ((lat_target, lon_target), (lat_actual, lon_display))
                    )
                
                # Check if we found a reasonable match (within 0.25 degrees)
                if abs(lat_actual - lat_target) > 0.25:
                    print(f"  Warning: Lat {lat_target} matched to {lat_actual:.3f}")
                if abs(lon_actual - lon_search) > 0.25:
                    print(f"  Warning: Lon {lon_target} (searching {lon_search:.3f}) matched to {lon_actual:.3f}")
                
                # Extract time series for this grid point
                point_data = ds[variable].isel(lat=lat_idx, lon=lon_idx)
                
                # Resample to monthly sum
                monthly = point_data.resample(time='1MS').sum()
                
                group_monthly_data.append(monthly)
            
            # Average across all points in this elevation group
            group_mean = sum(group_monthly_data) / len(group_monthly_data)
            results[group_name].append(group_mean)
        
        ds.close()
    
    # Combine all years for each group and convert to DataFrame
    output_dfs = {}
    
    for group_name in GRID_POINTS.keys():
        # Concatenate all years
        combined = xr.concat(results[group_name], dim='time')
        
        # Convert to DataFrame
        df = combined.to_dataframe(name='pr_monthly_sum').reset_index()
        
        # Add metadata
        df.attrs['group'] = group_name
        df.attrs['num_points'] = len(GRID_POINTS[group_name])
        df.attrs['model'] = model
        df.attrs['scenario'] = scenario
        df.attrs['variable'] = 'pr'
        df.attrs['units'] = 'kg m-2 s-1 (summed over days in month)'
        
        output_dfs[group_name] = df
        
        print(f"\n{group_name}: {len(df)} months extracted")
    
    return output_dfs, matched_cells


def save_results_by_group(results_dict, output_dir='.', matched_cells=None):
    """Save results to separate CSV files for each elevation group."""
    
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    for group_name, df in results_dict.items():
        output_file = output_dir / f"pr_monthly_{group_name}.csv"
        
        with open(output_file, 'w') as f:
            f.write(f"# NEX-GDDP-CMIP6 Monthly Precipitation - {group_name.upper()} elevation group\n")
            f.write(f"# Model: {df.attrs['model']}\n")
            f.write(f"# Scenario: {df.attrs['scenario']}\n")
            f.write(f"# Variable: {df.attrs['variable']}\n")
            f.write(f"# Units: {df.attrs['units']}\n")
            f.write(f"# NOTE: To convert to mm/month, multiply by 86400\n")
            f.write(f"#       Example: 0.0005 kg m-2 s-1 * 86400 = 43.2 mm/month\n")
            f.write(f"# Number of grid points averaged: {df.attrs['num_points']}\n")
            f.write(f"#\n")
            f.write(f"# Target grid points and matched cells:\n")
            
            if matched_cells and group_name in matched_cells:
                for target, actual in matched_cells[group_name]:
                    f.write(f"#   Target: ({target[0]}, {target[1]}) -> Matched: ({actual[0]:.3f}, {actual[1]:.3f})\n")
            else:
                for lat, lon in GRID_POINTS[group_name]:
                    f.write(f"#   ({lat}, {lon})\n")
            f.write(f"#\n")
        
        df.to_csv(output_file, mode='a', index=False)
        print(f"Saved: {output_file}")


def save_combined_results(results_dict, output_file, add_mm_month=True):
    """Save all elevation groups to a single CSV with group identifier."""
    
    all_data = []
    
    for group_name, df in results_dict.items():
        df_copy = df.copy()
        df_copy['elevation_group'] = group_name
        
        # Add mm/month conversion if requested
        if add_mm_month:
            df_copy['time'] = pd.to_datetime(df_copy['time'])
            # Correct conversion: just multiply by seconds per day
            # Each value is already a monthly sum of daily rates
            df_copy['pr_mm_month'] = df_copy['pr_monthly_sum'] * 86400
        
        all_data.append(df_copy)
    
    combined_df = pd.concat(all_data, ignore_index=True)
    
    if add_mm_month:
        cols = ['time', 'elevation_group', 'pr_monthly_sum', 'pr_mm_month']
    else:
        cols = ['time', 'elevation_group', 'pr_monthly_sum']
    combined_df = combined_df[cols]
    
    with open(output_file, 'w') as f:
        f.write("# NEX-GDDP-CMIP6 Monthly Precipitation - All Elevation Groups\n")
        f.write(f"# Model: {results_dict['high'].attrs['model']}\n")
        f.write(f"# Scenario: {results_dict['high'].attrs['scenario']}\n")
        f.write(f"# Variable: pr (precipitation)\n")
        f.write(f"#\n")
        f.write(f"# Units:\n")
        f.write(f"#   pr_monthly_sum: kg m-2 s-1 (summed over days in month)\n")
        if add_mm_month:
            f.write(f"#   pr_mm_month: millimeters per month (converted)\n")
            f.write(f"#   Conversion: pr_mm_month = pr_monthly_sum * 86400\n")
        f.write(f"#\n")
        f.write(f"# Elevation groups: high, uppermid, lowermid, low\n")
        f.write(f"#\n")
        f.write("# Grid points by group:\n")
        for group_name, points in GRID_POINTS.items():
            f.write(f"# {group_name}: {len(points)} points\n")
            for lat, lon in points:
                f.write(f"#   ({lat}, {lon})\n")
        f.write(f"#\n")
    
    combined_df.to_csv(output_file, mode='a', index=False)
    print(f"\nSaved combined file: {output_file}")


# ===== MAIN EXECUTION =====
if __name__ == "__main__":
    
    # ===== CONFIGURATION =====
    
    # Data directory containing NetCDF files
    DATA_DIR = "C:/Users/jet58062/Desktop/french_pr_hist"
    
    # Model and scenario
    MODEL = "CNRM-CM6-1"
    SCENARIO = "historical"  # or 'ssp126', 'ssp245', 'ssp370', 'ssp585'
    
    # Year range
    START_YEAR = 1950
    END_YEAR = 2014
    
    # Output directory for CSV files
    OUTPUT_DIR = "."  # Current directory
    
    # ===== RUN EXTRACTION =====
    
    print("\n" + "="*70)
    print("CMIP6 Precipitation Extraction - Specific Grid Points")
    print("="*70)
    print(f"Model: {MODEL}")
    print(f"Scenario: {SCENARIO}")
    print(f"Period: {START_YEAR}-{END_YEAR}")
    print(f"Data directory: {DATA_DIR}")
    print(f"\nElevation groups:")
    for group, points in GRID_POINTS.items():
        print(f"  {group}: {len(points)} points")
    print("="*70)
    
    try:
        # Extract data
        results, matched_cells = extract_precipitation_monthly(
            data_dir=DATA_DIR,
            model=MODEL,
            scenario=SCENARIO,
            start_year=START_YEAR,
            end_year=END_YEAR
        )
        
        print("\n" + "="*70)
        print("EXTRACTION COMPLETE")
        print("="*70)
        
        # Display summary statistics
        print("\nSummary Statistics:")
        print("-"*70)
        for group_name, df in results.items():
            print(f"\n{group_name.upper()}:")
            print(f"  Months: {len(df)}")
            print(f"  Date range: {df['time'].min()} to {df['time'].max()}")
            print(f"  Mean: {df['pr_monthly_sum'].mean():.6e} kg m-2 s-1 ({df['pr_monthly_sum'].mean() * 86400:.1f} mm/month)")
            print(f"  Min: {df['pr_monthly_sum'].min():.6e} kg m-2 s-1 ({df['pr_monthly_sum'].min() * 86400:.1f} mm/month)")
            print(f"  Max: {df['pr_monthly_sum'].max():.6e} kg m-2 s-1 ({df['pr_monthly_sum'].max() * 86400:.1f} mm/month)")
        
        # Save results
        print("\n" + "="*70)
        print("SAVING RESULTS")
        print("="*70)
        
        save_results_by_group(results, OUTPUT_DIR, matched_cells)
        save_combined_results(results, f"{OUTPUT_DIR}/pr_monthly_all_groups.csv")
        
        print("\n✓ All files saved successfully!")
        
        # Show first few rows
        print("\nFirst 10 rows of combined data:")
        combined_data = pd.read_csv(f"{OUTPUT_DIR}/pr_monthly_all_groups.csv", comment='#')
        print(combined_data.head(10))
        
    except Exception as e:
        print(f"\n✗ Error: {e}")
        import traceback
        traceback.print_exc()