In [None]:
import os
import plotly.express as px
import pandas as pd
import numpy as np
import json

from utils import *
from train import *
from evaluation import *
import wandb
wandb.login()

import warnings
warnings.filterwarnings('ignore')

# Set seed
np.random.seed(42)

# Set working directory
os.chdir(r"..") # should be the git repo root directory, checking below:
print("Current working directory: " + os.getcwd())
assert os.getcwd()[-8:] == "WattCast"
dir_path = os.path.join(os.getcwd(), 'data', 'clean_data')
model_dir = os.path.join(os.getcwd(), 'models')



In [None]:
def train_eval_tuning():

    wandb.init(project="WattCast_tuning")
    wandb.config.update(config_run)
    config = wandb.config

    print("Getting data...")

    pipeline, ts_train_piped, ts_val_piped, ts_test_piped, ts_train_weather_piped, ts_val_weather_piped, ts_test_weather_piped, trg_train_inversed, trg_val_inversed, trg_test_inversed = data_pipeline(config)

    print("Getting model instance...")
    model_instance = get_model(config)
    model_instance, _ = train_models([model_instance], ts_train_piped, ts_train_weather_piped, ts_val_piped, ts_val_weather_piped)

    print("Evaluating model...")
    predictions, score = predict_testset(model_instance[0], 
                                  ts_test_piped[config.longest_ts_test_idx], 
                                  ts_test_weather_piped[config.longest_ts_test_idx],
                                  config.n_lags, config.n_ahead, config.eval_stride, pipeline,
                                  )


    print("Plotting predictions...")
    df_compare = pd.concat([trg_test_inversed.pd_dataframe(), predictions], axis=1).dropna()
    df_compare.columns = ['target', 'prediction']
    fig = px.line(df_compare, title='Predictions vs. Test Set')

    wandb.log({'eval_loss': score})
    wandb.log({'predictions': fig})
    wandb.finish()



### Running

In [3]:
# run parameters

sweeps = 20

scale_location_pairs = (
    # ('1_county', 'Sacramento'),
    # ('1_county', 'New_York'),
    ('1_county', 'Los_Angeles'),
    #('2_town', 'town_0'),
    # ('2_town', 'town_1'),
    # ('2_town', 'town_2'),
    # ('3_village', 'village_1'),
    # ('3_village', 'village_2'),
    #('3_village', 'village_0'),
    # ('4_neighborhood', 'neighborhood_0'),
    # ('4_neighborhood', 'neighborhood_1'),
    # ('4_neighborhood', 'neighborhood_2'),
    # ('5_building', 'building_0'),
    # ('5_building', 'building_1'),
    #('5_building', 'building_2'),
      )



models = [
        #'rf',
        'xgb', 
        # 'gru', 
        #'lgbm',  
        # 'nbeats',
        #'tft'
        ]


for scale, location in scale_location_pairs:

    for model in models:
        # placeholder initialization of config file (will be updated in train_eval_light())
        config_run = {
            'spatial_scale': scale,
            'temp_resolution': 60,
            'location': location,
            'model': model,
            'horizon_in_hours': 24,
            'lookback_in_hours': 24,
            'boxcox': True,
            'liklihood': None,
            'weather': True,
            'holiday': True,
            'datetime_encodings': False,
        }
        
        with open(f'sweep_configurations/config_sweep_{model}.json', 'r') as fp:
            sweep_config = json.load(fp)                  

        sweep_config['name'] = model + 'sweep' + config_run['spatial_scale'] + '_' + config_run['location'] + '_' + str(config_run['temp_resolution'])

        sweep_id = wandb.sweep(sweep_config, project="WattCast_tuning")
        wandb.agent(sweep_id, train_eval_tuning, count=sweeps)


[48]	validation_0-rmse:0.18002
[49]	validation_0-rmse:0.17983
[50]	validation_0-rmse:0.17986
[51]	validation_0-rmse:0.17973
[52]	validation_0-rmse:0.17985
[53]	validation_0-rmse:0.18016
[54]	validation_0-rmse:0.18031
[55]	validation_0-rmse:0.18039
[56]	validation_0-rmse:0.18011
[57]	validation_0-rmse:0.17965
[58]	validation_0-rmse:0.17973
[59]	validation_0-rmse:0.17959
[60]	validation_0-rmse:0.18138
[61]	validation_0-rmse:0.18151
[62]	validation_0-rmse:0.18156
[63]	validation_0-rmse:0.18141
[64]	validation_0-rmse:0.18145
[65]	validation_0-rmse:0.18132
[66]	validation_0-rmse:0.18158
[67]	validation_0-rmse:0.18095
[68]	validation_0-rmse:0.18082
[69]	validation_0-rmse:0.17987
[70]	validation_0-rmse:0.17989
[71]	validation_0-rmse:0.18142
[72]	validation_0-rmse:0.18136
[73]	validation_0-rmse:0.18137
[74]	validation_0-rmse:0.18150
[75]	validation_0-rmse:0.18161
[76]	validation_0-rmse:0.18160
[77]	validation_0-rmse:0.18174
[78]	validation_0-rmse:0.18179
[79]	validation_0-rmse:0.18111
[80]	val

2017-08-16 23:00:00 2017-08-17 00:00:00
2017-08-17 23:00:00 2017-08-18 00:00:00
2017-08-18 23:00:00 2017-08-19 00:00:00
2017-08-19 23:00:00 2017-08-20 00:00:00
2017-08-20 23:00:00 2017-08-21 00:00:00
2017-08-21 23:00:00 2017-08-22 00:00:00
2017-08-22 23:00:00 2017-08-23 00:00:00
2017-08-23 23:00:00 2017-08-24 00:00:00
2017-08-24 23:00:00 2017-08-25 00:00:00
2017-08-25 23:00:00 2017-08-26 00:00:00
2017-08-26 23:00:00 2017-08-27 00:00:00
2017-08-27 23:00:00 2017-08-28 00:00:00
2017-08-28 23:00:00 2017-08-29 00:00:00
2017-08-29 23:00:00 2017-08-30 00:00:00
2017-08-30 23:00:00 2017-08-31 00:00:00
2017-08-31 23:00:00 2017-09-01 00:00:00
2017-09-01 23:00:00 2017-09-02 00:00:00
2017-09-02 23:00:00 2017-09-03 00:00:00
2017-09-03 23:00:00 2017-09-04 00:00:00
2017-09-04 23:00:00 2017-09-05 00:00:00
2017-09-05 23:00:00 2017-09-06 00:00:00
2017-09-06 23:00:00 2017-09-07 00:00:00
2017-09-07 23:00:00 2017-09-08 00:00:00
2017-09-08 23:00:00 2017-09-09 00:00:00
2017-09-09 23:00:00 2017-09-10 00:00:00


VBox(children=(Label(value='0.066 MB of 0.066 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

### Degbugging