In [1]:
import os
import pandas as pd
import requests 
from AlphanumericsTeam.data.util import get_aug_oxford_df, filter_df_regions


REPO_ROOT = os.path.abspath(os.path.join(os.path.abspath(''), os.pardir, os.pardir, os.pardir))

In [2]:
#sample input and output

# IP until 30 sep
EXAMPLE_INPUT_FILE = os.path.join(REPO_ROOT, "covid_xprize/validation/data/2020-09-30_historical_ip.csv")
prediction_input_df = pd.read_csv(EXAMPLE_INPUT_FILE,
                                  parse_dates=['Date'],
                                  dtype={"RegionName": str},
                                  encoding="ISO-8859-1")

#print(prediction_input_df)

#prediction from Aug 1 to Aug 4
EXAMPLE_OUTPUT_FILE = os.path.join(REPO_ROOT, "2020-08-01_2020-08-04_predictions_example.csv")
prediction_output_df = pd.read_csv(EXAMPLE_OUTPUT_FILE,
                                   parse_dates=['Date'],
                                   encoding="ISO-8859-1")  


print(prediction_output_df)

     CountryName RegionName       Date  PredictedDailyNewCases  IsSpecialty
0          Aruba        NaN 2020-08-01                0.820071            0
1          Aruba        NaN 2020-08-02                0.872854            0
2          Aruba        NaN 2020-08-03                0.000000            0
3          Aruba        NaN 2020-08-04                0.000000            0
4    Afghanistan        NaN 2020-08-01               80.590128            0
..           ...        ...        ...                     ...          ...
763       Zambia        NaN 2020-08-04              172.532764            0
764     Zimbabwe        NaN 2020-08-01              178.485848            0
765     Zimbabwe        NaN 2020-08-02              142.449493            0
766     Zimbabwe        NaN 2020-08-03               84.436329            0
767     Zimbabwe        NaN 2020-08-04              199.259844            0

[768 rows x 5 columns]


In [3]:
# Input data for training

#DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'
#df = pd.read_csv(DATA_URL,
#                 parse_dates=['Date'],
#                 encoding="ISO-8859-1",
#                 dtype={"RegionName": str,
#                        "RegionCode": str},
#                 error_bad_lines=False)

# Has 6 additional columns 
# 'New Cases' 
# 'GeoID' 
# 'Holidays' 
# 'pop_2020' 
# 'area_km2' 
# 'density_perkm2'
df = get_aug_oxford_df() 
df = filter_df_regions(df)

# Final list of 180 countries and 56 regions
assert df.CountryName.unique().size == 180
assert df.RegionName.unique().size == 56 + 1 

In [4]:
print(df.sample(3))
df["DailyChangeConfirmedCases"] = df.groupby(["CountryName", "RegionName"]).ConfirmedCases.diff().fillna(0)
california_df = df[(df.CountryName == "United States") & (df.RegionName == "California")]
california_df[["CountryName", "RegionName", "Date", "ConfirmedCases", "DailyChangeConfirmedCases"]].tail(5)

         CountryName CountryCode   RegionName RegionCode Jurisdiction  \
41282      Hong Kong         HKG                     NaN    NAT_TOTAL   
87540  United States         USA  Mississippi      US_MS  STATE_TOTAL   
87120  United States         USA     Missouri      US_MO  STATE_TOTAL   

            Date  C1_School closing  C1_Flag  C2_Workplace closing  C2_Flag  \
41282 2020-04-22                3.0      1.0                   2.0      1.0   
87540 2020-07-07                3.0      1.0                   1.0      1.0   
87120 2020-05-06                3.0      1.0                   2.0      0.0   

       ...  ContainmentHealthIndex  ContainmentHealthIndexForDisplay  \
41282  ...                   60.90                             60.90   
87540  ...                   54.81                             54.81   
87120  ...                   53.53                             53.53   

       EconomicSupportIndex  EconomicSupportIndexForDisplay  NewCases  \
41282                 100.0 

Unnamed: 0,CountryName,RegionName,Date,ConfirmedCases,DailyChangeConfirmedCases
80187,United States,California,2020-12-19,1842557.0,40362.0
80188,United States,California,2020-12-20,1884033.0,41476.0
80189,United States,California,2020-12-21,1923887.0,39854.0
80190,United States,California,2020-12-22,,0.0
80191,United States,California,2020-12-23,,0.0


In [5]:
# TRAINING
import xprize_predictor
from importlib import reload
reload(xprize_predictor)
from xprize_predictor import XPrizePredictor

DATA_PATH = os.path.join("data", 'OxCGRT_latest_aug.csv')
predictor = XPrizePredictor(None, DATA_PATH)
 
predictor_model = predictor.train()

model_weights_file = "models/trained_model_weights.h5"
if not os.path.exists('models'):
    os.mkdir('models')
predictor_model.save_weights(model_weights_file)

Creating numpy arrays for Keras for each country...
Numpy arrays created
 len geos 234
69988 0.0 57.00170398805195
69988 0.0 57.00170398805195
Trial 0
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100


In [6]:
# TESTING AND PREDICTION ---
DATA_PATH = os.path.join("data", 'OxCGRT_latest_aug.csv')
model_weights_file = "models/trained_model_weights.h5"

import xprize_predictor
from importlib import reload
reload(xprize_predictor)

from xprize_predictor import XPrizePredictor

predictor = XPrizePredictor(model_weights_file, DATA_PATH)

NPIS_INPUT_FILE = "../../validation/data/2020-09-30_historical_ip.csv"
start_date = "2020-08-01"
end_date = "2020-08-31"
 
preds_df = predictor.predict(start_date, end_date, NPIS_INPUT_FILE)
preds_df.head()

Start and end date 2020-08-01 00:00:00 2020-08-31 00:00:00
days 31


Unnamed: 0,CountryName,RegionName,Date,PredictedDailyNewCases
0,Aruba,,2020-08-01,20.451585
1,Aruba,,2020-08-02,22.44224
2,Aruba,,2020-08-03,14.003275
3,Aruba,,2020-08-04,12.222513
4,Aruba,,2020-08-05,16.460395


In [7]:
result = "result"+ "_" + str(start_date) + "_" + end_date + ".csv"
preds_df.to_csv(result, index=False)

In [8]:
!python predict.py -s 2020-08-01 -e 2020-08-04 -ip ../../validation/data/2020-09-30_historical_ip.csv -o predictions/2020-08-01_2020-08-04.csv
!head predictions/2020-08-01_2020-08-04.csv

Generating predictions from 2020-08-01 to 2020-08-04...
Start and end date 2020-08-01 00:00:00 2020-08-04 00:00:00
days 4
Saved predictions to predictions/2020-08-01_2020-08-04.csv
Done!
CountryName,RegionName,Date,PredictedDailyNewCases
Aruba,,2020-08-01,20.4515848266264
Aruba,,2020-08-02,22.442240146954234
Aruba,,2020-08-03,14.003274992386533
Aruba,,2020-08-04,12.222513322200891
Afghanistan,,2020-08-01,248.95815257401222
Afghanistan,,2020-08-02,192.68296045392853
Afghanistan,,2020-08-03,178.74634712829203
Afghanistan,,2020-08-04,203.30220009377666
Angola,,2020-08-01,74.19987738955582


In [9]:
# Check the pediction file is valid
import os
from covid_xprize.validation.predictor_validation import validate_submission

def validate(start_date, end_date, ip_file, output_file):
    # First, delete any potential old file
    try:
        os.remove(output_file)
    except OSError:
        pass
    
    # Then generate the prediction, calling the official API
    !python predict.py -s {start_date} -e {end_date} -ip {ip_file} -o {output_file}
    
    # And validate it
    errors = validate_submission(start_date, end_date, ip_file, output_file)
    if errors:
        for error in errors:
            print(error)
    else:
        print("All good!")

In [10]:
validate(start_date="2020-08-01",
         end_date="2020-08-04",
         ip_file="../../validation/data/2020-09-30_historical_ip.csv",
         output_file="predictions/val_4_days.csv")

Generating predictions from 2020-08-01 to 2020-08-04...
Start and end date 2020-08-01 00:00:00 2020-08-04 00:00:00
days 4
Saved predictions to predictions/val_4_days.csv
Done!
All good!


In [11]:
%%time
validate(start_date="2021-01-01",
         end_date="2021-01-31",
         ip_file="../../validation/data/future_ip.csv",
         output_file="predictions/val_1_month_future.csv")

Generating predictions from 2021-01-01 to 2021-01-31...
Start and end date 2021-01-01 00:00:00 2021-01-31 00:00:00
days 31
Saved predictions to predictions/val_1_month_future.csv
Done!
All good!
CPU times: user 154 ms, sys: 39.3 ms, total: 194 ms
Wall time: 12.9 s


In [12]:
from datetime import datetime, timedelta

start_date = datetime.now() + timedelta(days=7)
start_date_str = start_date.strftime('%Y-%m-%d')
end_date = start_date + timedelta(days=180)
end_date_str = end_date.strftime('%Y-%m-%d')
print(f"Start date: {start_date_str}")
print(f"End date: {end_date_str}")

Start date: 2021-01-04
End date: 2021-07-03


In [13]:
from covid_xprize.validation.scenario_generator import get_raw_data, generate_scenario, NPI_COLUMNS
DATA_FILE = 'data/OxCGRT_latest.csv'
latest_df = get_raw_data(DATA_FILE, latest=True)
scenario_df = generate_scenario(start_date_str, end_date_str, latest_df, countries=None, scenario="Freeze")
scenario_file = "predictions/180_days_future_scenario.csv"
scenario_df.to_csv(scenario_file, index=False)
print(f"Saved scenario to {scenario_file}")

Saved scenario to predictions/180_days_future_scenario.csv


In [14]:
%%time
validate(start_date=start_date_str,
         end_date=end_date_str,
         ip_file=scenario_file,
         output_file="predictions/val_6_month_future.csv")

Generating predictions from 2021-01-04 to 2021-07-03...
Start and end date 2021-01-04 00:00:00 2021-07-03 00:00:00
days 181
Saved predictions to predictions/val_6_month_future.csv
Done!
Missing countries / regions: {'Brazil / Alagoas', 'Brazil / Amapa', 'Brazil / Bahia', 'United States Virgin Islands', 'Brazil / Minas Gerais', 'Canada / New Brunswick', 'Brazil / Rio Grande do Norte', 'Canada / Prince Edward Island', 'Brazil / Pernambuco', 'Brazil / Rondonia', 'Brazil / Paraiba', 'Canada / Newfoundland and Labrador', 'Canada / Yukon', 'Canada / Saskatchewan', 'Turkmenistan', 'Brazil / Acre', 'Brazil / Maranhao', 'Brazil / Mato Grosso do Sul', 'Brazil / Santa Catarina', 'Malta', 'Brazil / Tocantins', 'Canada / Manitoba', 'Brazil / Rio Grande do Sul', 'Brazil / Roraima', 'Brazil / Rio de Janeiro', 'Brazil / Sergipe', 'Canada / Northwest Territories', 'Canada / Ontario', 'Brazil / Sao Paulo', 'Canada / Nova Scotia', 'Brazil / Distrito Federal', 'Brazil / Mato Grosso', 'Brazil / Para', 'Can

In [15]:

#TESTING/DEBUGGING CODE
import pandas as pd
from util import add_features_df

df = pd.read_csv("../../validation/data/2020-09-30_historical_ip.csv",
                                parse_dates=['Date'],
                                encoding="ISO-8859-1",) 
df["RegionName"] = df["RegionName"].fillna(value="") 
add_features_df(df)

Unnamed: 0,CountryName,RegionName,Date,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings,Holidays,density_perkm2
0,Aruba,,2020-01-01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,593.0
1,Aruba,,2020-01-02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,593.0
2,Aruba,,2020-01-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,593.0
3,Aruba,,2020-01-04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,593.0
4,Aruba,,2020-01-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,593.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64659,Zimbabwe,,2020-09-26,2.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,3.0,1.0,38.0
64660,Zimbabwe,,2020-09-27,2.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,3.0,1.0,38.0
64661,Zimbabwe,,2020-09-28,2.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,3.0,0.0,38.0
64662,Zimbabwe,,2020-09-29,2.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,1.0,0.0,38.0
