In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
from datetime import datetime
from typing import Dict, List
import os
import pandas as pd
from pandas import DataFrame, Series
import wandb
import pycountry
from flood_forecast.time_model import PyTorchForecast
from flood_forecast.trainer import train_function
from wandb.wandb_run import Run

In [3]:
wind: DataFrame = pd.read_csv('../data/wind.csv')
wind['datetime']: Series = pd.to_datetime(wind['time']).dt.date
wind.set_index('datetime', drop=True, inplace=True)
wind['time'] = wind['time'].astype('datetime64[s]')
wind.head()

Unnamed: 0_level_0,time,AT,BE,BG,CH,CZ,DE,DK,EE,ES,...,LV,NL,NO,PL,PT,RO,SI,SK,SE,UK
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1986-01-01,1986-01-01,0.047786,0.02302,0.04894,0.065907,0.041685,0.031583,0.017365,0.014149,0.079043,...,0.019004,0.014293,0.010351,0.029919,0.076675,0.029107,0.015193,0.054001,0.017463,0.030419
1986-01-02,1986-01-02,0.045921,0.036297,0.067995,0.077502,0.026427,0.023506,0.014981,0.015682,0.119019,...,0.013771,0.020373,0.006469,0.031359,0.1069,0.044379,0.024623,0.034362,0.008086,0.022146
1986-01-03,1986-01-03,0.067308,0.021352,0.101287,0.10368,0.057274,0.046181,0.023478,0.00957,0.106574,...,0.011871,0.010782,0.007217,0.027554,0.160308,0.047235,0.032093,0.023788,0.010004,0.060345
1986-01-04,1986-01-04,0.043833,0.050756,0.039337,0.075418,0.025843,0.025011,0.020003,0.008595,0.13506,...,0.013604,0.030366,0.007998,0.025986,0.208236,0.03751,0.028663,0.018115,0.009546,0.030981
1986-01-05,1986-01-05,0.082394,0.014302,0.033055,0.090867,0.065186,0.028168,0.016261,0.00978,0.095232,...,0.013913,0.012728,0.007241,0.047764,0.115451,0.037254,0.057101,0.072843,0.013872,0.023346


In [4]:
names: Dict = {}
for code in wind.columns:
    try:
        names[code] = pycountry.countries.get(alpha_2=code).name
    except:
        print(code)

# For some reason, these two were not present
names['EL'] = 'Greece'
names['UK'] = 'United Kingdom'

time
EL
UK


In [5]:
wind.rename(columns = names, inplace=True)
# wind['year'] = pd.to_datetime(wind['time']).map(lambda x: x.year)
wind['month'] = pd.to_datetime(wind['time']).map(lambda x: x.month)
wind['weekday'] = pd.to_datetime(wind['time']).map(lambda x: x.weekday())

In [6]:
wind.to_csv('../data/wind_train.csv', index=True, index_label='datetime')

In [7]:
# Config file for WanDB sweeps

def make_config_file(file_path: str, df_len: int) -> Dict:
    train_number: float = df_len * .7
    validation_number: float = df_len *.9
    config_default={
      "model_name": "DecoderTransformer",
      "model_type": "PyTorch",
      "takes_target": False,
      "model_params": {
      "n_time_series":30,
      "n_head": 8,
      "forecast_history":90,
      "n_embd": 1, 
      "num_layer": 5,
      "dropout":0.1,
      "q_len": 1,
      "scale_att": False,
      "forecast_length": 30, 
      "additional_params":{}
     },
     "dataset_params":
     {
         "class": "default",
          "training_path": file_path,
          "validation_path": file_path,
          "test_path": file_path,
          "batch_size":64,
          "forecast_history":90,
          "forecast_length":30,
          "train_end": int(train_number),
          "valid_start":int(train_number+1),
          "valid_end": int(validation_number),
          "target_col": ['Austria'],
          "relevant_cols": ['Austria', 'Belgium', 'Bulgaria', 'Switzerland', 'Czechia',
                            'Germany', 'Denmark', 'Estonia', 'Spain', 'Finland', 'France', 'Greece',
                            'Croatia', 'Hungary', 'Ireland', 'Italy', 'Lithuania', 'Luxembourg',
                            'Latvia', 'Netherlands', 'Norway', 'Poland', 'Portugal', 'Romania',
                            'Slovenia', 'Slovakia', 'Sweden', 'United Kingdom', 'month', 'weekday'],
          "scaler": "StandardScaler", 
          "interpolate": False,
          "sort_column":"time",
     },
     "training_params":
      {
        "criterion":"DilateLoss",
        "optimizer": "Adam",
        "optim_params":
        {
        },
        "lr": 0.001,
        "epochs": 10,
        "batch_size":64
      },
      "early_stopping": {
          "patience":3
      },
      "GCS": False,
      "sweep":False,
      "wandb":False,
      "forward_params":{},
      "metrics":["DilateLoss"],
      "inference_params":
        {     
              "datetime_start":"2010-01-01",
                "hours_to_forecast": 2000, 
                "test_csv_path":file_path,
                "decoder_params":{
                    "decoder_function": "simple_decode", 
                  "unsqueeze_dim": 1
                },
                "dataset_params":{
                  "file_path": file_path,
                  "forecast_history":90,
                  "forecast_length":30,
                  "relevant_cols": ['Austria', 'Belgium', 'Bulgaria', 'Switzerland', 'Czechia',
                            'Germany', 'Denmark', 'Estonia', 'Spain', 'Finland', 'France', 'Greece',
                            'Croatia', 'Hungary', 'Ireland', 'Italy', 'Lithuania', 'Luxembourg',
                            'Latvia', 'Netherlands', 'Norway', 'Poland', 'Portugal', 'Romania',
                            'Slovenia', 'Slovakia', 'Sweden', 'United Kingdom', 'month', 'weekday'],
                  "target_col": ['Austria'],
                  "scaling": "StandardScaler",
                  "interpolate_param": False
                }
          },
    }

    return config_default

In [8]:
file_path: str = '../data/wind_train.csv'
full_len: int = len(pd.read_csv(file_path))

In [9]:
conf_file: Dict = make_config_file(file_path, full_len)

In [10]:
run = wandb.init(project="pretrained-wind-updated")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mloloheia[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
trained_model: PyTorchForecast = train_function("PyTorch", make_config_file(file_path, full_len))

interpolate should be below
[]
Now loading ../data/wind_train.csv
scaling now
interpolate should be below
[]
Now loading ../data/wind_train.csv
scaling now
interpolate should be below
[]
Now loading ../data/wind_train.csv
scaling now
Using Wandb config:
{}
Torch is using cuda
running torch_single_train


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
Using non-full backward hooks on a Module that does not take as input a single Tensor or a tuple of Tensors is deprecated and will be removed in future versions. This hook will be missing some of the grad_input. Please use register_full_backward_hook to get the documented behavior.
Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).
To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach()

The running loss is: 
746.5062863826752
The number of items in train is: 118
The loss for epoch 0
6.326324460870128
Computing validation loss
running torch_single_train
The running loss is: 
557.404262304306
The number of items in train is: 118
The loss for epoch 1
4.723764934782254
Computing validation loss
running torch_single_train
The running loss is: 
562.8397541046143
The number of items in train is: 118
The loss for epoch 2
4.769828424615375
Computing validation loss
running torch_single_train
The running loss is: 
539.7109055519104
The number of items in train is: 118
The loss for epoch 3
4.573821233490766
Computing validation loss
running torch_single_train
The running loss is: 
544.1412591934204
The number of items in train is: 118
The loss for epoch 4
4.611366603334071
Computing validation loss
1
running torch_single_train
The running loss is: 
532.4913876056671
The number of items in train is: 118
The loss for epoch 5
4.512638878014128
Computing validation loss
running torc