Copyright 2021 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
!git clone https://github.com/google-research/google-research.git

import sys
sys.path.append('./google-research')

In [None]:
# Verify credentials for BQ.

from colabtools import auth
from colabtools import bigquery

from colabtools import adhoc_import
from covid_epidemiology import colab_utils

creds = auth.get_user_oauth2_credentials(bigquery.SCOPES)
client = bigquery.Client(
    project=colab_utils.constants.PROJECT_ID_MODEL_TRAINING, credentials=creds)

## Utils

### Prospective utils

In [None]:
import datetime

prospective_start = datetime.datetime.strptime(
    colab_utils.constants.PROSPECTIVE_START_DATE, '%Y-%m-%d').date()
prospective_end = datetime.datetime.strptime(
    colab_utils.constants.PROSPECTIVE_END_DATE, '%Y-%m-%d').date()

import pandas as pd
import numpy as np
import time
def gather_prospective_data(
    forecast_pd, pred_feature_key, gt_feature_keys, loc_key, locale, 
    use_latest_gt, debug=False):
  """Returns dataframe of date as key, dict of {dates, preds, gts} as value."""
  assert isinstance(gt_feature_keys, list)
  assert locale in ('japan', 'state')
  if locale == 'japan':
    # Sanity check that we're using ex "Hyogo" and not "Hyōgo".
    assert set(colab_utils.kaz_locations_to_open_covid_map().keys()).issubset(
        set(forecast_pd.prefecture_name.unique()))

  # Only use the relevant subset of forecasts.
  forecast_pd = forecast_pd[forecast_pd.prediction_date > forecast_pd.forecast_date]
  forecast_pd = forecast_pd[["forecast_date", "prediction_date", loc_key, pred_feature_key]]
  assert forecast_pd[pred_feature_key].notna().all()

  # Add GT version to forecast_pd.
  min_version = {
      'japan': '2020-10-07 20:01:32 UTC',
      'state': None,
  }[locale]
  forecast_dates = list(forecast_pd.forecast_date.unique())
  gt_versions = colab_utils.get_gt_version_names(
      forecast_dates, locale, min_version, use_latest_gt=use_latest_gt,
      client=client)
  forecast_to_version_map = {
      f: gtv for (f, gtv) in zip(forecast_dates, gt_versions)}
  forecast_pd['gt_version'] = forecast_pd.forecast_date.replace(
      forecast_to_version_map)
  print('Got all GT versions.')
  if debug:
    for k, v in forecast_to_version_map.items():
      print(f'{k}: {v}')

  # Get ground truth forecasts.
  gt_versions_to_use = forecast_pd.gt_version.unique().tolist()
  if locale == 'japan' and not use_latest_gt:
    # Add an extra buffer at the end so automatic increment can work.
    gt_versions_to_use += colab_utils.get_gt_version_names(
      [max(forecast_dates) + datetime.timedelta(days=i) for i in range(1, 10)], 
      locale, 
      min_version, 
      use_latest_gt=use_latest_gt,
      client=client)
  gt_df = colab_utils.get_all_gt(
      # Get the GT for "day 0" in case we want to compute incident cases/deaths,
      # which is always computed as a delta from GT at day 0. So we get 2 days
      # before the first prediction date, instead of 1 day.
      min(forecast_pd.prediction_date) - datetime.timedelta(days=2),
      # Add a date buffer at the end in case we need to auto-increment.
      max(forecast_pd.prediction_date),
      locale=locale,
      bq_client=client,
      version=gt_versions_to_use,
      feature_keys=gt_feature_keys)
  print('Finished reading GT.')

  # Check that versions are subset of available versions..
  available_versions = gt_df.version.dt.strftime(
      '%Y-%m-%d %H:%M:%S+00:00')
  assert forecast_pd.gt_version.isin(available_versions).all()
  
  # Get predictions and GT in one row.
  start_time = time.time()
  forecast_pd.rename(columns={pred_feature_key: 'predictions'}, inplace=True)
  gathered_data_df = forecast_pd.groupby(
      ['forecast_date', loc_key]).apply(
          colab_utils.gather_data_from_prospective_row, 
          gt_df=gt_df,
          locale=locale,
          available_versions=np.unique(available_versions.values),
          debug=debug)
  gathered_data_df = gathered_data_df.to_frame().reset_index().convert_dtypes()
  gathered_data_df['forecast_date'] = gathered_data_df.forecast_date.astype(str)
  gathered_data_df.rename(columns={loc_key: 'location_name'}, inplace=True)
  end_time = time.time()
  total_time_min = (end_time - start_time) / 60.0
  print(f'Total time: {total_time_min:.2f} min')
  print(f'Time per date: {total_time_min / len(gathered_data_df):.2f} min / item')

  return gathered_data_df

### Plotting and filtering utils

In [None]:
%matplotlib inline
import matplotlib
from matplotlib import pyplot as plt

def plot_events(events, d_min, max_y, line_top=0.5):
  for i, (d, name, ybuff, xbuff) in enumerate(events):
    plt.axvline(d, ymin=0.0, ymax=line_top, color='r')
    if d > d_min:
      plt.annotate(
          name, 
          xy=(d, line_top * max_y), 
          fontsize=16,
          xytext=(d + datetime.timedelta(days=xbuff), (line_top + ybuff) * max_y),
          arrowprops={'arrowstyle': '->'}
          )

def calculate_mape(base_pd, calculate_mape_apply_fn_args):
  base = base_pd.groupby('forecast_date').apply(
      colab_utils.calculate_mape_apply_fn, **calculate_mape_apply_fn_args).dropna()
  base_xs = [datetime.datetime.strptime(x, '%Y-%m-%d').date() for x in base.index]
  base_ys = base.values
  return base_xs, base_ys

def write_to_cns(xs, ys, name):
  colab_utils.write_csv_to_cns(
    data_dict={'forecast_date': xs, 'pred': ys}, graph_name=name)
  
def print_mape_stats(avg_type, locale, prosp_confirmed_ys, prosp_deaths_ys):
  assert locale in ['Japan', 'US State']
  def _print(dat, window, metric, ignore_nan):
    if dat is None: return
    m, l, u = colab_utils.mean_confidence_interval(
        dat, confidence=0.95, ignore_nan=ignore_nan)
    assert m == np.nanmean(dat)
    print(f'{avg_type}_{locale}_{window}_{metric}_MAPE: {m:.2f} [{l:.2f}, {u:.2f}]')
  _print(prosp_confirmed_ys, 'prospective', 'confirmed', ignore_nan=False)
  _print(prosp_deaths_ys, 'prospective', 'deaths', ignore_nan=False)

## Plots

### Japan

#### Load prospective data.

In [None]:
# Load prospective Japan data.
# Only look at ground truth values that we'll care about ie in the prospective 
# period.
bq_table = 'bigquery-public-data.covid19_public_forecasts.japan_prefecture_28d_historical'
q = (f"select * from `{bq_table}`"
     f"where forecast_date >= '{prospective_start.isoformat()}' "
     f"and forecast_date <= '{prospective_end.isoformat()}'")
jp_forecast_pd = client.query(q).to_dataframe()

jp_prosp_confirmed_df = gather_prospective_data(
    jp_forecast_pd, 
    pred_feature_key='cumulative_confirmed', 
    gt_feature_keys=['kaz_confirmed_cases', 'open_gt_jp_confirmed_cases'],
    loc_key='prefecture_name', 
    locale='japan',
    use_latest_gt=False)
jp_prosp_deaths_df = gather_prospective_data(
    jp_forecast_pd, 
    pred_feature_key='cumulative_deaths', 
    gt_feature_keys=['kaz_deaths', 'open_gt_jp_deaths'],
    loc_key='prefecture_name', 
    locale='japan',
    use_latest_gt=False)

#### Plot micro average.

In [None]:
# Micro average.
calculate_mape_apply_fn_args = {
    'average_type': 'micro', 
    'expected_num_locations': 47,
    'min_count': None,
    'min_mae': None,
    'value_type': '4week',
}

# Prospective.
jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(
    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)
jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(
    jp_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.legend()
plt.grid()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title('MAPE vs training date, micro average')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Micro average', 
    locale='Japan', 
    prosp_confirmed_ys=jp_prosp_confirmed_ys, 
    prosp_deaths_ys=jp_prosp_deaths_ys)

#### Plot macro average.

In [None]:
# Macro average.
calculate_mape_apply_fn_args = {
    'average_type': 'macro', 
    'expected_num_locations': 47,
    'min_count': None, 
    'min_mae': None,
    'value_type': '4week',
}

# Prospective.
jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(
    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)
jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(
    jp_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.legend()
plt.grid()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title('MAPE vs training date, macro average')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Macro average', 
    locale='Japan', 
    prosp_confirmed_ys=jp_prosp_confirmed_ys, 
    prosp_deaths_ys=jp_prosp_deaths_ys)

#### Plot macro average, with min count.

In [None]:
# Macro average, with min count.

calculate_mape_apply_fn_args = {
    'average_type': 'macro', 
    'expected_num_locations': 47,
    'min_mae': None,
    'value_type': '4week',
}

# Prospective.
calculate_mape_apply_fn_args['min_count'] = 1000
jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(
    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)
calculate_mape_apply_fn_args['min_count'] = 10
jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(
    jp_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.grid()
plt.legend()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title('MAPE vs training date, macro average (>=1K cases or 10 deaths)')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Macro average, with min count', 
    locale='Japan', 
    prosp_confirmed_ys=jp_prosp_confirmed_ys, 
    prosp_deaths_ys=jp_prosp_deaths_ys)

#### Plot macro average, with min mae.

In [None]:
# Macro average, with min mae.

calculate_mape_apply_fn_args = {
    'average_type': 'macro', 
    'expected_num_locations': 47,
    'min_mae': 10.0,
    'min_count': None,
    'value_type': 'daily',
}

# Prospective.
jp_prosp_confirmed_xs, jp_prosp_confirmed_ys = calculate_mape(
    jp_prosp_confirmed_df, calculate_mape_apply_fn_args)
jp_prosp_deaths_xs, jp_prosp_deaths_ys = calculate_mape(
    jp_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(jp_prosp_confirmed_xs, jp_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(jp_prosp_deaths_xs, jp_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.grid()
plt.legend()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title(f'MAPE vs training date (>= {calculate_mape_apply_fn_args["min_mae"]} MAE)')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Macro average, with mae threshold', 
    locale='Japan', 
    prosp_confirmed_ys=jp_prosp_confirmed_ys, 
    prosp_deaths_ys=jp_prosp_deaths_ys)

### US

#### Load prospective data.

In [None]:
# Load data from BQ.
# Only look at ground truth values that we'll care about ie in the prospective 
# period.
bq_table = 'bigquery-public-data.covid19_public_forecasts.state_28d_historical'
q = (f"select * from `{bq_table}`"
     f"where forecast_date >= '{prospective_start.isoformat()}' "
     f"and forecast_date <= '{prospective_end.isoformat()}'")
us_forecast_pd = client.query(q).to_dataframe()

# The US has double Hawaii, which we should remove.
with_double_hawaii = len(us_forecast_pd)
us_forecast_pd_dedupped = us_forecast_pd.drop_duplicates(
    subset=['state_fips_code', 'state_name', 'prediction_date', 'forecast_date'])
dups_dropped = with_double_hawaii - len(us_forecast_pd_dedupped)
print(f'dropped {dups_dropped} rows...')

us_prosp_confirmed_df = gather_prospective_data(
    us_forecast_pd_dedupped, 
    pred_feature_key='cumulative_confirmed', 
    gt_feature_keys=['jhu_state_confirmed_cases'],
    loc_key='state_name', 
    locale='state',
    use_latest_gt=False)
us_prosp_deaths_df = gather_prospective_data(
    us_forecast_pd_dedupped, 
    pred_feature_key='cumulative_deaths', 
    gt_feature_keys=['jhu_state_deaths'],
    loc_key='state_name', 
    locale='state',
    use_latest_gt=False)

#### Plot micro average.


In [None]:
# Micro average.

calculate_mape_apply_fn_args = {
    'average_type': 'micro', 
    'expected_num_locations': 51,
    'min_mae': None,
    'min_count': None,
    'value_type': '4week',
}

# Prospective.
us_prosp_confirmed_xs, us_prosp_confirmed_ys = calculate_mape(
    us_prosp_confirmed_df, calculate_mape_apply_fn_args)
us_prosp_deaths_xs, us_prosp_deaths_ys = calculate_mape(
    us_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(us_prosp_confirmed_xs, us_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(us_prosp_deaths_xs, us_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.grid()
plt.legend()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title(f'MAPE vs training date, Micro average')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Micro average', 
    locale='US State', 
    prosp_confirmed_ys=us_prosp_confirmed_ys, 
    prosp_deaths_ys=us_prosp_deaths_ys)

#### Plot macro average.

In [None]:
# Macro average.
calculate_mape_apply_fn_args = {
    'average_type': 'macro', 
    'expected_num_locations': 51,
    'min_mae': None,
    'min_count': None,
    'value_type': '4week',
}

# Prospective.
us_prosp_confirmed_xs, us_prosp_confirmed_ys = calculate_mape(
    us_prosp_confirmed_df, calculate_mape_apply_fn_args)
us_prosp_deaths_xs, us_prosp_deaths_ys = calculate_mape(
    us_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(us_prosp_confirmed_xs, us_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(us_prosp_deaths_xs, us_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.grid()
plt.legend()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title(f'MAPE vs training date, macro average')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Macro average',
    locale='US State', 
    prosp_confirmed_ys=us_prosp_confirmed_ys, 
    prosp_deaths_ys=us_prosp_deaths_ys)

#### Plot macro average, with min mae.

In [None]:
# Macro average, with min mae.

calculate_mape_apply_fn_args = {
    'average_type': 'macro', 
    'expected_num_locations': 51,
    'min_mae': 100.0,
    'min_count': None,
    'value_type': '4week',
}

# Prospective.
us_prosp_confirmed_xs, us_prosp_confirmed_ys = calculate_mape(
    us_prosp_confirmed_df, calculate_mape_apply_fn_args)
us_prosp_deaths_xs, us_prosp_deaths_ys = calculate_mape(
    us_prosp_deaths_df, calculate_mape_apply_fn_args)

plt.figure(figsize=(16, 6))
# Plot prospective.
plt.plot(us_prosp_confirmed_xs, us_prosp_confirmed_ys, label='confirmed, prospective', 
         marker='o', color='b')
plt.plot(us_prosp_deaths_xs, us_prosp_deaths_ys, label='deaths, prospective', 
          marker='o', color='g')

# Formatting.
plt.grid()
plt.legend()
plt.xlabel('train date')
plt.ylabel('MAPE')
plt.title(f'MAPE vs training date (>= {calculate_mape_apply_fn_args["min_mae"]} MAE)')
plt.show()

In [None]:
print_mape_stats(
    avg_type='Macro average, with min mae threshold',
    locale='US State', 
    prosp_confirmed_ys=us_prosp_confirmed_ys, 
    prosp_deaths_ys=us_prosp_deaths_ys)