# Pretraining data from Wind and Solar data for Transfer Learning

## Pre training on Wind Data

In [0]:
import os
import pandas as pd
from google.colab import auth
from datetime import datetime
auth.authenticate_user()
!gcloud source repos clone github_aistream-peelout_flow-forecast --project=gmap-997
os.chdir('/content/github_aistream-peelout_flow-forecast')
!git checkout -t origin/covid_fixes
!python setup.py develop
!pip install -r requirements.txt
!mkdir data
from flood_forecast.trainer import train_function
!pip install git+https://github.com/CoronaWhy/task-geo.git
!wandb login

In [0]:
# Get wind data
!wget https://storage.googleapis.com/coronaviruspublicdata/forecast_2/wind.csv
!wget https://storage.googleapis.com/coronaviruspublicdata/forecast_2/solar.csv

In [0]:
import pandas as pd
wind = pd.read_csv('wind.csv')
wind['datetime'] = pd.to_datetime(wind['time']).dt.date
wind.set_index('datetime', drop=False)
wind.head()

Unnamed: 0,time,AT,BE,BG,CH,CZ,DE,DK,EE,ES,FI,FR,EL,HR,HU,IE,IT,LT,LU,LV,NL,NO,PL,PT,RO,SI,SK,SE,UK,datetime
0,1986-01-01 00:00:00,0.047786,0.02302,0.04894,0.065907,0.041685,0.031583,0.017365,0.014149,0.079043,0.0058,0.049822,0.051933,0.030507,0.025005,0.011889,0.04625,0.03588,0.014839,0.019004,0.014293,0.010351,0.029919,0.076675,0.029107,0.015193,0.054001,0.017463,0.030419,1986-01-01
1,1986-01-02 00:00:00,0.045921,0.036297,0.067995,0.077502,0.026427,0.023506,0.014981,0.015682,0.119019,0.007176,0.06309,0.115133,0.035716,0.039431,0.016575,0.051848,0.016988,0.01951,0.013771,0.020373,0.006469,0.031359,0.1069,0.044379,0.024623,0.034362,0.008086,0.022146,1986-01-02
2,1986-01-03 00:00:00,0.067308,0.021352,0.101287,0.10368,0.057274,0.046181,0.023478,0.00957,0.106574,0.004687,0.048678,0.123855,0.051901,0.029249,0.054734,0.062773,0.01735,0.056217,0.011871,0.010782,0.007217,0.027554,0.160308,0.047235,0.032093,0.023788,0.010004,0.060345,1986-01-03
3,1986-01-04 00:00:00,0.043833,0.050756,0.039337,0.075418,0.025843,0.025011,0.020003,0.008595,0.13506,0.004102,0.092991,0.089767,0.055547,0.04434,0.016779,0.055305,0.019638,0.055925,0.013604,0.030366,0.007998,0.025986,0.208236,0.03751,0.028663,0.018115,0.009546,0.030981,1986-01-04
4,1986-01-05 00:00:00,0.082394,0.014302,0.033055,0.090867,0.065186,0.028168,0.016261,0.00978,0.095232,0.005172,0.045049,0.074312,0.081576,0.082401,0.038972,0.102499,0.020079,0.018873,0.013913,0.012728,0.007241,0.047764,0.115451,0.037254,0.057101,0.072843,0.013872,0.023346,1986-01-05


In [0]:
# Getting the real countries' name from 2 letter code

!pip install pycountry
import pycountry
names = {}
for code in wind.columns:
    try:
        names[code] = pycountry.countries.get(alpha_2=code).name
    except:
        print(code)

# For some reason, these two were not present
names['EL'] = 'Greece'
names['UK'] = 'United Kingdom'

Collecting pycountry
[?25l  Downloading https://files.pythonhosted.org/packages/16/b6/154fe93072051d8ce7bf197690957b6d0ac9a21d51c9a1d05bd7c6fdb16f/pycountry-19.8.18.tar.gz (10.0MB)
[K     |████████████████████████████████| 10.0MB 2.7MB/s 
[?25hBuilding wheels for collected packages: pycountry
  Building wheel for pycountry (setup.py) ... [?25l[?25hdone
  Created wheel for pycountry: filename=pycountry-19.8.18-py2.py3-none-any.whl size=10627361 sha256=cc9554e8cf6849198cb62e2d68fa6b20b9df33d86087e0184e32b029a6370e70
  Stored in directory: /root/.cache/pip/wheels/a2/98/bf/f0fa1c6bf8cf2cbdb750d583f84be51c2cd8272460b8b36bd3
Successfully built pycountry
Installing collected packages: pycountry
Successfully installed pycountry-19.8.18
time
EL
UK
datetime


In [0]:
wind.rename(columns = names, inplace=True)
wind['year'] = pd.to_datetime(wind['time']).map(lambda x: x.year)
wind['month'] = pd.to_datetime(wind['time']).map(lambda x: x.month)
wind['weekday'] = pd.to_datetime(wind['time']).map(lambda x: x.weekday())

In [0]:
# Making seperate dataframes for each country's data and saving in seperate CSV files
!mkdir wind
country_wise = {}
for country in names.values():
    country_wise[country] = wind[['datetime', 'year', 'month', 'weekday', country]]
    country_wise[country].to_csv('wind/'+country+'.csv')
country_wise[list(country_wise.keys())[5]].head()

Unnamed: 0,datetime,year,month,weekday,Germany
0,1986-01-01,1986,1,2,0.031583
1,1986-01-02,1986,1,3,0.023506
2,1986-01-03,1986,1,4,0.046181
3,1986-01-04,1986,1,5,0.025011
4,1986-01-05,1986,1,6,0.028168


In [0]:
wind.tail()

Unnamed: 0,time,Austria,Belgium,Bulgaria,Switzerland,Czechia,Germany,Denmark,Estonia,Spain,Finland,France,Greece,Croatia,Hungary,Ireland,Italy,Lithuania,Luxembourg,Latvia,Netherlands,Norway,Poland,Portugal,Romania,Slovenia,Slovakia,Sweden,United Kingdom,datetime,year,month,weekday
10952,2015-12-27 00:00:00,0.119231,0.077324,0.153964,0.134614,0.081821,0.055697,0.0151,0.038722,0.127358,0.01306,0.102966,0.163328,0.108146,0.044005,0.017154,0.112251,0.014419,0.092283,0.030101,0.01993,0.004189,0.034805,0.170225,0.136308,0.127401,0.047375,0.008486,0.028426,2015-12-27,2015,12,6
10953,2015-12-28 00:00:00,0.128963,0.087688,0.151576,0.125772,0.054648,0.063808,0.024135,0.040206,0.103774,0.011533,0.089861,0.16448,0.096811,0.063488,0.010358,0.121303,0.010499,0.091828,0.036198,0.067833,0.00463,0.023683,0.059809,0.118209,0.12972,0.048307,0.008479,0.020515,2015-12-28,2015,12,0
10954,2015-12-29 00:00:00,0.094872,0.028177,0.070883,0.121686,0.056463,0.054009,0.006732,0.029413,0.126492,0.008377,0.077485,0.170938,0.071934,0.074617,0.011983,0.119,0.017571,0.031521,0.012752,0.042996,0.004232,0.039898,0.173905,0.033584,0.064835,0.057409,0.011607,0.05522,2015-12-29,2015,12,1
10955,2015-12-30 00:00:00,0.086713,0.05481,0.100528,0.114044,0.097525,0.041072,0.006402,0.01474,0.123499,0.00381,0.091785,0.064106,0.076847,0.132149,0.027625,0.108127,0.027267,0.014551,0.031294,0.057475,0.003595,0.067484,0.065125,0.106734,0.026071,0.134139,0.009297,0.011749,2015-12-30,2015,12,2
10956,2015-12-31 00:00:00,0.111772,0.095547,0.115314,0.031523,0.122601,0.029161,0.012323,0.01921,0.092276,0.004629,0.067331,0.077266,0.108809,0.131938,0.022887,0.081279,0.02249,0.078602,0.018948,0.071013,0.003991,0.091865,0.087822,0.132066,0.114436,0.131965,0.015428,0.045504,2015-12-31,2015,12,3


In [0]:
# Config file for WanDB sweeps

def make_config_file(file_path, df_len):
  run = wandb.init(project="pretrain-wind")
  wandb_config = wandb.config
  train_number = df_len * .7
  validation_number = df_len *.9
  config_default={                 
    "model_name": "MultiAttnHeadSimple",
    "model_type": "PyTorch",
    "model_params": {
      "number_time_series":4,
      "seq_len":wandb_config["forecast_history"], 
      "output_seq_len":wandb_config["out_seq_length"],
      "forecast_length":wandb_config["out_seq_length"]
     },
    "dataset_params":
    {  "class": "default",
       "training_path": file_path,
       "validation_path": file_path,
       "test_path": file_path,
       "batch_size":wandb_config["batch_size"],
       "forecast_history":wandb_config["forecast_history"],
       "forecast_length":wandb_config["out_seq_length"],
       "train_end": int(train_number),
       "valid_start":int(train_number+1),
       "valid_end": int(validation_number),
       "target_col": [file_path.split('.')[0].split('/')[1]],
       "relevant_cols": [file_path.split('.')[0].split('/')[1], "month", "weekday", "year"],
       "scaler": "StandardScaler", 
       "interpolate": False
    },
    "training_params":
    {
       "criterion":"MSE",
       "optimizer": "Adam",
       "optim_params":
       {

       },
       "lr": wandb_config["lr"],
       "epochs": 10,
       "batch_size":wandb_config["batch_size"]
    
    },
    "GCS": False,
    
    "sweep":True,
    "wandb":False,
    "forward_params":{},
   "metrics":["MSE"],
   "inference_params":
   {     
         "datetime_start":"2010-01-01",
          "hours_to_forecast":2000, 
          "test_csv_path":file_path,
          "decoder_params":{
              "decoder_function": "simple_decode", 
            "unsqueeze_dim": 1
          },
          "dataset_params":{
             "file_path": file_path,
             "forecast_history":wandb_config["forecast_history"],
             "forecast_length":wandb_config["out_seq_length"],
             "relevant_cols": [file_path.split('.')[0].split('/')[1], "month", "weekday", "year"],
             "target_col": [file_path.split('.')[0].split('/')[1]],
             "scaling": "StandardScaler",
             "interpolate_param": False
          }
      }
  }
  wandb.config.update(config_default)
  return config_default

sweep_config = {
  "name": "Default sweep",
  "method": "grid",
  "parameters": {
        "batch_size": {
            "values": [2]
        },
        "lr":{
            "values":[0.001]
        },
        "forecast_history":{
            "values":[1, 2]
        },
        "out_seq_length":{
            "values":[1, 2]
        }
    }
}

Run sweep

In [0]:
import os
import pandas as pd

# The countries we will be pretraining our wind data on
os.listdir('wind')

['Lithuania.csv',
 'Bulgaria.csv',
 'Ireland.csv',
 'Netherlands.csv',
 'Slovenia.csv',
 'Croatia.csv',
 'Spain.csv',
 'Germany.csv',
 'Romania.csv',
 'Greece.csv',
 'Hungary.csv',
 'Portugal.csv',
 'Belgium.csv',
 'Austria.csv',
 'Norway.csv',
 'Poland.csv',
 'United Kingdom.csv',
 'Latvia.csv',
 'France.csv',
 'Denmark.csv',
 'Switzerland.csv',
 'Finland.csv',
 'Estonia.csv',
 'Italy.csv',
 'Luxembourg.csv',
 'Sweden.csv',
 'Slovakia.csv',
 'Czechia.csv']

In [0]:
import wandb
for country in os.listdir('wind'):
    file_path = 'wind/'+country
    full_len = len(pd.read_csv(file_path))
    sweep_id = wandb.sweep(sweep_config, project="pretrain-wind")
    wandb.agent(sweep_id, lambda:train_function("PyTorch", make_config_file(file_path, full_len)))
    !gsutil cp -n -r model_save gs://coronaviruspublicdata/pretrained

Believe me, the above cell executed. It just took... 13 hours to do so.