In [None]:
import sys
sys.path.insert(0, '..')

In [None]:
from datetime import datetime
from typing import Dict, List
import os
import pandas as pd
from pandas import DataFrame, Series
import wandb
import pycountry
from flood_forecast.time_model import PyTorchForecast
from flood_forecast.trainer import train_function
from wandb.wandb_run import Run

In [None]:
wind: DataFrame = pd.read_csv('../data/wind.csv')
wind['datetime']: Series = pd.to_datetime(wind['time']).dt.date
wind.set_index('datetime', drop=True, inplace=True)
wind['time'] = wind['time'].astype('datetime64[s]')
wind.head()

In [None]:
names: Dict = {}
for code in wind.columns:
    try:
        names[code] = pycountry.countries.get(alpha_2=code).name
    except:
        print(code)

# For some reason, these two were not present
names['EL'] = 'Greece'
names['UK'] = 'United Kingdom'

In [None]:
wind.rename(columns = names, inplace=True)
# wind['year'] = pd.to_datetime(wind['time']).map(lambda x: x.year)
wind['month'] = pd.to_datetime(wind['time']).map(lambda x: x.month)
wind['weekday'] = pd.to_datetime(wind['time']).map(lambda x: x.weekday())

In [None]:
wind.to_csv('../data/wind_train.csv', index=True, index_label='datetime')

In [None]:
# Config file for WanDB sweeps

def make_config_file(file_path: str, df_len: int) -> Dict:
    train_number: float = df_len * .7
    validation_number: float = df_len *.9
    config_default={
      "model_name": "DecoderTransformer",
      "model_type": "PyTorch",
      "takes_target": False,
      "model_params": {
      "n_time_series":30,
      "n_head": 8,
      "forecast_history":90,
      "n_embd": 1, 
      "num_layer": 5,
      "dropout":0.1,
      "q_len": 1,
      "scale_att": False,
      "forecast_length": 30, 
      "additional_params":{}
     },
     "dataset_params":
     {
         "class": "default",
          "training_path": file_path,
          "validation_path": file_path,
          "test_path": file_path,
          "batch_size":64,
          "forecast_history":90,
          "forecast_length":30,
          "train_end": int(train_number),
          "valid_start":int(train_number+1),
          "valid_end": int(validation_number),
          "target_col": ['Austria'],
          "relevant_cols": ['Austria', 'Belgium', 'Bulgaria', 'Switzerland', 'Czechia',
                            'Germany', 'Denmark', 'Estonia', 'Spain', 'Finland', 'France', 'Greece',
                            'Croatia', 'Hungary', 'Ireland', 'Italy', 'Lithuania', 'Luxembourg',
                            'Latvia', 'Netherlands', 'Norway', 'Poland', 'Portugal', 'Romania',
                            'Slovenia', 'Slovakia', 'Sweden', 'United Kingdom', 'month', 'weekday'],
          "scaler": "StandardScaler", 
          "interpolate": False,
          "sort_column":"time",
     },
     "training_params":
      {
        "criterion":"DilateLoss",
        "optimizer": "Adam",
        "optim_params":
        {
        },
        "lr": 0.001,
        "epochs": 2,
        "batch_size":64
      },
      "early_stopping": {
          "patience":3
      },
      "GCS": False,
      "sweep":False,
      "wandb":False,
      "forward_params":{},
      "metrics":["RMSE", "MAPE"],
      "inference_params":
        {     
              "datetime_start":"2010-01-01",
                "hours_to_forecast": 2000, 
                "test_csv_path":file_path,
                "decoder_params":{
                    "decoder_function": "simple_decode", 
                  "unsqueeze_dim": 1
                },
                "dataset_params":{
                  "file_path": file_path,
                  "forecast_history":90,
                  "forecast_length":30,
                  "relevant_cols": ['Austria', 'Belgium', 'Bulgaria', 'Switzerland', 'Czechia',
                            'Germany', 'Denmark', 'Estonia', 'Spain', 'Finland', 'France', 'Greece',
                            'Croatia', 'Hungary', 'Ireland', 'Italy', 'Lithuania', 'Luxembourg',
                            'Latvia', 'Netherlands', 'Norway', 'Poland', 'Portugal', 'Romania',
                            'Slovenia', 'Slovakia', 'Sweden', 'United Kingdom', 'month', 'weekday'],
                  "target_col": ['Austria'],
                  "scaling": "StandardScaler",
                  "interpolate_param": False
                }
          },
    }

    return config_default

In [None]:
file_path: str = '../data/wind_train.csv'
full_len: int = len(pd.read_csv(file_path))

In [None]:
conf_file: Dict = make_config_file(file_path, full_len)

In [None]:
run = wandb.init(project="pretrained-wind-updated")

In [None]:
trained_model: PyTorchForecast = train_function("PyTorch", make_config_file(file_path, full_len))