In [5]:
import pandas as pd
import swe2hs as jopack


In [None]:
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

# 1. Load the data with the month_period column
df = pd.read_csv('updated_manual_stations_data_months.csv')
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')

# Renaming for consistency with your existing code
df = df.rename(columns={'SWE_[m]': 'swe', 'HS_[m]': 'hsm'})

sites = df['site_id'].unique()

# Dictionaries to store results
r2_snow_all = {}    # Period 1 (Oct-Jan)
r2_nosnow_all = {}  # Period 2 (Feb-Sep)
results_list = []

for site_id in sites:
    site_df = df[df['site_id'] == site_id].sort_index()
    
    # --- Step 1: Generate Predictions for the whole site series ---
    # We do this on the whole series to maintain model state continuity
    hs_pred = jopack.convert_1d(
        site_df['swe'],
        swe_input_unit='m',
        hs_output_unit='m'
    )
    site_df['hs_pred'] = hs_pred

    # --- Step 2: Calculate R2 for each period ---
    
    # Snow Period (month_period == 1)
    snow_mask = (site_df['month_category'] == 1)
    snow_df = site_df[snow_mask].dropna(subset=['hsm', 'hs_pred'])
    if len(snow_df) > 1:
        r2_snow = r2_score(snow_df['hsm'], snow_df['hs_pred'])
    else:
        r2_snow = np.nan
    
    # No Snow Period (month_period == 2)
    nosnow_mask = (site_df['month_category'] == 2)
    nosnow_df = site_df[nosnow_mask].dropna(subset=['hsm', 'hs_pred'])
    if len(nosnow_df) > 1:
        r2_nosnow = r2_score(nosnow_df['hsm'], nosnow_df['hs_pred'])
    else:
        r2_nosnow = np.nan

    r2_snow_all[site_id] = r2_snow
    r2_nosnow_all[site_id] = r2_nosnow
    

# --- Step 4: Create a combined R2 DataFrame ---
r2_df = pd.DataFrame({
    'site_id': list(r2_snow_all.keys()),
    'R2_Snow': list(r2_snow_all.values()),
    'R2_No_Snow': list(r2_nosnow_all.values())
})

print(r2_df.head())
r2_df.to_csv("r2_comparison_month_wise.csv", index=False)

  site_id   R2_Snow  R2_No_Snow
0     1AD  0.926136    0.956271
1     1GD  0.973809    0.962445
2     1GS  0.952419    0.960398
3     1GT  0.961341    0.962768
4     1LS  0.914702    0.959346
