# Precipitation Data Analysis and Comparison

This notebook processes and compares precipitation data from multiple sources:
- High Resolution precipitation data
- ERA5 reanalysis data
- GSMaP satellite data
- IMERG satellite data
- Observed precipitation data

We'll process the high-resolution data to match the DHM-style daily precipitation format, then compare all datasets visually and statistically.

In [None]:
# Import required libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob
import ipywidgets as widgets
from ipywidgets import interact
from sklearn.metrics import mean_squared_error, r2_score

# Configure visualization settings for publication quality
sns.set_theme(style="whitegrid")
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 16
plt.rcParams['figure.dpi'] = 300

## Find Available Floods

First, let's identify all the available flood directories in the results folder.

In [None]:
# Find all flood directories
flood_dirs = sorted(glob("../results/floods/flood_*"))
flood_names = [os.path.basename(d) for d in flood_dirs]

print(f"Found {len(flood_names)} flood events: {flood_names}")

## Process High-Resolution Data to Daily Format

This function processes high-resolution precipitation data to conform to DHM-style daily format by:
1. Shifting timestamps by -8 hours and 45 minutes
2. Aggregating to daily values
3. Saving the processed data to a new CSV file

In [None]:
def process_highres_data(flood_name):
    """Process high-resolution precipitation data to daily format with DHM-style shifting"""
    # File paths
    highres_csv = f"../results/floods/{flood_name}/highres/high_resolution_precipitation.csv"
    output_csv = f"../results/floods/{flood_name}/highres/high_resolution_precipitation_daily.csv"
    
    # Check if highres data exists
    if not os.path.exists(highres_csv):
        print(f"⚠️ High-resolution data not found for {flood_name}")
        return None
    
    print(f"Processing high-resolution data for {flood_name}...")
    
    # Load high-resolution precipitation data
    highres = pd.read_csv(highres_csv)
    
    # Convert datetime column to datetime format
    highres["datetime"] = pd.to_datetime(highres["datetime"])
    
    # Apply DHM-style daily shifting (shift timestamp by -8 hours and 45 minutes)
    highres["shifted_time"] = highres["datetime"] - pd.Timedelta(hours=8, minutes=45)
    
    # Create date column based on shifted timestamp
    highres["date"] = (highres["shifted_time"].dt.normalize() + pd.Timedelta(days=1))
    
    # Group by date and sum numeric columns
    numeric_cols = highres.select_dtypes(include='number').columns
    highres_daily = highres.groupby("date")[numeric_cols].sum()
    
    # Sort by date and reset index
    highres_daily = highres_daily.sort_index()
    highres_daily = highres_daily.reset_index()
    
    # Save the processed data
    highres_daily.to_csv(output_csv, index=False)
    
    print(f"✅ High-resolution data converted to DHM-style daily precipitation and saved to: {output_csv}")
    return output_csv

## Compare Precipitation Data Sources

This function loads and compares precipitation data from different sources:
- High-resolution precipitation data
- ERA5 reanalysis
- GSMaP satellite data
- IMERG satellite data
- Observed precipitation data

It generates publication-quality comparison plots for each common station and includes statistical metrics (RMSE and R²).

In [None]:
def compare_precipitation_data(flood_name):
    """Compare precipitation data from different sources and generate publication-quality plots"""
    # Process high-resolution data if it hasn't been processed yet
    highres_path = f"../results/floods/{flood_name}/highres/high_resolution_precipitation_daily.csv"
    if not os.path.exists(highres_path):
        highres_path = process_highres_data(flood_name)
        if highres_path is None:
            return
    
    # Define file paths
    era5_path = f"../results/floods/{flood_name}/gee_era5/daily_precipitation.csv"
    gsmap_path = f"../results/floods/{flood_name}/gee_gsmap/daily_precipitation.csv"
    imerg_path = f"../results/floods/{flood_name}/gee_imerg/daily_precipitation.csv"
    observed_path = "../data/filtered_observed_precip.xlsx"
    
    # Create output directory - UPDATED PATH as requested
    output_dir = f"../results/floods/{flood_name}/Final Comparison"
    os.makedirs(output_dir, exist_ok=True)
    
    # Check if required files exist
    missing_files = []
    for path, name in [(era5_path, "ERA5"), (gsmap_path, "GSMaP"), 
                       (imerg_path, "IMERG"), (observed_path, "Observed")]:
        if not os.path.exists(path):
            missing_files.append(name)
    
    if missing_files:
        print(f"⚠️ Missing data files for {flood_name}: {', '.join(missing_files)}")
        if "Observed" in missing_files:
            print("⚠️ Observed data is required for comparison. Aborting.")
            return
    
    print(f"Comparing precipitation data for {flood_name}...")
    
    # Load data
    datasets = {}
    try:
        datasets["highres"] = pd.read_csv(highres_path)
        if os.path.exists(era5_path):
            datasets["era5"] = pd.read_csv(era5_path)
        if os.path.exists(gsmap_path):
            datasets["gsmap"] = pd.read_csv(gsmap_path)
        if os.path.exists(imerg_path):
            datasets["imerg"] = pd.read_csv(imerg_path)
        datasets["observed"] = pd.read_excel(observed_path)
    except Exception as e:
        print(f"⚠️ Error loading data: {e}")
        return
    
    # Parse and set date columns
    try:
        datasets["highres"]["date"] = pd.to_datetime(datasets["highres"]["date"])
        datasets["highres"] = datasets["highres"].set_index("date")
        
        if "era5" in datasets:
            datasets["era5"]["Date"] = pd.to_datetime(datasets["era5"]["Date"])
            datasets["era5"] = datasets["era5"].set_index("Date")
        
        if "gsmap" in datasets:
            datasets["gsmap"]["Date"] = pd.to_datetime(datasets["gsmap"]["Date"])
            datasets["gsmap"] = datasets["gsmap"].set_index("Date")
        
        if "imerg" in datasets:
            datasets["imerg"]["Date"] = pd.to_datetime(datasets["imerg"]["Date"])
            datasets["imerg"] = datasets["imerg"].set_index("Date")
        
        # Process observed data
        datasets["observed"]["date"] = pd.to_datetime(datasets["observed"][["year", "month", "days"]]).dt.normalize()
        datasets["observed"] = datasets["observed"].set_index("date").drop(columns=["year", "month", "days"])
        datasets["observed"].columns = datasets["observed"].columns.astype(str)
    except Exception as e:
        print(f"⚠️ Error processing dates: {e}")
        return
    
    # Align dates based on highres dates
    daily_dates = datasets["highres"].index.normalize()
    
    def align_dates(df, valid_dates):
        df.index = pd.to_datetime(df.index).normalize()
        return df.loc[df.index.isin(valid_dates)]
    
    aligned_data = {}
    aligned_data["highres"] = align_dates(datasets["highres"], daily_dates)
    
    if "era5" in datasets:
        aligned_data["era5"] = align_dates(datasets["era5"], daily_dates)
    if "gsmap" in datasets:
        aligned_data["gsmap"] = align_dates(datasets["gsmap"], daily_dates)
    if "imerg" in datasets:
        aligned_data["imerg"] = align_dates(datasets["imerg"], daily_dates)
    
    aligned_data["observed"] = datasets["observed"].loc[datasets["observed"].index.isin(daily_dates)]
    
    # Find common stations
    station_sets = []
    
    # Extract station IDs from column names for each dataset
    hr_stations = [col.split("_")[-1] for col in aligned_data["highres"].columns]
    station_sets.append(set(hr_stations))
    
    if "era5" in aligned_data:
        era5_stations = [col.split("_")[-1] for col in aligned_data["era5"].columns]
        station_sets.append(set(era5_stations))
    
    if "gsmap" in aligned_data:
        gsmap_stations = [col.split("_")[-1] for col in aligned_data["gsmap"].columns]
        station_sets.append(set(gsmap_stations))
        
    if "imerg" in aligned_data:
        imerg_stations = [col.split("_")[-1] for col in aligned_data["imerg"].columns]
        station_sets.append(set(imerg_stations))
    
    observed_stations = list(aligned_data["observed"].columns)
    station_sets.append(set(observed_stations))
    
    # Find common stations across all datasets
    common_stations = sorted(set.intersection(*station_sets))
    
    print(f"✅ Found {len(common_stations)} common stations for comparison: {common_stations}")
    
    # Prepare statistics table
    stats_table = pd.DataFrame(
        index=["ERA5", "GSMaP", "HighRes", "IMERG"],
        columns=["RMSE", "R2"]
    )
    
    # Initialized aggregated statistics
    aggregated_stats = {
        "ERA5": {"rmse": [], "r2": []},
        "GSMaP": {"rmse": [], "r2": []},
        "HighRes": {"rmse": [], "r2": []},
        "IMERG": {"rmse": [], "r2": []}
    }
    
    # Create plots for each station
    for sid in common_stations:
        colname = f"Station_{sid}"
        
        try:
            # Create figure and axes with seaborn style
            fig, ax = plt.subplots(figsize=(10, 6))
            sns.set_style("whitegrid")
            
            # Dictionary to collect data for statistics
            station_data = {}
            
            # Get observed data for this station
            if sid in aligned_data["observed"].columns and aligned_data["observed"][sid].notna().any():
                observed_data = aligned_data["observed"][sid]
                sns.lineplot(x=aligned_data["observed"].index, y=observed_data, 
                           label="Observed", linestyle='--', linewidth=2.5, color='black', ax=ax)
                station_data["Observed"] = observed_data
            else:
                print(f"⚠️ No observed data for station {sid}, skipping")
                continue
            
            # Plot each dataset if available with enhanced styling
            datasets_to_plot = [
                ("highres", "HighRes", "#1f77b4"),  # More vibrant blue
                ("era5", "ERA5", "#d62728"),        # More vibrant red
                ("gsmap", "GSMaP", "#2ca02c"),      # More vibrant green
                ("imerg", "IMERG", "#9467bd")       # More vibrant purple
            ]
            
            for data_key, label, color in datasets_to_plot:
                if data_key in aligned_data and colname in aligned_data[data_key].columns and aligned_data[data_key][colname].notna().any():
                    sns.lineplot(x=aligned_data[data_key].index, y=aligned_data[data_key][colname], 
                               label=label, linewidth=2, color=color, ax=ax)
                    station_data[label] = aligned_data[data_key][colname]
            
            # Calculate statistics for this station
            station_stats = {}
            for label, data in station_data.items():
                if label != "Observed" and "Observed" in station_data:
                    # Align the data
                    model_data = data.reindex(station_data["Observed"].index)
                    observed_aligned = station_data["Observed"].reindex(model_data.index)
                    
                    # Remove NaN values
                    valid_idx = ~(model_data.isna() | observed_aligned.isna())
                    if valid_idx.sum() > 0:
                        model_clean = model_data[valid_idx]
                        observed_clean = observed_aligned[valid_idx]
                        
                        # Calculate statistics
                        rmse = np.sqrt(mean_squared_error(observed_clean, model_clean))
                        r2 = r2_score(observed_clean, model_clean)
                        
                        station_stats[label] = {"RMSE": rmse, "R2": r2}
                        
                        # Add to aggregated statistics
                        aggregated_stats[label]["rmse"].append(rmse)
                        aggregated_stats[label]["r2"].append(r2)
            
            # Add statistics table to plot using a more elegant approach
            if station_stats:
                # Create a styled statistics table
                table_data = []
                for model, metrics in station_stats.items():
                    table_data.append([model, f"{metrics['RMSE']:.2f}", f"{metrics['R2']:.2f}"])
                
                # Create the table as an inset
                table_ax = fig.add_axes([0.68, 0.68, 0.27, 0.20], frame_on=True)
                table_ax.axis('off')
                table = table_ax.table(
                    cellText=table_data,
                    colLabels=["Model", "RMSE", "R²"],
                    loc='center',
                    cellLoc='center',
                    bbox=[0, 0, 1, 1]
                )
                table.auto_set_font_size(False)
                table.set_fontsize(10)
                table.scale(1, 1.5)
                
                # Style the table
                for (row, col), cell in table.get_celld().items():
                    if row == 0:  # Header row
                        cell.set_facecolor('#4472C4')
                        cell.set_text_props(color='white', fontweight='bold')
                    elif col == 0:  # Model names column
                        cell.set_text_props(fontweight='bold')
                    
                    # Add light alternating colors
                    if row > 0 and row % 2 == 0:
                        cell.set_facecolor('#E6F0FF')
            
            # Add title and labels with enhanced styling
            ax.set_title(f"Daily Precipitation Comparison - Station {sid} ({flood_name})", 
                       fontweight='bold', pad=20)
            ax.set_xlabel("Date", fontweight='bold')
            ax.set_ylabel("Precipitation (mm)", fontweight='bold')
            
            # Enhance legend
            leg = ax.legend(
                title="Data Sources",
                loc='upper left',           # top-left inside the axes
                frameon=True,
                fancybox=True,
                framealpha=0.9,
                facecolor='white',
                edgecolor='gray'
            )
            leg.get_title().set_fontweight('bold')


            
            # Enhance grid
            ax.grid(True, linestyle='--', alpha=0.7)
            
            # Improve date formatting
            fig.autofmt_xdate()
            
            # Add a subtle border around the plot
            for spine in ax.spines.values():
                spine.set_visible(True)
                spine.set_color('gray')
                spine.set_linewidth(0.5)
                
            # Remove tight_layout and instead use constrained_layout
            fig.subplots_adjust(left=0.1, right=0.95, top=0.9, bottom=0.1)


            outpath = os.path.join(output_dir, f"station_{sid}_comparison.png")
            plt.savefig(outpath, dpi=300, bbox_inches='tight')
            plt.close()
            
            print(f"✅ Created comparison plot for station {sid}")
            
        except Exception as e:
            print(f"⚠️ Skipped station {sid} due to error: {e}")
    
    # Calculate aggregated statistics
    for model in aggregated_stats:
        if aggregated_stats[model]["rmse"]:
            stats_table.loc[model, "RMSE"] = np.mean(aggregated_stats[model]["rmse"])
            stats_table.loc[model, "R2"] = np.mean(aggregated_stats[model]["r2"])
    
    # Save aggregated statistics
    stats_path = os.path.join(output_dir, "aggregated_statistics.csv")
    stats_table.to_csv(stats_path)
    
    print(f"\n✅ Aggregated statistics across all stations:")
    print(stats_table)
    
    return stats_table

## Interactive Flood Selection

Select which flood event to analyze using the dropdown menu below.

In [None]:
# Create dropdown widget for flood selection
def analyze_flood(flood_name):
    print(f"\n===== Analyzing {flood_name} =====\n")
    stats = compare_precipitation_data(flood_name)
    print(f"\n===== Completed analysis for {flood_name} =====\n")
    return stats

# Interactive dropdown
if flood_names:
    interact(analyze_flood, flood_name=widgets.Dropdown(options=flood_names, description='Flood:'))
else:
    print("No flood directories found in ../results/floods/")

## Analyze All Floods

Run this cell to analyze all available flood events and generate a summary of statistics.

In [None]:
def analyze_all_floods():
    """Run analysis on all flood events and compile statistics"""
    all_stats = {}
    
    for flood_name in flood_names:
        print(f"\n===== Analyzing {flood_name} =====\n")
        stats = compare_precipitation_data(flood_name)
        if stats is not None:
            all_stats[flood_name] = stats
    
    if all_stats:
        # Create summary table
        summary = pd.DataFrame(
            index=pd.MultiIndex.from_product(
                [list(all_stats.keys()), ["ERA5", "GSMaP", "HighRes", "IMERG"]],
                names=["Flood", "Model"]
            ),
            columns=["RMSE", "R2"]
        )
        
        # Fill summary table
        for flood, stats in all_stats.items():
            for model in stats.index:
                summary.loc[(flood, model), "RMSE"] = stats.loc[model, "RMSE"]
                summary.loc[(flood, model), "R2"] = stats.loc[model, "R2"]
        
        # Save summary
        summary.to_csv("precipitation_comparison_summary.csv")
        
        print("\n===== Summary of All Floods =====\n")
        print(summary)
        
        # Create comparison plots with publication quality
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 12))
        
        # RMSE comparison
        summary_pivot = summary.reset_index().pivot(index="Flood", columns="Model", values="RMSE")
        
        # Use seaborn's color palette
        colors = sns.color_palette("muted", n_colors=len(summary_pivot.columns))
        
        # Plot with enhanced styling
        summary_pivot.plot(kind="bar", ax=ax1, color=colors, width=0.8, edgecolor='black', linewidth=0.6)
        ax1.set_title("RMSE Comparison Across Flood Events", fontweight='bold', fontsize=16)
        ax1.set_ylabel("RMSE (mm)", fontweight='bold', fontsize=14)
        ax1.set_xlabel("")  # Remove x-label on top plot
        ax1.grid(axis="y", linestyle='--', alpha=0.7)
        
        # Add value labels on bars
        for container in ax1.containers:
            ax1.bar_label(container, fmt='%.2f', fontsize=8, padding=3)
        
        # Enhance legend
        leg1 = ax1.legend(title="Precipitation Products", frameon=True, 
                        fancybox=True, framealpha=0.9, facecolor='white', 
                        edgecolor='gray', loc='upper right')
        leg1.get_title().set_fontweight('bold')
        
        # R2 comparison with enhanced styling
        r2_pivot = summary.reset_index().pivot(index="Flood", columns="Model", values="R2")
        r2_pivot.plot(kind="bar", ax=ax2, color=colors, width=0.8, edgecolor='black', linewidth=0.6)
        ax2.set_title("R² Comparison Across Flood Events", fontweight='bold', fontsize=16)
        ax2.set_ylabel("R² Value", fontweight='bold', fontsize=14)
        ax2.set_xlabel("Flood Event", fontweight='bold', fontsize=14)
        ax2.grid(axis="y", linestyle='--', alpha=0.7)
        
        # Add value labels on bars
        for container in ax2.containers:
            ax2.bar_label(container, fmt='%.2f', fontsize=8, padding=3)
        
        # Enhance legend
        leg2 = ax2.legend(title="Precipitation Products", frameon=True, 
                        fancybox=True, framealpha=0.9, facecolor='white', 
                        edgecolor='gray', loc='upper right')
        leg2.get_title().set_fontweight('bold')
        
        # Add subtle border around the plots
        for ax in [ax1, ax2]:
            for spine in ax.spines.values():
                spine.set_visible(True)
                spine.set_color('gray')
                spine.set_linewidth(0.5)
        
        # Add a title for the entire figure
        fig.suptitle("Comparison of Precipitation Products Across Flood Events", 
                   fontsize=18, fontweight='bold', y=0.98)
        
        plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for the suptitle
        plt.savefig("precipitation_model_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
        
        return summary
    else:
        print("No statistics available for summary.")
        return None

# Run analysis on all floods
# analyze_all_floods()  # Uncomment to run

## Conclusions

This notebook provides a comprehensive analysis of precipitation data from multiple sources, comparing them against observed data for multiple flood events.

Key findings:
1. The statistical metrics (RMSE and R²) help quantify the accuracy of each precipitation product
2. Visualizations show temporal patterns and differences between precipitation sources
3. The analysis can be run for individual flood events or across all available flood data

To expand this analysis, you could consider:
- Adding more statistical measures (bias, correlation coefficient)
- Creating spatial visualizations of precipitation patterns
- Analyzing lag effects between precipitation and flood events