In [None]:

import os
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import os
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

def calculate_rank_histogram(measurement_file, station_file, ensemble_dir):
    # Load the measurement data
    measurements = xr.open_dataset(measurement_file)
    observed_values = measurements['snw'].to_dataframe().reset_index()
    
    # Load the station details
    stations_df = pd.read_csv(station_file)
    
    # Merge observed values with station details to get HRU_ID for each observation
    merged_df = pd.merge(observed_values, stations_df, on='station_id')
    
    # Get the unique stations
    unique_stations = merged_df['station_id'].unique()
    
    for station in unique_stations:
        # Filter data for the current station
        station_data = merged_df[merged_df['station_id'] == station]
        
        # Initialize a list to store ensemble predictions
        ensemble_predictions = []
        
        # Read ensemble member files
        for file_name in os.listdir(ensemble_dir):
            if file_name.endswith('.nc'):
                ensemble_file = os.path.join(ensemble_dir, file_name)
                ensemble_data = xr.open_dataset(ensemble_file)
                ensemble_snw = ensemble_data['scalarSWE'].to_dataframe().reset_index()
                
                # Merge ensemble data with HRU_IDs (use 'hruId' for ensemble files)
                merged_ensemble = pd.merge(ensemble_snw, station_data, left_on='hru', right_on='HRU_ID')
                ensemble_predictions.append(merged_ensemble['scalarSWE'].values)
        
        # Convert ensemble predictions to numpy array
        ensemble_predictions = np.array(ensemble_predictions)
        
        # Reshape to (n_ensemble_members, n_points)
        n_ensemble_members, n_points = ensemble_predictions.shape[0], ensemble_predictions.shape[1]
        
        # Initialize rank counts
        ranks = np.zeros(n_ensemble_members + 1)
        
        # Calculate rank histogram
        for i in range(n_points):
            combined = np.concatenate(([station_data.iloc[i]['snw']], ensemble_predictions[:, i]))
            sorted_indices = np.argsort(combined)
            obs_rank = np.where(sorted_indices == 0)[0][0]
            ranks[obs_rank] += 1
        
        # Plot the rank histogram for the current station
        plt.figure()
        plt.bar(range(len(ranks)), ranks, align='center', edgecolor='black')
        plt.xlabel('Rank')
        plt.ylabel('Frequency')
        plt.title(f'Rank Histogram for Station {station}')
        plt.show()


# Example usage
measurement_file = '/Users/dcasson/Data/CanSWE/CanSWE-CanEEN_1928-2023_v6.nc'
station_file = '/Users/dcasson/Data/summa_snakemake/bow_above_banff/stations_with_hru_ids.csv'
ensemble_dir = '/Users/dcasson/Data/summa_snakemake/bow_above_banff/summa/output/'

calculate_rank_histogram(measurement_file, station_file, ensemble_dir)

: 