## Description
____
This script generates plots comparing the amount of SWE on a given date with the amount of summer  
streamflow between two given dates.

### Import Libraries

In [None]:
import xarray as xr
import geopandas as gpd
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np
from scipy.stats import linregress

### Inputs

In [None]:
# Easymore SWE directory path
directory= 'easymore_snotel/'

# Define the start month and day of analysis
month = 4
day = 1

# define end month and day of analysis
end_month= 6
end_day= 30

# Path to HYPE observed flow 
qobs= '../esp_hype/model/Qobs.txt'

# HYPE river segment to be analyzed
riv_seg= '58213'

# Output image location
output_dir = './plots/'

### Get Upstream Segments

In [None]:
# Read the river and catchment shapefiles
gdf = gpd.read_file('smm_tgf_modified/smm_riv.shp')
catchment= gpd.read_file('smm_tgf_modified/smm_cat.shp')

In [None]:
# Create a dictionary from 'hru_nhm' and 'seg_nhm' columns in catchment
segment_dict = dict(zip(catchment['seg_nhm'], catchment['hru_nhm']))

In [None]:
# Convert columns to string data type
gdf['seg_nhm'] = gdf['seg_nhm'].astype(str)
gdf['ds_seg_nhm'] = gdf['ds_seg_nhm'].astype(str)

In [None]:
# Create a directed graph
riv_graph = nx.DiGraph()

# Add edges from DataFrame
for idx, row in gdf.iterrows():
    if row['ds_seg_nhm'] != '0':  # Skip if ds_seg_nhm is '0'
        riv_graph.add_edge(row['seg_nhm'], row['ds_seg_nhm'])

# Find upstream segments for given segment
upstream_segments = list(nx.ancestors(riv_graph, riv_seg))

# Add the target segment 'riv_seg' to the upstream segments
upstream_segments.append(riv_seg)

In [None]:
# Convert keys in segment_dict to integers
segment_dict = {int(k): v for k, v in segment_dict.items()}

# Convert values in upstream_segments to integers
upstream_segments = [int(seg) for seg in upstream_segments]

# Convert stream segments to hru IDs for comparison with snotel
upstream_segments = [segment_dict.get(seg, seg) for seg in upstream_segments]

### SWE

In [None]:
# Initialize an empty list to store results
results = []

In [None]:
# Loop through each .nc file in the directory
for filename in os.listdir(directory):
    if filename.endswith('.nc'):
        # Read the dataset into xarray
        filepath = os.path.join(directory, filename)
        dataset = xr.open_dataset(filepath)
        
        # List to store total SWE for each upstream segment
        total_swe_list = []
        
        for seg in upstream_segments:            
            # Select the 'swe' variable and subset by ID
            subset_data = dataset['swe'].sel(ID=seg)

            # Convert the xarray DataArray to a pandas DataFrame
            df = subset_data.to_dataframe().reset_index()

            # Convert 'time' column to datetime format
            df['time'] = pd.to_datetime(df['time'])

            # Set 'time' column as index and drop 'ID' column
            df.set_index('time', inplace=True)
            df.drop(columns=['ID'], inplace=True)

            # Convert the index to datetime format if it's not already
            df.index = pd.to_datetime(df.index)
            
            # Create a mask for the specified month and day
            mask = (df.index.month == month) & (df.index.day == day)

            # Apply the mask to filter the DataFrame
            filtered_df = df.loc[mask]

            # Extract the 'SWE' value
            swe_value = filtered_df['swe'].values[0] if not filtered_df.empty else None

            # For total winter SWE
            # Create a mask for the specified month and day
            # mask = (df.index.month < month) | ((df.index.month == month) & (df.index.day <= day))

            # Apply the mask to filter the DataFrame
            #filtered_df = df.loc[mask]

            # Sum the 'SWE' values
           #total_swe = filtered_df['swe'].sum()
            
            # Append the total SWE to the list
            total_swe_list.append(swe_value)
        
        # Get the first date in the filtered DataFrame
        first_date = filtered_df.index.min()
 
        
        # Append the results to the final list
        results.append([filename, sum(total_swe_list), first_date])

In [None]:
# Convert results to a DataFrame for better visualization
results_df = pd.DataFrame(results, columns=['Filename', 'Total_SWE', 'First_Date'])

In [None]:
# Add a new column 'Rank' to rank the 'Total_SWE' values from largest to smallest
results_df['SWE_Rank'] = results_df['Total_SWE'].rank(ascending=False).astype(int)

In [None]:
# Convert 'First_Date' column to only include the year
results_df['First_Date'] = results_df['First_Date'].dt.year

### Qobs

In [None]:
# Read qobs hype input to dataframe
df = pd.read_csv(qobs, sep='\t', index_col=0)

In [None]:
# Convert the index to datetime
df.index = pd.to_datetime(df.index)

# Change the headers to integers
df.columns = df.columns.astype(int)

In [None]:
# Drop all columns that are not equal to 'riv_seg'
df = df.loc[:, df.columns == int(riv_seg)]

In [None]:
# Create an empty DataFrame to store the results
qobs_df = pd.DataFrame(columns=['Year', 'Total_flow'])

# Iterate over unique years in the index
for year in df.index.year.unique():
    # Define start and end dates for filtering
    start_date = f"{year}-{month}-{day}"
    end_date = f"{year}-{end_month}-{end_day}"
    
    # Create a mask for the specified date range
    mask = (df.index >= start_date) & (df.index <= end_date)
    
    # Apply the mask to filter the DataFrame
    filtered_df = df.loc[mask]
    
    # Sum the summer streamflow values
    total_swe = filtered_df[int(riv_seg)].sum()
    
    # Append the results to the results DataFrame
    qobs_df = pd.concat([qobs_df, pd.DataFrame({'Year': [year], 'Total_flow': [total_swe]})], ignore_index=True)

In [None]:
# Rank the 'Total_flow' values from largest to smallest
qobs_df['Flow_Rank'] = qobs_df['Total_flow'].rank(ascending=False).astype(int)

In [None]:
# Merge results_df and qobs_df based on 'First_Date' and 'Year' columns
merged_df = pd.merge(results_df, qobs_df, left_on='First_Date', right_on='Year', how='inner')

# Drop the duplicate 'Year' column
merged_df.drop(columns=['Year'], inplace=True)

In [None]:
merged_df

### Plot

In [None]:
# Scatter plot
plt.figure(figsize=(10, 6))
plt.scatter(merged_df['Flow_Rank'], merged_df['SWE_Rank'], label='SWE vs Flow Rank', color='blue')

# Compute regression line
slope, intercept, r_value, p_value, std_err = linregress(merged_df['Flow_Rank'], merged_df['SWE_Rank'])
regression_line = slope * merged_df['Flow_Rank'] + intercept

# Plot regression line with thinner and dashed style
plt.plot(merged_df['Flow_Rank'], regression_line, color='red',  linewidth=0.7, label=f'Regression Line\nR-squared: {r_value**2:.2f}')

# Set plot title and labels
plt.title(f'April 1st SWE and Summer Runoff at Segment {riv_seg}')
plt.xlabel('Flow Rank')
plt.ylabel('SWE Rank')
plt.legend()
plt.grid(True)

# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Save the plot as an image file
output_file_path = os.path.join(output_dir, f'{riv_seg}_SWEvsFlow')
plt.savefig(output_file_path, dpi=300)  # dpi specifies the resolution (dots per inch)

# Show plot
plt.tight_layout()
plt.show()

# Print R-squared value
print(f'R-squared: {r_value**2:.2f}')