# Notebook to compare SWE DataSources 

# Step 0 Set up Notebook

In [None]:
import pandas as pd
import boto3
import s3fs
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import io
import requests
from io import StringIO
from snowML.datapipe import data_utils as du 
from snowML.LSTM import LSTM_metrics as met
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
import numpy as np
import itertools


# Step 1 - Define Functions to Gather Data

In [None]:
def get_UA_data(huc_id): 
    f = f"mean_swe_in_{huc_id}.csv"
    b = "snowml-gold"
    df = du.s3_to_df(f, b)
    df.set_index('day', inplace=True)
    df.index = pd.to_datetime(df.index)
    df["mean_swe"] = df["mean_swe"]/1000
    return df

In [None]:
def get_Skagit(huc_id): 
    f = f"data_prior/wus-sr-skagit-{huc_id}-mean-swe.csv"
    df = pd.read_csv(f)
    df.rename(columns={"time": "day", "mean": "mean_swe"}, inplace=True)
    df.set_index('day', inplace=True)
    df.index = pd.to_datetime(df.index)
    df["huc_id"] = huc_id
    df = df[["mean_swe", "huc_id"]]
    return df
    

In [None]:
def get_UCLA(huc_id): 
    f = f"mean_swe_in_{huc_id}_UCLA.csv"
    b = "snowml-gold"
    df = du.s3_to_df(f, b)
    df.rename(columns={"time": "day", "SWE": "mean_swe"}, inplace=True)
    df.set_index('day', inplace=True)
    df.index = pd.to_datetime(df.index)
    return df

In [None]:
def filter(df, filter_date):
    df_filtered = df[df.index >= filter_date]
    return df_filtered

In [None]:
def gather_data(huc_id, filter_date = "1984-10-01"):
    swe_UA  = filter(get_UA_data(huc_id), filter_date)
    swe_Skagit = filter(get_Skagit(huc_id), filter_date)
    swe_UCLA = filter(get_UCLA(huc_id), filter_date)
    df_dict = {"UA_data":swe_UA, "Skagit_data": swe_Skagit, "UCLA_data":swe_UCLA}
    return df_dict

# Step 2 - Define Plotting And Analysis Functions

In [None]:
def plot_swe(df_dict, huc_id, ttl= "plot"):
    """
    Plots mean_swe vs day for a dictionary of DataFrames with yearly x-axis ticks,
    rotated labels, and no grid lines.

    Parameters:
    df_dict (dict): Dictionary where keys are labels and values are DataFrames 
                    with 'mean_swe' column and 'day' as index.
    """
    plt.figure(figsize=(12, 6))
    
    for label, df in df_dict.items():
        plt.plot(df.index, df['mean_swe'], label=label)
    
    plt.xlabel('Year')
    plt.ylabel('Mean SWE')
    plt.title(ttl)
    plt.legend()

    # Format x-axis to show ticks yearly and rotate them
    ax = plt.gca()
    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    plt.xticks(rotation=45)

    plt.tight_layout()
    plt.savefig(f"charts/{ttl}.png", bbox_inches='tight')
    
    plt.show()
    plt.close()




In [None]:
def similarity(df_dict):
    """
    Computes R-squared, Pearson correlation, alpha, and beta between 'mean_swe' columns 
    for each pair of DataFrames in df_dict.

    Parameters:
    df_dict (dict): Dictionary where keys are identifiers and values are DataFrames
                    with a 'mean_swe' column.
    """
    keys = list(df_dict.keys())
    results = []

    for key1, key2 in itertools.combinations(keys, 2):
        df1 = df_dict[key1]
        df2 = df_dict[key2]

        # Align on index (e.g., day)
        df1_aligned, df2_aligned = df1.align(df2, join='inner')

        a = df1_aligned['mean_swe']
        b = df2_aligned['mean_swe']

        r2 = r2_score(a, b)
        pearson_corr, _ = pearsonr(a, b)
        alpha = np.std(a) / np.std(b) if np.std(b) != 0 else np.nan
        beta = np.mean(a) / np.mean(b) if np.mean(b) != 0 else np.nan

        results.append((key1, key2, pearson_corr, alpha, beta, r2))

    # Create and print a clean DataFrame of results
    results_df = pd.DataFrame(
        results, 
        columns=['Dataset 1', 'Dataset 2', 'Pearson Corr', 'Ratio_var', 'Ratio_means', 'R-squared']
    )
    print(results_df.round(3))




# Step 3 - PLot and Analyze Data 

In [None]:
hucs = [1711000504, 1711000505, 1711000506, 1711000507, 1711000508, 1711000509, 1711000511] 

In [None]:
for huc_id in hucs: 
    df_dict = gather_data(huc_id)
    ttl = f'Mean_SWE_over_Time_for_Huc_{huc_id}' 
    plot_swe(df_dict, huc_id, ttl = ttl)
    df_dict = gather_data(huc_id)
    similarity(df_dict)

In [None]:
filter_date = "2005-10-01"
for huc_id in hucs: 
    df_dict = gather_data(huc_id, filter_date = filter_date)
    ttl = f'MeanS_WE_over_Time_for_Huc_{huc_id}_2005_through_2022' 
    plot_swe(df_dict, huc_id, ttl = ttl)
    similarity(df_dict)
    