# Example Predictor: Linear Rollout Predictor

This example contains basic functionality for training and evaluating a linear predictor that rolls out predictions day-by-day.

First, a training data set is created from historical case and npi data.

Second, a linear model is trained to predict future cases from prior case data along with prior and future npi data.
The model is an off-the-shelf sklearn Lasso model, that uses a positive weight constraint to enforce the assumption that increased npis has a negative correlation with future cases.

Third, a sample evaluation set is created, and the predictor is applied to this evaluation set to produce prediction results in the correct format.

## Training

In [1]:
import pickle
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split

### Copy the data locally

In [2]:
# Main source for the training data
DATA_URL = 'https://raw.githubusercontent.com/OxCGRT/covid-policy-tracker/master/data/OxCGRT_latest.csv'
# Local file
DATA_FILE = 'data/OxCGRT_latest.csv'

In [3]:
import os
import urllib.request
if not os.path.exists('data'):
    os.mkdir('data')
urllib.request.urlretrieve(DATA_URL, DATA_FILE)

('data/OxCGRT_latest.csv', <http.client.HTTPMessage at 0x7fa6118e7b70>)

In [4]:
# Load historical data from local file
df = pd.read_csv(DATA_FILE, 
                 parse_dates=['Date'],
                 encoding="ISO-8859-1",
                 dtype={"RegionName": str,
                        "RegionCode": str},
                 error_bad_lines=False)

In [5]:
df.columns

Index(['CountryName', 'CountryCode', 'RegionName', 'RegionCode',
       'Jurisdiction', 'Date', 'C1_School closing', 'C1_Flag',
       'C2_Workplace closing', 'C2_Flag', 'C3_Cancel public events', 'C3_Flag',
       'C4_Restrictions on gatherings', 'C4_Flag', 'C5_Close public transport',
       'C5_Flag', 'C6_Stay at home requirements', 'C6_Flag',
       'C7_Restrictions on internal movement', 'C7_Flag',
       'C8_International travel controls', 'E1_Income support', 'E1_Flag',
       'E2_Debt/contract relief', 'E3_Fiscal measures',
       'E4_International support', 'H1_Public information campaigns',
       'H1_Flag', 'H2_Testing policy', 'H3_Contact tracing',
       'H4_Emergency investment in healthcare', 'H5_Investment in vaccines',
       'H6_Facial Coverings', 'H6_Flag', 'M1_Wildcard', 'ConfirmedCases',
       'ConfirmedDeaths', 'StringencyIndex', 'StringencyIndexForDisplay',
       'StringencyLegacyIndex', 'StringencyLegacyIndexForDisplay',
       'GovernmentResponseIndex', 'Gove

In [6]:
# For testing, restrict training data to that before a hypothetical predictor submission date
HYPOTHETICAL_SUBMISSION_DATE = np.datetime64("2020-07-31")
df = df[df.Date <= HYPOTHETICAL_SUBMISSION_DATE]

In [7]:
# Add RegionID column that combines CountryName and RegionName for easier manipulation of data
df['GeoID'] = df['CountryName'] + '__' + df['RegionName'].astype(str)

In [8]:
# Add new cases column
df['NewCases'] = df.groupby('GeoID').ConfirmedCases.diff().fillna(0)

In [9]:
# Keep only columns of interest
id_cols = ['CountryName',
           'RegionName',
           'GeoID',
           'Date']
cases_col = ['NewCases']
npi_cols = ['C1_School closing',
            'C2_Workplace closing',
            'C3_Cancel public events',
            'C4_Restrictions on gatherings',
            'C5_Close public transport',
            'C6_Stay at home requirements',
            'C7_Restrictions on internal movement',
            'C8_International travel controls',
            'H1_Public information campaigns',
            'H2_Testing policy',
            'H3_Contact tracing',
            'H6_Facial Coverings']
df = df[id_cols + cases_col + npi_cols]

In [10]:
# Fill any missing case values by interpolation and setting NaNs to 0
df.update(df.groupby('GeoID').NewCases.apply(
    lambda group: group.interpolate()).fillna(0))

In [11]:
# Fill any missing NPIs by assuming they are the same as previous day
for npi_col in npi_cols:
    df.update(df.groupby('GeoID')[npi_col].ffill().fillna(0))

In [12]:
df

Unnamed: 0,CountryName,RegionName,GeoID,Date,NewCases,C1_School closing,C2_Workplace closing,C3_Cancel public events,C4_Restrictions on gatherings,C5_Close public transport,C6_Stay at home requirements,C7_Restrictions on internal movement,C8_International travel controls,H1_Public information campaigns,H2_Testing policy,H3_Contact tracing,H6_Facial Coverings
0,Aruba,,Aruba__nan,2020-01-01,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,Aruba,,Aruba__nan,2020-01-02,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,Aruba,,Aruba__nan,2020-01-03,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,Aruba,,Aruba__nan,2020-01-04,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,Aruba,,Aruba__nan,2020-01-05,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
87064,Zimbabwe,,Zimbabwe__nan,2020-07-27,78.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0
87065,Zimbabwe,,Zimbabwe__nan,2020-07-28,192.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0
87066,Zimbabwe,,Zimbabwe__nan,2020-07-29,113.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0
87067,Zimbabwe,,Zimbabwe__nan,2020-07-30,62.0,3.0,1.0,2.0,3.0,1.0,2.0,2.0,4.0,2.0,1.0,1.0,4.0


In [13]:
df['CountryName'].unique()

array(['Aruba', 'Afghanistan', 'Angola', 'Albania', 'Andorra',
       'United Arab Emirates', 'Argentina', 'Australia', 'Austria',
       'Azerbaijan', 'Burundi', 'Belgium', 'Benin', 'Burkina Faso',
       'Bangladesh', 'Bulgaria', 'Bahrain', 'Bahamas',
       'Bosnia and Herzegovina', 'Belarus', 'Belize', 'Bermuda',
       'Bolivia', 'Brazil', 'Barbados', 'Brunei', 'Bhutan', 'Botswana',
       'Central African Republic', 'Canada', 'Switzerland', 'Chile',
       'China', "Cote d'Ivoire", 'Cameroon',
       'Democratic Republic of Congo', 'Congo', 'Colombia', 'Comoros',
       'Cape Verde', 'Costa Rica', 'Cuba', 'Cyprus', 'Czech Republic',
       'Germany', 'Djibouti', 'Dominica', 'Denmark', 'Dominican Republic',
       'Algeria', 'Ecuador', 'Egypt', 'Eritrea', 'Spain', 'Estonia',
       'Ethiopia', 'Finland', 'Fiji', 'France', 'Faeroe Islands', 'Gabon',
       'United Kingdom', 'Georgia', 'Ghana', 'Guinea', 'Gambia', 'Greece',
       'Greenland', 'Guatemala', 'Guam', 'Guyana', 'Hong Ko

In [14]:
# Set number of past days to use to make predictions
nb_lookback_days = 30

# Create training data across all countries for predicting one day ahead
X_cols = cases_col + npi_cols
y_col = cases_col
X_samples = []
y_samples = []
geo_ids = df.GeoID.unique()
for g in geo_ids:
    gdf = df[df.GeoID == g]
    all_case_data = np.array(gdf[cases_col])
    all_npi_data = np.array(gdf[npi_cols])

    # Create one sample for each day where we have enough data
    # Each sample consists of cases and npis for previous nb_lookback_days
    nb_total_days = len(gdf)
    for d in range(nb_lookback_days, nb_total_days - 1):
        X_cases = all_case_data[d-nb_lookback_days:d]

        # Take negative of npis to support positive
        # weight constraint in Lasso.
        X_npis = -all_npi_data[d - nb_lookback_days:d]

        # Flatten all input data so it fits Lasso input format.
        X_sample = np.concatenate([X_cases.flatten(),
                                   X_npis.flatten()])
        y_sample = all_case_data[d + 1]
        X_samples.append(X_sample)
        y_samples.append(y_sample)

X_samples = np.array(X_samples)
y_samples = np.array(y_samples).flatten()

In [15]:
# Helpful function to compute mae
def mae(pred, true):
    return np.mean(np.abs(pred - true))

In [16]:
# Split data into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X_samples,
                                                    y_samples,
                                                    test_size=0.2,
                                                    random_state=301)

In [17]:
# # Create and train Lasso model.
# # Set positive=True to enforce assumption that cases are positively correlated
# # with future cases and npis are negatively correlated.
# model = Lasso(alpha=0.1,
#               precompute=True,
#               max_iter=10000,
#               positive=True,
#               selection='random')
# # Fit model
# model.fit(X_train, y_train)

In [18]:
from sklearn.linear_model import SGDRegressor
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

model = make_pipeline(StandardScaler(), SGDRegressor(max_iter=1000, tol=1e-3))
model.fit(X_train, y_train)

Pipeline(steps=[('standardscaler', StandardScaler()),
                ('sgdregressor', SGDRegressor())])

In [19]:
# Evaluate model
train_preds = model.predict(X_train)
train_preds = np.maximum(train_preds, 0) # Don't predict negative cases
print('Train MAE:', mae(train_preds, y_train))

test_preds = model.predict(X_test)
test_preds = np.maximum(test_preds, 0) # Don't predict negative cases
print('Test MAE:', mae(test_preds, y_test))

Train MAE: 361.647883786552
Test MAE: 382.09622578029956


In [20]:
# # Inspect the learned feature coefficients for the model
# # to see what features it's paying attention to.

# # Give names to the features
# x_col_names = []
# for d in range(-nb_lookback_days, 0):
#     x_col_names.append('Day ' + str(d) + ' ' + cases_col[0])
# for d in range(-nb_lookback_days, 1):
#     for col_name in npi_cols:
#         x_col_names.append('Day ' + str(d) + ' ' + col_name)

# # View non-zero coefficients
# for (col, coeff) in zip(x_col_names, list(model.coef_)):
#     if coeff != 0.:
#         print(col, coeff)
# print('Intercept', model.intercept_)

In [21]:
# Save model to file
if not os.path.exists('models'):
    os.mkdir('models')
with open('models/model.pkl', 'wb') as model_file:
    pickle.dump(model, model_file)

## Evaluation

Now that the predictor has been trained and saved, this section contains the functionality for evaluating it on sample evaluation data.

In [22]:
# Reload the module to get the latest changes
import predict
from importlib import reload
reload(predict)
from predict import predict_df

In [23]:
list_countries = sorted(list(set(df.CountryName)))
hist_ips_df = pd.read_csv("data/2020-09-30_historical_ip.csv",
                              parse_dates=['Date'],
                              encoding="ISO-8859-1",
                              dtype={"RegionName": str},
                              error_bad_lines=True)
hist_ips_df = hist_ips_df[hist_ips_df.CountryName.isin(list_countries)]
hist_ips_df.to_csv("data/2020-09-30_historical_ip_new.csv" , index = False) 

In [None]:
%%time
preds_df = predict_df("2020-08-01", "2020-08-31", path_to_ips_file="data/2020-09-30_historical_ip_new.csv", verbose=True)


Predicting for Aruba__nan
2020-08-01: 930.2627582369983
2020-08-02: 837.6230469813972
2020-08-03: 1648.9951498447485
2020-08-04: 1291.3062976736046
2020-08-05: 1737.2755242505261
2020-08-06: 2913.497703760582
2020-08-07: 2014.932134812107
2020-08-08: 4599.38792518506
2020-08-09: 3272.3383932611023
2020-08-10: 5221.969334477735
2020-08-11: 4753.723032911489
2020-08-12: 6340.460698774552
2020-08-13: 6141.763893346028
2020-08-14: 8205.205486820396
2020-08-15: 7569.844776543533
2020-08-16: 11028.890474102516
2020-08-17: 8611.518217493862
2020-08-18: 12913.925904929138
2020-08-19: 11713.0692582801
2020-08-20: 13701.283155741196
2020-08-21: 16035.83589399966
2020-08-22: 16765.04646864475
2020-08-23: 18421.480219803387
2020-08-24: 20930.208754123018
2020-08-25: 21500.53239484783
2020-08-26: 24695.537088067093
2020-08-27: 27635.736347271966
2020-08-28: 27568.67236841266
2020-08-29: 35683.8201273596
2020-08-30: 32279.74151497898
2020-08-31: 41661.68846296061

Predicting for Afghanistan__nan
20

2020-08-28: 74024.55328130691
2020-08-29: 80268.9562419204
2020-08-30: 88697.88667773732
2020-08-31: 96200.41630559022

Predicting for Burundi__nan
2020-08-01: 922.3095242569884
2020-08-02: 868.2757617697871
2020-08-03: 1589.5887685794203
2020-08-04: 1427.1893047225717
2020-08-05: 2041.7929381530846
2020-08-06: 2551.638838227487
2020-08-07: 2713.23221494491
2020-08-08: 4029.849742013482
2020-08-09: 3874.95844631848
2020-08-10: 5088.08623116754
2020-08-11: 5218.599370136873
2020-08-12: 6092.814516137273
2020-08-13: 6991.318131817169
2020-08-14: 7558.444405233465
2020-08-15: 8926.097420780323
2020-08-16: 9601.917205865911
2020-08-17: 10275.790561136608
2020-08-18: 11708.621122269997
2020-08-19: 12405.966850463621
2020-08-20: 13575.436881454221
2020-08-21: 15466.91558608594
2020-08-22: 16329.176154977024
2020-08-23: 18039.661332815274
2020-08-24: 19567.998925653123
2020-08-25: 20693.424230349963
2020-08-26: 23667.95080369925
2020-08-27: 24541.578072735872
2020-08-28: 27916.4185762845
2020

2020-08-26: 27275.040495460566
2020-08-27: 26209.902437358036
2020-08-28: 32513.83686621752
2020-08-29: 33074.991804676276
2020-08-30: 37322.55446427177
2020-08-31: 40916.13458238477

Predicting for Bolivia__nan
2020-08-01: 1089.83550599847
2020-08-02: 1127.903007912413
2020-08-03: 1695.528551811481
2020-08-04: 1533.2189227051304
2020-08-05: 2093.730520078559
2020-08-06: 2701.5247277389526
2020-08-07: 2988.7298449546215
2020-08-08: 4053.576417858274
2020-08-09: 4366.293887505509
2020-08-10: 5003.488539697529
2020-08-11: 5718.830399808792
2020-08-12: 5913.30708552141
2020-08-13: 7663.995420207042
2020-08-14: 7740.272703512027
2020-08-15: 9826.91404067752
2020-08-16: 9876.140451155026
2020-08-17: 11287.337815304336
2020-08-18: 12459.97361441019
2020-08-19: 13825.584551220709
2020-08-20: 15384.941272546994
2020-08-21: 17404.89350888141
2020-08-22: 18934.540881874178
2020-08-23: 20866.136896116255
2020-08-24: 23146.51974447266
2020-08-25: 24553.866500210122
2020-08-26: 28711.82543660259
20

2020-08-01: 937.573295931549
2020-08-02: 849.617790270912
2020-08-03: 1598.1426816072953
2020-08-04: 1391.81815354228
2020-08-05: 1819.1217012292084
2020-08-06: 2863.2085194720084
2020-08-07: 2385.4472696922794
2020-08-08: 4253.557176720535
2020-08-09: 3771.9074603844374
2020-08-10: 4932.9439231637125
2020-08-11: 5398.476356783102
2020-08-12: 5918.252327152743
2020-08-13: 7342.357131780546
2020-08-14: 7430.483950743052
2020-08-15: 9459.309768370993
2020-08-16: 9612.578522751803
2020-08-17: 11078.403282308675
2020-08-18: 12186.091237303139
2020-08-19: 13926.683255784337
2020-08-20: 15175.490624439002
2020-08-21: 17149.032779295885
2020-08-22: 18798.206856510427
2020-08-23: 20729.06679299568
2020-08-24: 22897.529332023227
2020-08-25: 24942.529604731353
2020-08-26: 28166.038913125394
2020-08-27: 29998.4136862139
2020-08-28: 34247.7462483496
2020-08-29: 36639.933116614695
2020-08-30: 40723.73549467092
2020-08-31: 44486.67032347545

Predicting for Cote d'Ivoire__nan
2020-08-01: 945.77620296

In [None]:
# Check the predictions
preds_df.head()

# Validation
This is how the predictor is going to be called during the competition.  
!!! PLEASE DO NOT CHANGE THE API !!!

In [None]:
!python3 predict.py -s 2020-08-01 -e 2020-08-04 -ip data/2020-09-30_historical_ip_new.csv -o predictions/2020-08-01_2020-08-04.csv

In [None]:
!head predictions/2020-08-01_2020-08-04.csv

# Test cases
We can generate a prediction file. Let's validate a few cases...

In [None]:
import os
from predictor_validation import validate_submission

def validate(start_date, end_date, ip_file, output_file):
    # First, delete any potential old file
    try:
        os.remove(output_file)
    except OSError:
        pass
    
    # Then generate the prediction, calling the official API
    !python3 predict.py -s {start_date} -e {end_date} -ip {ip_file} -o {output_file}
    
    # And validate it
    errors = validate_submission(start_date, end_date, ip_file, output_file)
    if errors:
        for error in errors:
            print(error)
    else:
        print("All good!")

## 4 days, no gap
- All countries and regions
- Official number of cases is known up to start_date
- Intervention Plans are the official ones

In [None]:
validate(start_date="2020-08-01",
         end_date="2020-08-04",
         ip_file="data/2020-09-30_historical_ip_new.csv",
         output_file="predictions/val_4_days.csv")

## 1 month in the future
- 2 countries only
- there's a gap between date of last known number of cases and start_date
- For future dates, Intervention Plans contains scenarios for which predictions are requested to answer the question: what will happen if we apply these plans?

In [None]:
%%time
validate(start_date="2021-01-01",
         end_date="2021-01-31",
         ip_file="validation/data/future_ip.csv",
         output_file="predictions/val_1_month_future.csv")

## 180 days, from a future date, all countries and regions
- Prediction start date is 1 week from now. (i.e. assuming submission date is 1 week from now)  
- Prediction end date is 6 months after start date.  
- Prediction is requested for all available countries and regions.  
- Intervention plan scenario: freeze last known intervention plans for each country and region.  

As the number of cases is not known yet between today and start date, but the model relies on them, the model has to predict them in order to use them.  
This test is the most demanding test. It should take less than 1 hour to generate the prediction file.

### Generate the scenario

In [None]:
from datetime import datetime, timedelta

start_date = datetime.now() + timedelta(days=7)
start_date_str = start_date.strftime('%Y-%m-%d')
end_date = start_date + timedelta(days=180)
end_date_str = end_date.strftime('%Y-%m-%d')
print(f"Start date: {start_date_str}")
print(f"End date: {end_date_str}")

In [None]:
from validation.scenario_generator import get_raw_data, generate_scenario, NPI_COLUMNS
DATA_FILE = 'data/OxCGRT_latest.csv'
latest_df = get_raw_data(DATA_FILE, latest=True)
scenario_df = generate_scenario(start_date_str, end_date_str, latest_df, countries=None, scenario="Freeze")
scenario_file = "predictions/180_days_future_scenario.csv"
scenario_df.to_csv(scenario_file, index=False)
print(f"Saved scenario to {scenario_file}")

### Check it

In [None]:
%%time
validate(start_date=start_date_str,
         end_date=end_date_str,
         ip_file=scenario_file,
         output_file="predictions/val_6_month_future.csv")

## SPAIN

In [None]:
start_date = datetime.now() + timedelta(days=7)
start_date_str = start_date.strftime('%Y-%m-%d')
end_date = start_date + timedelta(days=180)
end_date_str = end_date.strftime('%Y-%m-%d')
print(f"Start date: {start_date_str}")
print(f"End date: {end_date_str}")

DATA_FILE = 'data/OxCGRT_latest.csv'
latest_df = get_raw_data(DATA_FILE, latest=True)
scenario_df = generate_scenario(start_date_str, end_date_str, latest_df, countries=['France'], scenario="Freeze")
scenario_file = "predictions/180_days_future_scenario_france_freeze.csv"
scenario_df.to_csv(scenario_file, index=False)
print(f"Saved scenario to {scenario_file}")

validate(start_date=start_date_str,
         end_date=end_date_str,
         ip_file=scenario_file,
         output_file="predictions/val_6_month_future_france_freeze.csv")

In [None]:
df_verify = pd.read_csv('predictions/val_6_month_future_spain_max.csv')
x = df_verify['Date'].to_numpy()
y = df_verify['PredictedDailyNewCases'].to_numpy()

%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(x,y)