In [1]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(os.path.join(module_path, 'scripts'))

In [2]:
import importlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

sns.set_theme()

import utils

In [3]:
utils.set_project_dir()

In [4]:
os.getcwd()

'\\\\export.hpc.ut.ee\\gis\\flow_swat_ml_paper'

In [5]:
country_codes = ['ESP', 'EST', 'ETH', 'USA']
time_interval = 'd'
target = f'Q_{time_interval}+1'
feat_set = f'FS3_{time_interval}'
test_size = 0.5
test_size_int = int(test_size * 100)

# Read RF metrics
rf_metrics = pd.read_csv(f'ml/{target}_rf_metrics.csv')

# Output Excel file
out_fp = f'ml/{target}_{feat_set}_rf_vs_swat.xlsx'
if os.path.exists(out_fp):
    os.remove(out_fp)

for i in range(len(country_codes)):

    country_code = country_codes[i]

    # Read SWAT results
    excel_file = 'swat/SWAT_results.xlsx'
    sheet_name = f'{country_code}_{time_interval.upper()}'
    swat_results = pd.read_excel(excel_file, sheet_name=sheet_name)
    swat_results = swat_results.rename(
        columns={swat_results.columns[0]: 'Date', swat_results.columns[1]: 'Observed', swat_results.columns[2]: 'SWAT'}
    )

    # Shift by one day to match with RF time series
    swat_results['SWAT'] = swat_results['SWAT'].shift(-1)
    swat_results['Observed'] = swat_results['Observed'].shift(-1)

    model_dir = utils.get_model_dir(country_code, target, feat_set)

    # Get indices of training samples
    train_indices = pd.read_csv(
        f'{model_dir}/{country_code}_{target}_{feat_set}_feat_train_{test_size_int}.csv', usecols=['Index']
    )['Index'].values

    # Get indices of test samples
    test_indices = pd.read_csv(
        f'{model_dir}/{country_code}_{target}_{feat_set}_feat_test_{test_size_int}.csv', usecols=['Index']
    )['Index'].values

    # Read RF results
    obs_vs_pred = pd.read_csv(f'{model_dir}/{country_code}_{target}_{feat_set}_obs_vs_pred_{test_size_int}.csv', parse_dates=['Date'])
    obs_vs_pred['Index'] = obs_vs_pred.index

    # Training set
    start_train, end_train = utils.get_train_period(country_code)
    obs_vs_pred_train = obs_vs_pred.loc[(start_train <= obs_vs_pred['Date']) & (obs_vs_pred['Date'] <= end_train)]
    obs_vs_pred_train = obs_vs_pred_train.merge(swat_results, how='inner', on='Date')
    obs_vs_pred_train = obs_vs_pred_train.set_index('Date')

    # Test set
    start_test, end_test = utils.get_test_period(country_code)
    obs_vs_pred_test = obs_vs_pred.loc[(start_test <= obs_vs_pred['Date']) & (obs_vs_pred['Date'] <= end_test)]
    obs_vs_pred_test = obs_vs_pred_test.merge(swat_results, how='inner', on='Date')
    obs_vs_pred_test = obs_vs_pred_test.set_index('Date')
    
    # Write to Excel file
    if not os.path.exists(out_fp):
        with pd.ExcelWriter(out_fp) as writer:
            obs_vs_pred_train.to_excel(writer, sheet_name=f'{country_code}_train', encoding='utf-8')
            obs_vs_pred_test.to_excel(writer, sheet_name=f'{country_code}_test', encoding='utf-8')
    else:
        with pd.ExcelWriter(out_fp, mode='a') as writer:
            obs_vs_pred_train.to_excel(writer, sheet_name=f'{country_code}_train', encoding='utf-8')
            obs_vs_pred_test.to_excel(writer, sheet_name=f'{country_code}_test', encoding='utf-8')

the 'encoding' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'encoding'
the 'encoding' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'encoding'
the 'encoding' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'encoding'
the 'encoding' keyword is deprecated and will be removed in a future version. Please take steps to stop the use of 'encoding'
