# Pretraining data from Wind and Solar data for Transfer Learning

In [0]:
import os
import pandas as pd
from google.colab import auth
from datetime import datetime
auth.authenticate_user()
!gcloud source repos clone github_aistream-peelout_flow-forecast --project=gmap-997
os.chdir('/content/github_aistream-peelout_flow-forecast')
!git checkout -t origin/covid_fixes
!python setup.py develop
!pip install -r requirements.txt
!mkdir data
from flood_forecast.trainer import train_function
!pip install git+https://github.com/CoronaWhy/task-geo.git
!wandb login

In [0]:
# Get wind data
!wget https://storage.googleapis.com/coronaviruspublicdata/forecast_2/wind.csv
!wget https://storage.googleapis.com/coronaviruspublicdata/forecast_2/solar.csv

## Pre training on Wind Data

In [4]:
import pandas as pd
wind = pd.read_csv('wind.csv')
wind['datetime'] = pd.to_datetime(wind['time']).dt.date
wind.set_index('datetime', drop=False)
wind.head()

Unnamed: 0,time,AT,BE,BG,CH,CZ,DE,DK,EE,ES,FI,FR,EL,HR,HU,IE,IT,LT,LU,LV,NL,NO,PL,PT,RO,SI,SK,SE,UK,datetime
0,1986-01-01 00:00:00,0.047786,0.02302,0.04894,0.065907,0.041685,0.031583,0.017365,0.014149,0.079043,0.0058,0.049822,0.051933,0.030507,0.025005,0.011889,0.04625,0.03588,0.014839,0.019004,0.014293,0.010351,0.029919,0.076675,0.029107,0.015193,0.054001,0.017463,0.030419,1986-01-01
1,1986-01-02 00:00:00,0.045921,0.036297,0.067995,0.077502,0.026427,0.023506,0.014981,0.015682,0.119019,0.007176,0.06309,0.115133,0.035716,0.039431,0.016575,0.051848,0.016988,0.01951,0.013771,0.020373,0.006469,0.031359,0.1069,0.044379,0.024623,0.034362,0.008086,0.022146,1986-01-02
2,1986-01-03 00:00:00,0.067308,0.021352,0.101287,0.10368,0.057274,0.046181,0.023478,0.00957,0.106574,0.004687,0.048678,0.123855,0.051901,0.029249,0.054734,0.062773,0.01735,0.056217,0.011871,0.010782,0.007217,0.027554,0.160308,0.047235,0.032093,0.023788,0.010004,0.060345,1986-01-03
3,1986-01-04 00:00:00,0.043833,0.050756,0.039337,0.075418,0.025843,0.025011,0.020003,0.008595,0.13506,0.004102,0.092991,0.089767,0.055547,0.04434,0.016779,0.055305,0.019638,0.055925,0.013604,0.030366,0.007998,0.025986,0.208236,0.03751,0.028663,0.018115,0.009546,0.030981,1986-01-04
4,1986-01-05 00:00:00,0.082394,0.014302,0.033055,0.090867,0.065186,0.028168,0.016261,0.00978,0.095232,0.005172,0.045049,0.074312,0.081576,0.082401,0.038972,0.102499,0.020079,0.018873,0.013913,0.012728,0.007241,0.047764,0.115451,0.037254,0.057101,0.072843,0.013872,0.023346,1986-01-05


In [5]:
# Getting the real countries' name from 2 letter code

!pip install pycountry
import pycountry
names = {}
for code in wind.columns:
    try:
        names[code] = pycountry.countries.get(alpha_2=code).name
    except:
        print(code)

# For some reason, these two were not present
names['EL'] = 'Greece'
names['UK'] = 'United Kingdom'

Collecting pycountry
[?25l  Downloading https://files.pythonhosted.org/packages/16/b6/154fe93072051d8ce7bf197690957b6d0ac9a21d51c9a1d05bd7c6fdb16f/pycountry-19.8.18.tar.gz (10.0MB)
[K     |████████████████████████████████| 10.0MB 71kB/s 
[?25hBuilding wheels for collected packages: pycountry
  Building wheel for pycountry (setup.py) ... [?25l[?25hdone
  Created wheel for pycountry: filename=pycountry-19.8.18-py2.py3-none-any.whl size=10627361 sha256=f398b4cdbe7925f75177f82e6f7dfabcbe5c1deb87caba9c22abe3f176ef29ea
  Stored in directory: /root/.cache/pip/wheels/a2/98/bf/f0fa1c6bf8cf2cbdb750d583f84be51c2cd8272460b8b36bd3
Successfully built pycountry
Installing collected packages: pycountry
Successfully installed pycountry-19.8.18
time
EL
UK
datetime


In [0]:
wind.rename(columns = names, inplace=True)
# wind['year'] = pd.to_datetime(wind['time']).map(lambda x: x.year)
wind['month'] = pd.to_datetime(wind['time']).map(lambda x: x.month)
wind['weekday'] = pd.to_datetime(wind['time']).map(lambda x: x.weekday())

In [7]:
# Making seperate dataframes for each country's data and saving in seperate CSV files
!mkdir wind
country_wise = {}
for country in names.values():
    country_wise[country] = wind[['datetime', 'month', 'weekday', country]]
    country_wise[country].to_csv('wind/'+country+'.csv')
country_wise[list(country_wise.keys())[5]].head()

Unnamed: 0,datetime,month,weekday,Germany
0,1986-01-01,1,2,0.031583
1,1986-01-02,1,3,0.023506
2,1986-01-03,1,4,0.046181
3,1986-01-04,1,5,0.025011
4,1986-01-05,1,6,0.028168


In [8]:
wind.tail()

Unnamed: 0,time,Austria,Belgium,Bulgaria,Switzerland,Czechia,Germany,Denmark,Estonia,Spain,Finland,France,Greece,Croatia,Hungary,Ireland,Italy,Lithuania,Luxembourg,Latvia,Netherlands,Norway,Poland,Portugal,Romania,Slovenia,Slovakia,Sweden,United Kingdom,datetime,month,weekday
10952,2015-12-27 00:00:00,0.119231,0.077324,0.153964,0.134614,0.081821,0.055697,0.0151,0.038722,0.127358,0.01306,0.102966,0.163328,0.108146,0.044005,0.017154,0.112251,0.014419,0.092283,0.030101,0.01993,0.004189,0.034805,0.170225,0.136308,0.127401,0.047375,0.008486,0.028426,2015-12-27,12,6
10953,2015-12-28 00:00:00,0.128963,0.087688,0.151576,0.125772,0.054648,0.063808,0.024135,0.040206,0.103774,0.011533,0.089861,0.16448,0.096811,0.063488,0.010358,0.121303,0.010499,0.091828,0.036198,0.067833,0.00463,0.023683,0.059809,0.118209,0.12972,0.048307,0.008479,0.020515,2015-12-28,12,0
10954,2015-12-29 00:00:00,0.094872,0.028177,0.070883,0.121686,0.056463,0.054009,0.006732,0.029413,0.126492,0.008377,0.077485,0.170938,0.071934,0.074617,0.011983,0.119,0.017571,0.031521,0.012752,0.042996,0.004232,0.039898,0.173905,0.033584,0.064835,0.057409,0.011607,0.05522,2015-12-29,12,1
10955,2015-12-30 00:00:00,0.086713,0.05481,0.100528,0.114044,0.097525,0.041072,0.006402,0.01474,0.123499,0.00381,0.091785,0.064106,0.076847,0.132149,0.027625,0.108127,0.027267,0.014551,0.031294,0.057475,0.003595,0.067484,0.065125,0.106734,0.026071,0.134139,0.009297,0.011749,2015-12-30,12,2
10956,2015-12-31 00:00:00,0.111772,0.095547,0.115314,0.031523,0.122601,0.029161,0.012323,0.01921,0.092276,0.004629,0.067331,0.077266,0.108809,0.131938,0.022887,0.081279,0.02249,0.078602,0.018948,0.071013,0.003991,0.091865,0.087822,0.132066,0.114436,0.131965,0.015428,0.045504,2015-12-31,12,3


In [0]:
# Config file for WanDB sweeps

def make_config_file(file_path, df_len):
  run = wandb.init(project="pretrained-wind-updated")
  wandb_config = wandb.config
  train_number = df_len * .7
  validation_number = df_len *.9
  config_default={                 
    "model_name": "MultiAttnHeadSimple",
    "model_type": "PyTorch",
    "model_params": {
      "number_time_series": 3,
      "seq_len":wandb_config["forecast_history"], 
      "output_seq_len":wandb_config["out_seq_length"],
      "forecast_length":wandb_config["out_seq_length"]
     },
    "dataset_params":
    {  "class": "default",
       "training_path": file_path,
       "validation_path": file_path,
       "test_path": file_path,
       "batch_size":wandb_config["batch_size"],
       "forecast_history":wandb_config["forecast_history"],
       "forecast_length":wandb_config["out_seq_length"],
       "train_end": int(train_number),
       "valid_start":int(train_number+1),
       "valid_end": int(validation_number),
       "target_col": [file_path.split('.')[0].split('/')[1]],
       "relevant_cols": [file_path.split('.')[0].split('/')[1], "month", "weekday"],
       "scaler": "StandardScaler", 
       "interpolate": False
    },
    "training_params":
    {
       "criterion":"MSE",
       "optimizer": "Adam",
       "optim_params":
       {

       },
       "lr": wandb_config["lr"],
       "epochs": 10,
       "batch_size":wandb_config["batch_size"]
    
    },
    "GCS": False,
    
    "sweep":True,
    "wandb":False,
    "forward_params":{},
   "metrics":["MSE"],
   "inference_params":
   {     
         "datetime_start":"2010-01-01",
          "hours_to_forecast": 2000, 
          "test_csv_path":file_path,
          "decoder_params":{
              "decoder_function": "simple_decode", 
            "unsqueeze_dim": 1
          },
          "dataset_params":{
             "file_path": file_path,
             "forecast_history":wandb_config["forecast_history"],
             "forecast_length":wandb_config["out_seq_length"],
             "relevant_cols": [file_path.split('.')[0].split('/')[1], "month", "weekday"],
             "target_col": [file_path.split('.')[0].split('/')[1]],
             "scaling": "StandardScaler",
             "interpolate_param": False
          }
      }
  }
  wandb.config.update(config_default)
  return config_default

sweep_config = {
  "name": "Default sweep",
  "method": "grid",
  "parameters": {
        "batch_size": {
            "values": [2]
        },
        "lr":{
            "values":[0.001]
        },
        "forecast_history":{
            "values":[1, 2]
        },
        "out_seq_length":{
            "values":[1, 2]
        }
    }
}

Run sweep

In [10]:
import os
import pandas as pd

# The countries we will be pretraining our wind data on
os.listdir('wind')

['Spain.csv',
 'Switzerland.csv',
 'Austria.csv',
 'Italy.csv',
 'Slovenia.csv',
 'Portugal.csv',
 'Germany.csv',
 'Ireland.csv',
 'Sweden.csv',
 'Denmark.csv',
 'Romania.csv',
 'Hungary.csv',
 'Estonia.csv',
 'Bulgaria.csv',
 'Croatia.csv',
 'Slovakia.csv',
 'Netherlands.csv',
 'Latvia.csv',
 'Norway.csv',
 'Czechia.csv',
 'Poland.csv',
 'United Kingdom.csv',
 'Lithuania.csv',
 'Finland.csv',
 'Belgium.csv',
 'Greece.csv',
 'Luxembourg.csv',
 'France.csv']

In [0]:
# This definitely works
import wandb
for country in os.listdir('wind'):
    file_path = 'wind/'+country
    full_len = len(pd.read_csv(file_path))
    sweep_id = wandb.sweep(sweep_config, project="pretrained-wind-updated")
    wandb.agent(sweep_id, lambda:train_function("PyTorch", make_config_file(file_path, full_len)))
    !gsutil cp -n -r model_save gs://coronaviruspublicdata/pretrained

In [0]:
# It gave some errors, maybe some encoding ones.
import wandb
!mkdir model_save
for country in os.listdir('wind'):
    file_path = 'wind/'+country
    full_len = len(pd.read_csv(file_path))
    sweep_id = wandb.sweep(sweep_config, project="pretrained-wind-updated")
    paths = []
    if len(os.listdir("model_save"))>1:
        print("will use transfer")
        weight_files = filter(lambda x: x.endswith(".pth"), os.listdir("model_save")) 
        for weight_file in weight_files:
          paths.append(os.path.join("model_save", weight_file)) 
        correct_file = max(paths, key = os.path.getctime)
        print(correct_file) 
        wandb.agent(sweep_id, lambda:train_function("PyTorch", make_config_file(correct_file, full_len)))
    else:
        wandb.agent(sweep_id, lambda:train_function("PyTorch", make_config_file(file_path, full_len)))
    print("sucessfully completed sweep for: " + file_path)
    !gsutil cp -n -r model_save gs://coronaviruspublicdata/pretrained

Believe me, the above cell executed. It just took... 13 hours to do so.

**Check out the sweeps here :**  *https://app.wandb.ai/pranjalya/pretrained-wind-updated*

## Pre Training on Solar Data

In [0]:
import pandas as pd
solar = pd.read_csv('solar.csv')

In [0]:
solar.head()

Unnamed: 0,LocalTime,Power(MW),Power(MW).1,Power(MW).2,Power(MW).3,Power(MW).4,Power(MW).5,Power(MW).6,Power(MW).7,Power(MW).8,Power(MW).9,Power(MW).10,Power(MW).11,Power(MW).12,Power(MW).13,Power(MW).14,Power(MW).15,Power(MW).16,Power(MW).17,Power(MW).18,Power(MW).19,Power(MW).20,Power(MW).21,Power(MW).22,Power(MW).23,Power(MW).24,Power(MW).25,Power(MW).26,Power(MW).27,Power(MW).28,Power(MW).29,Power(MW).30,Power(MW).31,Power(MW).32,Power(MW).33,Power(MW).34,Power(MW).35,Power(MW).36,Power(MW).37,Power(MW).38,...,Power(MW).97,Power(MW).98,Power(MW).99,Power(MW).100,Power(MW).101,Power(MW).102,Power(MW).103,Power(MW).104,Power(MW).105,Power(MW).106,Power(MW).107,Power(MW).108,Power(MW).109,Power(MW).110,Power(MW).111,Power(MW).112,Power(MW).113,Power(MW).114,Power(MW).115,Power(MW).116,Power(MW).117,Power(MW).118,Power(MW).119,Power(MW).120,Power(MW).121,Power(MW).122,Power(MW).123,Power(MW).124,Power(MW).125,Power(MW).126,Power(MW).127,Power(MW).128,Power(MW).129,Power(MW).130,Power(MW).131,Power(MW).132,Power(MW).133,Power(MW).134,Power(MW).135,Power(MW).136
0,2006-01-01 00:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2006-01-01 01:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,2006-01-01 02:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,2006-01-01 03:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,2006-01-01 04:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [0]:
solar.describe()

Unnamed: 0,Power(MW),Power(MW).1,Power(MW).2,Power(MW).3,Power(MW).4,Power(MW).5,Power(MW).6,Power(MW).7,Power(MW).8,Power(MW).9,Power(MW).10,Power(MW).11,Power(MW).12,Power(MW).13,Power(MW).14,Power(MW).15,Power(MW).16,Power(MW).17,Power(MW).18,Power(MW).19,Power(MW).20,Power(MW).21,Power(MW).22,Power(MW).23,Power(MW).24,Power(MW).25,Power(MW).26,Power(MW).27,Power(MW).28,Power(MW).29,Power(MW).30,Power(MW).31,Power(MW).32,Power(MW).33,Power(MW).34,Power(MW).35,Power(MW).36,Power(MW).37,Power(MW).38,Power(MW).39,...,Power(MW).97,Power(MW).98,Power(MW).99,Power(MW).100,Power(MW).101,Power(MW).102,Power(MW).103,Power(MW).104,Power(MW).105,Power(MW).106,Power(MW).107,Power(MW).108,Power(MW).109,Power(MW).110,Power(MW).111,Power(MW).112,Power(MW).113,Power(MW).114,Power(MW).115,Power(MW).116,Power(MW).117,Power(MW).118,Power(MW).119,Power(MW).120,Power(MW).121,Power(MW).122,Power(MW).123,Power(MW).124,Power(MW).125,Power(MW).126,Power(MW).127,Power(MW).128,Power(MW).129,Power(MW).130,Power(MW).131,Power(MW).132,Power(MW).133,Power(MW).134,Power(MW).135,Power(MW).136
count,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,...,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0,5832.0
mean,97.990021,72.601543,72.617678,181.308368,163.551646,64.378909,68.242027,118.852452,79.61214,101.382236,75.003927,77.286043,69.262551,78.501012,73.46677,79.95667,78.386968,73.687637,72.940055,73.225892,77.1518,77.681207,77.547582,77.543021,76.565895,76.831962,76.478326,78.114215,77.010734,70.163134,77.808933,70.796005,70.287329,77.684585,76.008676,76.927332,77.182888,72.178669,77.44477,77.26905,...,73.652795,69.24405,72.905161,177.831687,64.941169,73.224571,54.165501,143.012654,78.924091,69.42296,68.952023,54.896039,75.234688,53.857613,64.792473,61.921416,72.59237,54.449314,72.350617,77.925634,72.908951,151.620405,41.334774,77.7893,70.294016,75.423577,41.880676,76.978515,163.73654,76.747531,77.040929,76.639506,123.436848,77.488529,76.99417,123.925892,75.624348,72.612363,61.697565,74.11418
std,121.215824,100.775117,102.617019,255.57798,227.711359,91.088087,96.013131,167.768777,109.091165,142.448961,105.413903,106.748239,96.573719,108.095548,103.759917,109.660217,111.123343,104.098793,103.39274,103.677634,108.577382,108.463883,108.651308,106.827237,105.974013,107.632612,105.833609,110.312326,107.852568,98.031765,109.073624,96.789302,97.835354,108.632086,106.472926,107.922032,106.690069,102.596689,107.988948,107.989447,...,92.000652,97.306594,100.894989,216.932139,91.735845,103.56603,75.38495,198.830916,108.327914,96.93925,96.811196,76.699214,105.744341,74.698423,91.477063,85.882681,100.702548,76.392924,101.918631,109.124838,103.196428,186.609687,57.09526,109.048518,96.184135,106.200268,57.693641,108.015106,228.709291,106.271689,107.825797,105.532143,170.698819,109.003895,107.316424,170.578983,93.098493,103.11633,86.162491,104.574143
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,5.65,0.1,0.1,0.2,0.2,0.1,0.1,0.2,0.1,0.1,0.2,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.2,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.0,0.1,0.1,0.0,0.1,0.0,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,...,3.7,0.1,0.1,8.2,0.1,0.1,0.0,0.1,0.1,0.1,0.1,0.1,0.1,0.0,0.1,0.0,0.0,0.0,0.1,0.1,0.1,6.75,0.0,0.1,0.0,0.1,0.0,0.1,0.2,0.1,0.1,0.1,0.2,0.1,0.1,0.1,4.5,0.1,0.1,0.1
75%,215.8,140.9,138.15,353.325,325.225,121.05,132.225,228.175,161.75,200.15,143.35,152.85,135.575,156.225,140.55,160.825,151.225,140.525,136.1,138.575,148.4,152.925,152.6,154.75,151.45,150.125,154.325,149.025,150.05,139.025,151.15,144.9,137.725,150.5,148.325,148.9,151.925,132.725,152.85,152.625,...,163.325,134.45,142.325,401.125,124.75,138.5,106.2,290.425,156.825,138.25,131.425,105.625,145.925,105.825,123.225,124.125,141.825,104.825,137.325,154.1,137.55,338.775,84.2,150.725,142.3,142.625,84.75,148.525,327.0,150.425,150.0,152.225,248.4,148.2,152.2,253.475,170.125,136.425,125.025,139.8
max,371.1,349.0,354.8,949.0,801.0,321.8,330.3,608.9,379.6,499.7,371.3,368.8,354.9,381.7,361.2,383.2,403.6,369.8,378.5,378.9,380.8,375.0,392.2,370.1,365.8,371.3,365.3,400.4,388.0,345.7,367.9,341.1,354.7,374.7,378.6,374.6,373.1,356.6,382.3,378.0,...,280.8,341.4,345.6,648.8,323.2,373.0,274.3,713.2,383.2,345.3,340.4,273.1,369.7,258.0,326.4,302.0,346.6,279.1,352.2,371.9,356.4,560.2,197.9,376.5,334.4,376.3,202.6,372.6,814.4,365.8,370.2,369.4,615.1,387.1,391.4,592.2,283.8,366.8,299.7,376.1


In [0]:
solar.tail()

Unnamed: 0,LocalTime,Power(MW),Power(MW).1,Power(MW).2,Power(MW).3,Power(MW).4,Power(MW).5,Power(MW).6,Power(MW).7,Power(MW).8,Power(MW).9,Power(MW).10,Power(MW).11,Power(MW).12,Power(MW).13,Power(MW).14,Power(MW).15,Power(MW).16,Power(MW).17,Power(MW).18,Power(MW).19,Power(MW).20,Power(MW).21,Power(MW).22,Power(MW).23,Power(MW).24,Power(MW).25,Power(MW).26,Power(MW).27,Power(MW).28,Power(MW).29,Power(MW).30,Power(MW).31,Power(MW).32,Power(MW).33,Power(MW).34,Power(MW).35,Power(MW).36,Power(MW).37,Power(MW).38,...,Power(MW).97,Power(MW).98,Power(MW).99,Power(MW).100,Power(MW).101,Power(MW).102,Power(MW).103,Power(MW).104,Power(MW).105,Power(MW).106,Power(MW).107,Power(MW).108,Power(MW).109,Power(MW).110,Power(MW).111,Power(MW).112,Power(MW).113,Power(MW).114,Power(MW).115,Power(MW).116,Power(MW).117,Power(MW).118,Power(MW).119,Power(MW).120,Power(MW).121,Power(MW).122,Power(MW).123,Power(MW).124,Power(MW).125,Power(MW).126,Power(MW).127,Power(MW).128,Power(MW).129,Power(MW).130,Power(MW).131,Power(MW).132,Power(MW).133,Power(MW).134,Power(MW).135,Power(MW).136
5827,2006-08-31 19:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5828,2006-08-31 20:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5829,2006-08-31 21:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5830,2006-08-31 22:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5831,2006-08-31 23:00:00,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [0]:
solar.rename(columns={'LocalTime':'datetime'}, inplace=True)
solar = solar[['datetime', 'Power(MW)']]
solar.set_index('datetime', drop=False)

Unnamed: 0_level_0,datetime,Power(MW)
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1
2006-01-01 00:00:00,2006-01-01 00:00:00,0.0
2006-01-01 01:00:00,2006-01-01 01:00:00,0.0
2006-01-01 02:00:00,2006-01-01 02:00:00,0.0
2006-01-01 03:00:00,2006-01-01 03:00:00,0.0
2006-01-01 04:00:00,2006-01-01 04:00:00,0.0
...,...,...
2006-08-31 19:00:00,2006-08-31 19:00:00,0.0
2006-08-31 20:00:00,2006-08-31 20:00:00,0.0
2006-08-31 21:00:00,2006-08-31 21:00:00,0.0
2006-08-31 22:00:00,2006-08-31 22:00:00,0.0


So, basically the first 1469 values, and last 6 values are just 0, so let's remove those rows, as they may just add bias.

In [0]:
solar = solar.iloc[1470:5826]
solar.describe()

Unnamed: 0,Power(MW)
count,4356.0
mean,108.773531
std,125.396285
min,0.0
25%,0.0
50%,29.45
75%,231.725
max,371.1


In [0]:
# Extracting the timeseries relevant columns
# solar['day'] = pd.to_datetime(solar['datetime']).map(lambda x: x.day)
solar['month'] = pd.to_datetime(solar['datetime']).map(lambda x: x.month)
solar['weekday'] = pd.to_datetime(solar['datetime']).map(lambda x: x.weekday())
# solar['hour'] = pd.to_datetime(solar['datetime']).map(lambda x: x.hour)

In [0]:
solar.to_csv('selected_solar.csv')
solar.tail()

Unnamed: 0,datetime,Power(MW),month,weekday
5821,2006-08-31 13:00:00,111.3,8,3
5822,2006-08-31 14:00:00,201.1,8,3
5823,2006-08-31 15:00:00,203.8,8,3
5824,2006-08-31 16:00:00,135.5,8,3
5825,2006-08-31 17:00:00,52.9,8,3


In [0]:
# Config file for WanDB sweeps

def make_config_file(file_path, df_len):
  run = wandb.init(project="pretrained-solar-updated")
  wandb_config = wandb.config
  train_number = df_len * .7
  validation_number = df_len *.9
  config_default={                 
    "model_name": "MultiAttnHeadSimple",
    "model_type": "PyTorch",
    "model_params": {
      "number_time_series":3,
      "seq_len":wandb_config["forecast_history"], 
      "output_seq_len":wandb_config["out_seq_length"],
      "forecast_length":wandb_config["out_seq_length"]
     },
    "dataset_params":
    {  "class": "default",
       "training_path": file_path,
       "validation_path": file_path,
       "test_path": file_path,
       "batch_size":wandb_config["batch_size"],
       "forecast_history":wandb_config["forecast_history"],
       "forecast_length":wandb_config["out_seq_length"],
       "train_end": int(train_number),
       "valid_start":int(train_number+1),
       "valid_end": int(validation_number),
       "target_col": ["Power(MW)"],
       "relevant_cols": ["Power(MW)", "month", "weekday"],
       "scaler": "StandardScaler", 
       "interpolate": False
    },
    "training_params":
    {
       "criterion":"MSE",
       "optimizer": "Adam",
       "optim_params":
       {

       },
       "lr": wandb_config["lr"],
       "epochs": 10,
       "batch_size":wandb_config["batch_size"]
    
    },
    "GCS": False,
    
    "sweep":True,
    "wandb":False,
    "forward_params":{},
   "metrics":["MSE"],
   "inference_params":
   {     
         "datetime_start":"2006-08-22",
          "hours_to_forecast":150, 
          "test_csv_path":file_path,
          "decoder_params":{
              "decoder_function": "simple_decode", 
            "unsqueeze_dim": 1
          },
          "dataset_params":{
             "file_path": file_path,
             "forecast_history":wandb_config["forecast_history"],
             "forecast_length":wandb_config["out_seq_length"],
             "relevant_cols": ["Power(MW)", "month", "weekday"],
             "target_col": ["Power(MW)"],
             "scaling": "StandardScaler",
             "interpolate_param": False
          }
      }
  }
  wandb.config.update(config_default)
  return config_default

sweep_config = {
  "name": "Default sweep",
  "method": "grid",
  "parameters": {
        "batch_size": {
            "values": [2, 3, 4]
        },
        "lr":{
            "values":[0.001, 0.01]
        },
        "forecast_history":{
            "values":[1, 2, 3, 5]
        },
        "out_seq_length":{
            "values":[1, 2]
        }
    }
}

In [0]:
# Running the sweep
import wandb
file_path = 'selected_solar.csv'
full_len = len(pd.read_csv(file_path))
sweep_id = wandb.sweep(sweep_config, project="pretrained-solar-updated")
wandb.agent(sweep_id, lambda:train_function("PyTorch", make_config_file(file_path, full_len)))
!gsutil cp -n -r model_save gs://coronaviruspublicdata/pretrained

Similarly, this cell ran too.

**Check out the sweeps here :**  *https://app.wandb.ai/pranjalya/pretrained-solar-updated*