In [None]:
class Switch_Model(Model):
    threshold = 0.7
    
    def __init__(self, window_switch, window_regular, CONV_WIDTH):
        self.window_switch = window_switch
        self.window = window_regular
        
        assert(window_switch.input_width == self.window.input_width)
        
        self.switch = Ensemble_Static(window_switch)
        
        self.regular = Base_Model(model_name='multi-LSTM', window=window_regular, CONV_WIDTH=CONV_WIDTH)
        self.q70 = Base_Model(model_name='multi-LSTM', window=window_regular, CONV_WIDTH=CONV_WIDTH, loss_func=CustomLoss.qloss_70)
        self.q95 = Base_Model(model_name='multi-LSTM', window=window_regular, CONV_WIDTH=CONV_WIDTH, loss_func=CustomLoss.qloss_95)
        
    def predictions(self, station):
        preds_switch = self.switch.predictions(station)     
        
        preds_regular = self.regular.predictions(station)
        preds_q70 = self.q70.predictions(station)        
        preds_q95 = self.q95.predictions(station)
        
        test_array = self.window.test_windows(station)   

        new_pred=[]
        
        for pred_switch, pred_regular, pred_q70, pred_q95 in zip(preds_switch, preds_regular, preds_q70, preds_q95):

                
            switch_condition = pred_switch > 0.95
            q95_condition = pred_switch > 0.7
            q70_condition = pred_switch <= 0.7  # You might want to specify this condition differently

            new_pred.append(np.where(switch_condition, pred_q95, np.where(q95_condition, pred_q70, pred_regular)))
                
        return np.array(new_pred)
        

    def test_MSE(self, station=None):
        preds = self.predictions(data='test', station=station)
        test_array = self.window.test_array(station)[self.window.input_width:]

        return mean_squared_error(test_array, preds)
    
    def test_ROCAUC(self, station, level=0.05):
        preds = self.predictions(data='test', station=station)
        test_array = (self.window.test_array(station)[self.window.input_width:] < level).astype(int)
        
        return roc_auc_score(test_array, preds)

    def summary(self, station=None):
        summary_dict = {}
        
        summary_dict['input_width'] = self.window.input_width
        summary_dict['label_width'] = self.window.label_width
        
        summary_dict['station'] = station

        summary_dict['NSE'] = self.get_NSE(station)       
                  
        summary_dict['SER_1%'] = self.average_model_error(station, cut=1)
        summary_dict['SER_2%'] = self.average_model_error(station, cut=2)    
        summary_dict['SER_5%'] = self.average_model_error(station, cut=5)        
        summary_dict['SER_10%'] = self.average_model_error(station, cut=10)  
        summary_dict['SER_25%'] = self.average_model_error(station, cut=25)  
        summary_dict['SER_50%'] = self.average_model_error(station, cut=50)  
        summary_dict['SER_75%'] = self.average_model_error(station, cut=75)  
        summary_dict['RMSE'] = self.average_model_error(station, cut=100)
        
        return summary_dict  

### Stage- 3 Quantile LSTM Model for All States

In [18]:
# Choose the state you wish to run the quantile lstm model on
selected_stations = list(camels_data.summary_data[camels_data.summary_data['state_outlet'] == 'SA'].index)

In [None]:
combined=[]
for i in range(0,30):
    print('RUN',i)
    results_switch=[]
    variable_ts = ['streamflow_MLd_inclInfilled', 'precipitation_deficit', 'year_sin', 'year_cos', 'tmax_AWAP', 'tmin_AWAP']
    variable_ts_switch = ['flood_probabilities', 'precipitation_deficit', 'year_sin', 'year_cos', 'tmax_AWAP', 'tmin_AWAP']

    variable_static = ['q_mean', 'stream_elas', 'runoff_ratio', 'high_q_freq', 'high_q_dur', 'low_q_freq', 'zero_q_freq']

    train_df, test_df = camels_data.get_train_val_test(source=variable_ts, stations=selected_stations)

    multi_window = MultiWindow(input_width=5,
                               label_width=5,
                               shift=5,
                               train_df=train_df,
                               test_df=test_df,
                               stations=selected_stations,
                               label_columns=['streamflow_MLd_inclInfilled'])

    np_window = MultiNumpyWindow(input_width=5, 
                                 label_width=5,
                                 shift=5,
                                 timeseries_source=variable_ts_switch,
                                 summary_source=variable_static,
                                 summary_data=camels_data.summary_data,
                                 stations=selected_stations,
                                 label_columns=['flood_probabilities'])

    model_switch = Switch_Model(window_switch=np_window, window_regular=multi_window, CONV_WIDTH=5) 

    for station in selected_stations:
                results_switch.append(model_switch.summary(station))
    
    Switch_SA= pd.DataFrame(results_switch)
    Switch_SA= Switch_SA.mean()
    Switch_SA= Switch_SA.to_dict()
    combined.append(Switch_SA)

In [23]:
Switch_SA_bdlstm = pd.DataFrame(combined)
Switch_SA_bdlstm.to_csv('Quantilelstm_SA.csv')

### Visualizations of Quantile LSTM Model

In [None]:
# Get visualizations of predicted vs actual for quantile lstm model
# Enter the station name and the horizons as per need 

import matplotlib.patches as mpatches
fig = plt.figure(figsize=(20,5))
plt.title('SA-A5030502', fontsize= 20)
plt.ylabel('Flood Probability', fontsize=18)
plt.ylim(0, 1.1)

plt.rcParams.update({'font.size': 15})

df_date= np_window.test_df['A5030502'].reset_index()
date_values= df_date['date']

s=800
e=1600

ax1 = plt.plot(date_values[s:e], model_switch.predictions('A5030502')[:,0][s:e], color='blue')
ax2 = plt.plot(date_values[s:e], multi_window.test_windows('A5030502')[:,0][s:e], color='red')

red_patch = mpatches.Patch(color='red', label='Actual')
blue_patch = mpatches.Patch(color='blue', label='Predicted')

plt.legend(handles=[red_patch, blue_patch])
plt.show()