Download data from OWID and generate dataset files with train-val splits

In [1]:
import pandas as pd
import numpy as np
import datetime as dt
import torch
from src import data
from torch.utils import data as tdt

DATA_DIR = 'data'

### Download

In [None]:
!curl https://covid.ourworldindata.org/data/owid-covid-data.csv --output data/owid_$(date +%Y-%m-%d).csv

### Config

In [9]:
config = {
    'FEATURES': ['new_cases', 'new_deaths', 'temp_mean', 'humidity_mean', 'pressure_mean'],
    'POP_FEATURES': ['new_cases', 'new_deaths'],
    'AUX_FEATURES': ['population_density', 'gdp_per_capita', 'hospital_beds_per_thousand', 'median_age'],
    "VAL_RATIO": 0.3,
    "IP_SEQ_LEN": 40,
    "OP_SEQ_LEN": 20,
    "SRC": "dataset_2020-09-15_v3.csv"
}
fn = "ds_cdthp_pgba_" + str(config['IP_SEQ_LEN']) + str(config['OP_SEQ_LEN']) + '_' + config['SRC'] + ".pt"

### Read

In [10]:
cols = ['location', 'date', 'total_cases', 'new_cases', 'total_deaths', 'new_deaths', 'population', 'population_density', 'gdp_per_capita', 'hospital_beds_per_thousand', 'median_age',  'temp_mean', 'humidity_mean', 'pressure_mean']
dates = ['date']
df = pd.read_csv(DATA_DIR + "/" + config['SRC'],
                 usecols=cols,
                 parse_dates=dates)
df = data.fix_anomalies_owid(df)
df.sample()

Unnamed: 0,location,date,total_cases,new_cases,total_deaths,new_deaths,population,population_density,median_age,gdp_per_capita,hospital_beds_per_thousand,temp_mean,pressure_mean,humidity_mean
23773,Luxembourg,2020-02-08,0.0,0.0,0.0,0.0,625976.0,231.447,39.7,94277.965,4.51,0.023483,1.02175,0.819167


### Prepare dataset

In [4]:
def gen_dataset(cfg):
    IP_SEQ_LEN = cfg['IP_SEQ_LEN']
    OP_SEQ_LEN = cfg['OP_SEQ_LEN']
    VAL_RATIO = cfg['VAL_RATIO']
    
    ip_trn = []
    ip_aux = []
    op_trn = []

    countries = df['location'].unique()
    pop_countries = ['China', 'United States', 'Indonesia', 'Pakistan', 'Brazil', 'Bangladesh', 'Russia', 'Mexico']

    c = 0
    for country in countries:
        if country in ['World', 'International', 'India']: # Countries to be skipped
            continue
        country_df = df.loc[df.location == country]
        tot_cases_gt_100 = (country_df['total_cases'] >= 100)
        country_df = country_df.loc[tot_cases_gt_100]

        if len(country_df) >= IP_SEQ_LEN + OP_SEQ_LEN:
            c += 1
            pop = country_df['population'].iloc[0]
            print(c, country, len(country_df), pop)
            aux_ips = np.array(country_df[cfg['AUX_FEATURES']].iloc[0])
            country_df[cfg['POP_FEATURES']] = country_df[cfg['POP_FEATURES']] * 1000 / pop
            daily_cases = np.array(country_df[cfg['FEATURES']].rolling(7, center=True, min_periods=1).mean(), dtype=np.float32)

            for i in range(len(country_df) - IP_SEQ_LEN - OP_SEQ_LEN + 1):
                ip_trn.append(daily_cases[i : i+IP_SEQ_LEN])
                ip_aux.append(aux_ips)
                op_trn.append(daily_cases[i+IP_SEQ_LEN : i+IP_SEQ_LEN+OP_SEQ_LEN])

    ip_trn = torch.from_numpy(np.array(ip_trn, dtype=np.float32))
    ip_aux = torch.from_numpy(np.array(ip_aux, dtype=np.float32))
    op_trn = torch.from_numpy(np.array(op_trn, dtype=np.float32))
    dataset = tdt.TensorDataset(ip_trn, ip_aux, op_trn)

    val_len = int(VAL_RATIO * len(dataset))
    trn_len = len(dataset) - val_len
    trn_set, val_set = tdt.random_split(dataset, (trn_len, val_len))
    return trn_set, val_set

In [11]:
try:
    ds = torch.load(DATA_DIR + '/' + fn)
    trn_set, val_set, ds_cfg = ds['trn'], ds['val'], ds['config']
    print(fn, "already exists.")
    print(ds_cfg)
except FileNotFoundError:
    trn_set, val_set = gen_dataset(config)
    torch.save({'trn': trn_set, 'val': val_set, 'config': config}, DATA_DIR + '/' + fn)
    print("Saved dataset to", fn)
finally:
    print("Training data:", len(trn_set), "Validation data:", len(val_set))

1 Afghanistan 171 38928341.0
2 Albania 176 2877800.0
3 Algeria 179 43851043.0
4 Angola 97 32866268.0
5 Argentina 180 45195777.0
6 Armenia 181 2963234.0
7 Aruba 146 106766.0
8 Australia 190 25499881.0
9 Austria 191 9006400.0
10 Azerbaijan 173 10139175.0
11 Bahamas 115 393248.0
12 Bahrain 190 1701583.0
13 Bangladesh 162 164689383.0
14 Barbados 66 287371.0
15 Belarus 169 9449321.0
16 Belgium 195 11589616.0
17 Benin 132 12123198.0
18 Bermuda 142 62273.0
19 Bolivia 169 11673029.0
20 Bosnia and Herzegovina 177 3280815.0
21 Botswana 81 2351625.0
22 Brazil 185 212559409.0
23 Brunei 175 437483.0
24 Bulgaria 180 6948445.0
25 Burkina Faso 174 20903278.0
26 Burundi 92 11890781.0
27 Cambodia 172 16718971.0
28 Cameroon 169 26545864.0
29 Canada 188 37742157.0
30 Cape Verde 142 555988.0
31 Cayman Islands 119 65720.0
32 Central African Republic 130 4829764.0
33 Chad 136 16425859.0
34 Chile 183 19116209.0
35 China 241 1439323774.0
36 Colombia 181 50882884.0
37 Comoros 108 869595.0
38 Congo 153 5518092.0