In [None]:
import os
import plotly.express as px
import pandas as pd
import numpy as np
import json

from utils import *
from train import *
from evaluation import *
import wandb
wandb.login()

import warnings
warnings.filterwarnings('ignore')

# Set seed
np.random.seed(42)

# Set working directory
os.chdir(r"..") # should be the git repo root directory, checking below:
print("Current working directory: " + os.getcwd())
assert os.getcwd()[-8:] == "WattCast"
dir_path = os.path.join(os.getcwd(), 'data', 'clean_data')
model_dir = os.path.join(os.getcwd(), 'models')



In [None]:
def train_eval_tuning():

    wandb.init(project="WattCast_tuning")
    wandb.config.update(config_run)
    config = wandb.config

    print("Getting data...")

    pipeline, ts_train_piped, ts_val_piped, ts_test_piped, ts_train_weather_piped, ts_val_weather_piped, ts_test_weather_piped, trg_train_inversed, trg_val_inversed, trg_test_inversed = data_pipeline(config)

    print("Getting model instance...")
    model_instance = get_model(config)
    model_instance, _ = train_models([model_instance], ts_train_piped, ts_train_weather_piped, ts_val_piped, ts_val_weather_piped)

    print("Evaluating model...")
    predictions, score = predict_testset(model_instance[0], 
                                  ts_test_piped[config.longest_ts_test_idx], 
                                  ts_test_weather_piped[config.longest_ts_test_idx],
                                  config.n_lags, config.n_ahead, config.eval_stride, pipeline,
                                  )


    print("Plotting predictions...")
    df_compare = pd.concat([trg_test_inversed.pd_dataframe(), predictions], axis=1).dropna()
    df_compare.columns = ['target', 'prediction']
    fig = px.line(df_compare, title='Predictions vs. Test Set')

    wandb.log({'eval_loss': score})
    wandb.log({'predictions': fig})
    wandb.finish()



### Running

In [None]:
#we keep the number of sweeps constant to ensure a fair playing field for all algorithms
sweeps = 20 # number of sweeps per model, per location: 

scale_location_pairs = (
    # ('1_county', 'Sacramento'),
    # ('1_county', 'New_York'),
    ('1_county', 'Los_Angeles'),
    #('2_town', 'town_0'),
    # ('2_town', 'town_1'),
    # ('2_town', 'town_2'),
    # ('3_village', 'village_1'),
    # ('3_village', 'village_2'),
    #('3_village', 'village_0'),
    # ('4_neighborhood', 'neighborhood_0'),
    # ('4_neighborhood', 'neighborhood_1'),
    # ('4_neighborhood', 'neighborhood_2'),
    # ('5_building', 'building_0'),
    # ('5_building', 'building_1'),
    #('5_building', 'building_2'),
      )



models = [
        #'rf',
        'xgb', 
        #'gru', 
        #'lgbm',  
        #'nbeats',
        #'tft'
        ]


for scale, location in scale_location_pairs:

    for model in models:
        # placeholder initialization of config file (will be updated in train_eval_light())
        config_run = {
            'spatial_scale': scale,
            'temp_resolution': 60,
            'location': location,
            'model': model,
            'horizon_in_hours': 24,
            'lookback_in_hours': 24,
            'boxcox': True,
            'liklihood': None,
            'weather': True,
            'holiday': True,
            'datetime_encodings': False,
            'heat_wave_binary': False,
        }
        
        with open(f'sweep_configurations/config_sweep_{model}.json', 'r') as fp:
            sweep_config = json.load(fp)                  

        sweep_config['name'] = model + 'sweep' + config_run['spatial_scale'] + '_' + config_run['location'] + '_' + str(config_run['temp_resolution'])

        sweep_id = wandb.sweep(sweep_config, project="WattCast_tuning")
        wandb.agent(sweep_id, train_eval_tuning, count=sweeps)


[107]	validation_0-rmse:0.01867
[108]	validation_0-rmse:0.01867
[109]	validation_0-rmse:0.01870
[110]	validation_0-rmse:0.01871
[111]	validation_0-rmse:0.01871
[112]	validation_0-rmse:0.01871
[113]	validation_0-rmse:0.01869
[114]	validation_0-rmse:0.01869
[115]	validation_0-rmse:0.01869
[116]	validation_0-rmse:0.01869
[117]	validation_0-rmse:0.01868
[118]	validation_0-rmse:0.01868
[0]	validation_0-rmse:0.12232
[1]	validation_0-rmse:0.09240
[2]	validation_0-rmse:0.07308
[3]	validation_0-rmse:0.06168
[4]	validation_0-rmse:0.05462
[5]	validation_0-rmse:0.05026
[6]	validation_0-rmse:0.04796
[7]	validation_0-rmse:0.04555
[8]	validation_0-rmse:0.04427
[9]	validation_0-rmse:0.04297
[10]	validation_0-rmse:0.04222
[11]	validation_0-rmse:0.04178
[12]	validation_0-rmse:0.04080
[13]	validation_0-rmse:0.04041
[14]	validation_0-rmse:0.04005
[15]	validation_0-rmse:0.03954
[16]	validation_0-rmse:0.03920
[17]	validation_0-rmse:0.03898
[18]	validation_0-rmse:0.03884
[19]	validation_0-rmse:0.03870
[20]	v


[88]	validation_0-rmse:0.03506
[89]	validation_0-rmse:0.03504
[90]	validation_0-rmse:0.03504
[91]	validation_0-rmse:0.03500
[92]	validation_0-rmse:0.03503
[93]	validation_0-rmse:0.03502
[94]	validation_0-rmse:0.03503
[95]	validation_0-rmse:0.03503
[96]	validation_0-rmse:0.03502
[97]	validation_0-rmse:0.03500
[98]	validation_0-rmse:0.03500
[99]	validation_0-rmse:0.03500
[100]	validation_0-rmse:0.03492
[101]	validation_0-rmse:0.03492
[102]	validation_0-rmse:0.03484
[103]	validation_0-rmse:0.03478
[104]	validation_0-rmse:0.03478
[105]	validation_0-rmse:0.03474
[106]	validation_0-rmse:0.03474
[107]	validation_0-rmse:0.03473
[108]	validation_0-rmse:0.03471
[109]	validation_0-rmse:0.03467
[110]	validation_0-rmse:0.03465
[111]	validation_0-rmse:0.03463
[112]	validation_0-rmse:0.03460
[113]	validation_0-rmse:0.03460
[114]	validation_0-rmse:0.03461
[115]	validation_0-rmse:0.03461
[116]	validation_0-rmse:0.03458
[117]	validation_0-rmse:0.03456
[118]	validation_0-rmse:0.03455
[119]	validation_0-

### Degbugging