In [None]:
from datetime import date, datetime, timedelta
from timeit import default_timer as timer
import os, io, re, pandas as pd, numpy as np
import logging
from ibmpairs import paw
import scipy.stats
from scipy.stats import spearmanr
import requests
import json
import yaml
import geopandas as gpd
from itertools import cycle
from IPython.core.display import HTML
# library used for visualization
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
# package supporting GAM implementation in R
import packages.gaa.gam as gamModel
import packages.gaa.analysis as analysisHelper

# COVID-19 Geospatial Correlation & Association
This Notebook can be used to determine the [**Spearman's rank correlation coefficient**](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient) or a [**log-linear generalized additive model (GAM)**](https://en.wikipedia.org/wiki/Generalized_additive_model) between COVID-19 cases and Geospatial & Temporal information out of [**IBM PAIRS Geoscope**](https://ibmpairs.mybluemix.net/).

The Notebook will run through 10 steps.

If you just want to run the Notebook you need to provide input for **Step 1** & **Step 2**. You can then run the analysis for the included countries which are **The Netherlands (NL), France (FR), Denmark (DK), Sweden (SE), Germany (DE), United States (US)** & **India (IN)**. Steps 3 up to 10 will run without additional input, but can be altered where necessary.

If you want to extend the Notebook with additional countries you need to extend **Step 4** & **Step 5** following the examples of the already included countries.

Please note that for DK & SE you need to download the COVID-19 files manually. The sources are specified in **Step 4**. The other countries download the COVID-19 data directly from a URL.

### Steps

1. **Create User Config File**: an IBM PAIRS User Account with API access is needed. These credentials must be stored in a private configuration file.

 NB: If no account is available the Notebook can still be used. For this, do not create the config file as defined in Step 1. The Notebook will continue but only with IBM PAIRS Local Cached data. As example a cached NL dataset for TemperatureAboveGround, RelativeHumiditySurface, WindSpeed, UVIndex is included.
 
2. **Define Analysis**: the Notebook must be configured to run the analysis for the desired Country, Time & Variables

3. **Set Global Variables**: set global variables that are used in the remainder of the Notebook

4. **Get Country Geospatial Data**: collect metadata of the country (e.g. regions & population) and geospatial vector data

5. **Get Country COVID-19 Data Set**: collect COVID-19 case data (e.g. hospitalized, recovered, deceased patients)

6. **Get IBM PAIRS Geospatial & Temporal data**: collect the geospatial information needed for the analysis

7. **Add Rolling Windows & Time Shifts**: add rolling windows & time shifts to model e.g. incubation time
8. **Merge COVID-19 & IBM PAIRS datasets**: merge the data on COVID-19 with the geospatial dataset
9. **Determine Spearman or GAM**: calculate the spearman correlation coefficient or GAM association and the significance of it
10. **Create Visualizations**: create various plots (e.g. line charts, SPLOMs, Choropleths) to visualize the results of the input & analysis result

## 1. Create User Config File
To use this Notebook you need to create a YAML file **private-user-config.yml** in the root the folder. This file holds the access credentials to IBM PAIRS. The YAML file needs the following structure:
```
    ibm_pairs:
        server: https://pairs.res.ibm.com
        user:
        api_key: 
```
If you don't have access to the IBM PAIRS API, then you can request accesss via the [IBM PAIRS Access procedure](../IBMPAIRS-Access.md).

If you want to use the Notebook without access to IBM PAIRS you can. However, you are restricted to the IBM PAIRS Cache File that is included in this GitHub Repo for NL.

In [None]:
BASE_FILE_PATH          = os.path.realpath("../")
USER_CONFIG             = None
USER_CONFIG_FILE        = 'private-user-config.yml'

try:
    with open(BASE_FILE_PATH + "/" + USER_CONFIG_FILE, 'r') as user_config_file:
        USER_CONFIG = yaml.safe_load(user_config_file)
except:
    print('Cannot find {}/{}'.format(BASE_FILE_PATH,USER_CONFIG_FILE))
    print('\nPlease create this file or continue while using cached IBM PAIRS only.')

## 2. Define Analysis
The variable **ANALYSIS** controls what analysis is executed by running the Notebook. The 'analysis' key must be set to a value that is present in the **ANALYSIS_DEFINITIONS** dictionary.

Each **ANALYSIS_DEFINITION** must have a unique key and contains 3 values:
- *country_code*: ISO-3166 of the country for which the analysis is run. Seven countries are made available in this notebook (NL, FR, DK, SE, DE, US, IN) but more can be added.

    The country_code determines:
    1. the geospatial boundries
    2. the COVID-19 datasource
    3. the data queried from IBM PAIRS
    
    
- *model*: MODEL_DEFINITIONS key, used to define the algorithm & parameters used for the analysis
- *time_window*: TIME_WINDOW_DEFINITIONS key, used to resrict the analysis to a specific time slice

Each **MODEL_DEFINITION** must have a unique key and contains the parameters appropriate for the given model. The current Notebook supports *spearman* and *gam*, additional models might be added in later versions.

The **spearman** model has 6 values:
- *model*: controls the algorithm to use, set to 'spearman' to run a spearman correlaction model.
- *predictor*: first rank variable in the spearman algorithm. Must be an existing key in PAIRS_QUERY_DEFINITONS. Only the first predictor specified is currently used in the analysis & visualization.
- *outcome*: second rank variable(s) in the spearman algorithm. The correlation between the *predictor* and all *outcome* variables specified is determined. The *outcome* variables refer to the columnns available in the cleansed data source. Regular expressions can be used to specify e.g. wildcards.
- *alpha*: threshold to determine whether the correlation is deemed significant.
- *rolling_windows*: the number of days for which the rolling window is calculated. Multiple windows can be specified.
- *rolling_window_type*: the arithmetic operation applied for the rolling_window. Supported are 'mean' and 'sum'
- *time_shifts*: the number of days the predictor is shifted. The time shift can be used to model the impact of a time lag between the predictor and the outcome. Multiple time shifts can be specified.

The **gam** model has 7 values:
- *model*: controls the algorithm to use, set to 'gam' to run a log-linear GAM.
- *independent_variables*: predictor variables for which it is determined if there is an association with the dependent variable(s) specified in the *outcome* field. Must be an existing key in PAIRS_QUERY_DEFINITONS.
- *control_variables*: [confounding variables](https://en.wikipedia.org/wiki/Confounding) which are variables, other than the independent variables in focus, that may affect the outcome and thus, may lead to erroneous conclusions about the relationship between the independent and outcome variables. Must be an existing key in PAIRS_QUERY_DEFINITONS and / or DOW (Day of Week).
- *outcome*: dependent variable(s) in the GAM algorithm. The association between the *independent_variables* and all *outcome* variables specified is determined. The *outcome* variables refer to the columnns available in the cleansed data source. Regular expressions can be used to specify e.g. wildcards.
- *alpha*: threshold to determine whether the association is deemed significant.
- *rolling_windows*: the number of days for which the rolling window is calculated. Multiple windows can be specified.
- *rolling_window_type*: the arithmetic operation applied for the rolling_window. Supported are 'mean' and 'sum'
- *time_shifts*: the number of days the predictor is shifted. The time shift can be used to model the impact of a time lag between the predictor and the outcome. Multiple time shifts can be specified.

Each **TIME_WINDOW_DEFINITION** must have a unique key and controls the time window for which the analysis is run. Each definition has 2 values:

- *window_start*: the first date to include in the analysis.
- *window_end*: the last date to include in the analysis.

A date can be fixed (e.g. date(2020, 3, 6)) or relative (e.g. date.today()) as long as it is a valid [Python datetime.date](https://docs.python.org/3.8/library/datetime.html).

NB: The cases in a given country are presumably heavily influenced by the time of introduction of the virus in the region and the measures taken to control the outbreak. Therefore one must give a good consideration on the influence of the Time Window on the observered correlation.

As **example only** the Notebook contains a definition for each region. The examples are based on the [Oxford COVID-19 Government Response Tracker](https://covidtracker.bsg.ox.ac.uk/). The *start_date* is set to the first day with 100+ cases and the *end_date* is set to 14 days after the Stringency Index is 70+ for the given country.   

Each **PAIRS_QUERY_DEFINITION** must have a unique key and controls what data is queried from IBM PAIRS. The *predictor* value(s) determine the definition used in the given model. Each definition has 2 values:

- *layer_id*: ID of the layer in the IBM PAIRS Dataset that is used in the model (e.g. 49311 for the *UV Index* layer out of the *Current and historical weather (IBM TWC)* dataset)
- *aggregation*: the arithmetic operation applied for the temporal aggregation if there are multiple measurements for the same raster in IBM PAIRS. Supported are 'None', 'Min', 'Max', 'Mean', 'Sum'.


In [None]:
# Analysis to run in the Notebook
ANALYSIS                = {'analysis':'GAM_ARTICLE_4_PLOT_1'}

ANALYSIS_DEFINITIONS    = {
    'ARTICLE_2_PLOT_1' : {
        'country_code':'NL',
        'model':       'SpearmanUVIndex_AllOutcomeAvailable',
        'time_window': 'ARTICLE_2_PLOT_1'},
    'ARTICLE_2_PLOT_2' : {
        'country_code':'NL','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'ARTICLE_2_PLOT_2'},
    'NL_UVIndex_ToLockdown' : {
        'country_code':'NL','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'NL_Start->LockDown+14d'},
    'FR_UVIndex_ToLockdown' : {
        'country_code':'FR','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'FR_Start->LockDown+14d'},
    'DK_UVIndex_ToLockdown' : {
        'country_code':'DK','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'DK_Start->LockDown+14d'},
    'SE_UVIndex_Today': {
        'country_code':'SE','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'SE_Start->Today'},
    'DE_UVIndex_ToLockdown' : {
        'country_code':'DE','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'DE_Start->LockDown+14d'},   
    'US_UVIndex_Today' : {
        'country_code':'US','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'US_Start->Today'},
    'IN_UVIndex_Today' : {
        'country_code':'IN','model':'SpearmanUVIndex_AllOutcomeAvailable','time_window':'WHO_PandemicStart->Today'},
    'Spearman_ARTICLE_4_PLOT_1' : {
        'country_code':'NL',
        'model':       'SpearmanTemp_LimitOutcome',
        'time_window': 'ARTICLE_4_PLOT_1'},
    'GAM_ARTICLE_4_PLOT_1' : {
        'country_code':'NL',
        'model':       'Gam_UTR_WD_LimitOutcome',
        'time_window': 'ARTICLE_4_PLOT_1'},
}

MODEL_DEFINITIONS    = {
    'SpearmanUVIndex_LimitOutcome' : {
        'model':'spearman',
        'predictor':['UVIndex'],
        'outcome': ['hospitalized_addition_population_weighted', 'deceased_addition_population_weighted'],
        'alpha': 0.001,
        'rolling_windows':[7],
        'rolling_window_type':'mean', 
        'time_shifts' : [0,7,14]
        },
    
    'SpearmanUVIndex_AllOutcomeAvailable' : {
        'model':'spearman',
        'predictor':['UVIndex'],
        'outcome': ['.+_addition_population_weighted'],
        'alpha': 0.001,
        'rolling_windows':[7],
        'rolling_window_type':'sum', 
        'time_shifts' : [0,7,14]
        },
    
    'SpearmanRelativeHumiditySurface_LimitOutcome' : {
        'model':'spearman',
        'predictor':['RelativeHumiditySurface'],
        'outcome': ['hospitalized_addition_population_weighted', 'deceased_addition_population_weighted'],
        'alpha': 0.001,   
        'rolling_windows':[7],
        'rolling_window_type':'mean', 
        'time_shifts' : [0,7,14]
        },

    'SpearmanTemp_LimitOutcome' : {
        'model':'spearman',
        'predictor':['UVIndex'],
        'outcome': ['confirmed_addition_population_weighted'],
        'alpha': 0.05,   
        'rolling_windows':[1,7],
        'rolling_window_type':'mean',
        'time_shifts' : [0,7,14]
        }, 

    'Gam_UTR_WD_LimitOutcome' : {
        'model':'gam',
        'independent_variables':['UVIndex', 'TemperatureAboveGround', 'RelativeHumiditySurface'],
        'control_variables':['WindSpeed','DOW'],
        'outcome': ['confirmed_addition'],
        'alpha': 0.05,   
        'rolling_windows':[1,7],
        'rolling_window_type':'mean',
        'time_shifts' : [0,7,14]
        }, 
}

TIME_WINDOW_DEFINITIONS    = {
    'ARTICLE_2_PLOT_1': {'window_start' : date(2020, 3, 6), 'window_end' : date(2020, 6, 19)},
    'ARTICLE_2_PLOT_2': {'window_start' : date(2020, 4, 1), 'window_end' : date(2020, 6, 19)},
    'ARTICLE_4_PLOT_1': {'window_start' : date(2020, 4, 1), 'window_end' : date(2020, 11, 15)},
    'NL_SecondWave': {'window_start' : date(2020, 8, 31), 'window_end' : date(2020, 11, 15)},
    'WHO_PandemicStart->Today' :{'window_start' : date(2020, 3, 11), 'window_end' : date.today()},
    'NL_Start->LockDown+14d' :{'window_start' : date(2020, 3, 6), 'window_end' : date(2020, 3, 30)},
    'FR_Start->LockDown+14d' :{'window_start' : date(2020, 3, 1), 'window_end' : date(2020, 5, 1)},
    'DK_Start->LockDown+14d' :{'window_start' : date(2020, 3, 10), 'window_end' : date(2020, 3, 27)},
    'SE_Start->Today' :{'window_start' : date(2020, 3, 7), 'window_end' : date.today()},
    'DE_Start->LockDown+14d' :{'window_start' : date(2020, 3, 1), 'window_end' : date(2020, 4, 4)},
    'US_Start->LockDown+14d' :{'window_start' : date(2020, 3, 3), 'window_end' : date(2020, 4, 2)},
    'IN_Start->LockDown+14d' :{'window_start' : date(2020, 3, 16), 'window_end' : date(2020, 4, 3)},
}

# parameters used in IBM PAIRS Query
PAIRS_QUERY_DEFINITIONS = {
    'TemperatureAboveGround': {'layer_id':'49257', 'aggregation':'Mean'},
    'RelativeHumiditySurface': {'layer_id':'49252', 'aggregation':'Mean'},
    'WindSpeed': {'layer_id':'49313', 'aggregation':'Mean'},
    'UVIndex': {'layer_id':'49311', 'aggregation':'Sum'},
    'SolarRadiation': {'layer_id':'49424', 'aggregation':'Sum'},
    'Soiltemperature': {'layer_id':'49446', 'aggregation':'Sum'},
    'Surfacepressure': {'layer_id':'49439', 'aggregation':'Sum'},
    'Totalprecipitation': {'layer_id':'49459', 'aggregation':'Sum'},
    'Dewpoint': {'layer_id':'49422', 'aggregation':'Sum'},
}

## 3. Set Global Variables
In order to simplify the remainer of the Notebook a series of Global Variables are set. The coding convention in the Notebook is that Global Variables are defined in CAPITALS.

The Global Variables should **not be changed**, unless you want to alter the code in the Notebook.

In [None]:
# Define the analysis the Notebook will process
ANALYSIS_DEFINITION     = ANALYSIS_DEFINITIONS[ANALYSIS['analysis']]
ANALYSIS['country_code']= ANALYSIS_DEFINITION['country_code']
ANALYSIS['model']       = MODEL_DEFINITIONS[ANALYSIS_DEFINITION['model']]
ANALYSIS['time_window'] = TIME_WINDOW_DEFINITIONS[ANALYSIS_DEFINITION['time_window']]

# Define the predictors in case of a gam model, this model has two types of variables of which some are not retrieved from PAIRS
if(ANALYSIS['model']['model'] == 'gam'):
    ANALYSIS['model']['predictor'] = ANALYSIS['model']['independent_variables'] + ANALYSIS['model']['control_variables']
    for predictor in ANALYSIS['model']['predictor']:
        if (predictor not in PAIRS_QUERY_DEFINITIONS.keys()):
            ANALYSIS['model']['predictor'].remove(predictor)
        
# Define the paths to locally store & retrieve files
GLOBAL_FILE_PATH        = BASE_FILE_PATH + '/data/Global/'
COUNTRY_FILE_PATH       = BASE_FILE_PATH + '/data/' + ANALYSIS_DEFINITION['country_code'] + '/'

# Add the file path to the ANALYSIS definiton so we can cache the file
ANALYSIS['country_file_path'] = COUNTRY_FILE_PATH
ANALYSIS['cache_file']  = '{}IBMPAIRS_ANALYSIS_{}.csv'.format(COUNTRY_FILE_PATH,ANALYSIS['analysis'])

# Set the values to access IBM PAIRS
PAIRS_CACHE_REFRESH     = not(USER_CONFIG == None) # Used to control if only the local cache file is used
if(PAIRS_CACHE_REFRESH):
    PAIRS_SERVER        = USER_CONFIG['ibm_pairs']['server']
    PAIRS_CREDENTIALS   = (USER_CONFIG['ibm_pairs']['user'], USER_CONFIG['ibm_pairs']['api_key'])

# Define the PAIRS_QUERY to run    
PAIRS_QUERY             = {}
PAIRS_QUERY['alias']    = '_'.join(ANALYSIS['model']['predictor'])
PAIRS_QUERY['cache_file'] = '{}IBMPAIRS_{}.csv'.format(COUNTRY_FILE_PATH,'LocalCache')
PAIRS_QUERY['layers']   = {}

additional_days_for_layer = []
ANALYSIS['pairs_query'] = {}
ANALYSIS['pairs_query']['layers'] = {}
for predictor in ANALYSIS['model']['predictor']: 
    PAIRS_QUERY['layers'][predictor] = PAIRS_QUERY_DEFINITIONS[predictor]
    ANALYSIS['pairs_query']['layers'][predictor] = PAIRS_QUERY['layers'][predictor]
    additional_days_for_layer.append(\
        max(ANALYSIS['model']['rolling_windows']) +\
        max(ANALYSIS['model']['time_shifts']))

# Define the start & end date for the data we need from IBM PAIRS where we take into account:
#   - the extra days we need before the window_start to accomodate the rolling_windows
#   - the extra days we need before the window_start to accomodate the time_shifts
PAIRS_QUERY['start_date'] = ANALYSIS['time_window']['window_start'] \
                                - timedelta(days=(max(additional_days_for_layer) + 1))
PAIRS_QUERY['end_date'] = ANALYSIS['time_window']['window_end'] \
                                if ANALYSIS['time_window']['window_end'] < date.today() \
                                else date.today() - timedelta(days=1)

In [None]:
PAIRS_CACHE_REFRESH = False

## 4. Get Country Geospatial Data
To define a (new) country in the Notebook, create a **country_metadata.json** JSON file and select the **geometry** out of a shapefile.

1. Create a **subfolder** in the *data* folder using the [*ISO 3166-1 alpha-2*](https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2) code for the given country: 
2. Create a **country_metadata.json** in that country folder with one top level object *country_metadata* and for each region create one entry as follows:

    "country_metadata": [
        {
            "iso3166-1_code": "NL",
            "iso3166-1_name": "The Netherlands",
            "iso3166-2_code": "NL-UT",
            "iso3166-2_name_en": "Utrecht",
            "population": 1354979,
            "covid_region_code": "26"
        },]
Each region is identified by an [ISO 3166-2](https://en.wikipedia.org/wiki/ISO_3166-2) code.
    - *covid_region_code*: unique identifier of a region as used by the COVID-19 data source. Sometimes these sources use ISO 3166-2 as well, but often another unique reference is used. Used to establish the merge between the country geospatial data and the COVID-19 data.
    - *population*: total amount of people in the region. Used to determine the *population weighted* COVID-19 metric.

 Add a second top level object *country_metadata_sources* in to the JSON file to document the sources used and capture comments if any:

    "country_data_sources": {
        "iso3166-1":"https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2",
        "iso3166-2":"https://en.wikipedia.org/wiki/ISO_3166-2",
        "population":"https://opendata.cbs.nl/statline/#/CBS/nl/dataset/70072NED/table?fromstatweb"
    }


3. Extend the **get_country_region_geometry()** function with a new function that filters the geospatial vector data from the shapefile. The geospatial data used in this Notebook is the [*Admin 1 – States, Provinces*](https://www.naturalearthdata.com/downloads/10m-cultural-vectors/) shapefile.

 Possibly the data must be filtered and / or cleansed to arrive at the clean data set for the country. The outcome must be a Geopandas Data frame with the following columns:

    - *iso3166-2_code*: unique identified for the region
    - *geometry:* Polygon or Multi Polygon defining the boundaries of the region
    - *geometry_wkt:* geometry information in Well-Known text (WKT) format (the format used in PAIRS)

 NB: If you use another geospatial source, make sure the spatial reference system is WGS 84 (or EPSG:4326). This is required bij IBM PAIRS.

In [None]:
# Define the files that hold the country metadata & shapefile data
COUNTRY_METADATA_JSON   = COUNTRY_FILE_PATH + 'country_metadata.json'
SHAPEFILE_FILE_PATH = GLOBAL_FILE_PATH + 'ne_10m_admin_1_states_provinces/ne_10m_admin_1_states_provinces.shp'

# Function that loads the country_metadata JSON file & collects the geospatial vector data
def get_country_geospatial_data(country_code, country_metadata_file, shapefile):
    # load the JSON file that contains the meta data for the country
    try:
        with open(country_metadata_file, 'r', encoding='utf-8') as metadata_file:
            country_metadata=json.load(metadata_file)
    except:
        raise Exception('ERROR: Cannot find {}. Please create this file to continue.'.format(country_metadata_file))
        
    country_metadata_df = pd.json_normalize(country_metadata['country_metadata'])

    # load the country geometry vector data & merge with the country_metadata
    try:
        geometry_json, country_metadata_df = get_country_geometry(country_code, country_metadata_df, shapefile)
    except Exception as inst:
        raise Exception('ERROR: {}. Please define a country_geometry function to continue.'.format(inst))

    print("Country Geospatial Data loaded. There are {} regions defined.\n".format(country_metadata_df.shape[0]))
    print("The following sources have been used to contruct the country geospatial data:".format())

    for key,value in country_metadata['country_data_sources'].items():
        print(" - {}: {}".format(key, value))        
    
    return geometry_json, country_metadata_df

# Function that filters the geospatial vector data from the shapefile
# Extend when new ANALYSIS_DEFINITION['country_code']
def get_country_geometry(country_code, country_metadata_df, shapefile):
    #Read the (uncleansed) shapefile
    shapefile_gdp = gpd.read_file(shapefile)

    #Cleanse the data for Country specific updates to arrive at a standard gdp with
    if(country_code == 'NL'):
        country_region_gdp = get_country_geometry_basic_filter(shapefile_gdp,country_code,'Province')
    elif(country_code == 'FR'):
        country_region_gdp = get_FR_geometry_basic_filter(country_metadata_df,shapefile_gdp,country_code,'Metropolitan department')
    elif(country_code == 'DK'):
        country_region_gdp = get_country_geometry_basic_filter(shapefile_gdp,country_code,'Region')
    elif(country_code == 'SE'):
        country_region_gdp = get_country_geometry_basic_filter(shapefile_gdp,country_code,'County')
    elif(country_code == 'DE'):
        country_region_gdp = get_country_geometry_basic_filter(shapefile_gdp,country_code,'State')
    elif(country_code == 'US'):
        states_to_exclude = ['US-AK','US-HI']
        country_region_gdp = get_country_geometry_basic_filter(shapefile_gdp,country_code,'State',states_to_exclude)
    elif(country_code == 'IN'):
        country_region_gdp = get_country_geometry_basic_filter(shapefile_gdp,country_code)
    else:
        raise Exception('Geometry not implemented for country_code: ' + country_code)
    
    #Convert to GeoJSON
    country_region_json = json.loads(country_region_gdp.to_json())
    
    #Add the geometry information to the country metadata
    country_metadata_df = country_metadata_df.merge(country_region_gdp,
                                                   left_on='iso3166-2_code', right_on='iso3166-2_code', how='inner')

    #Store the geometry value in WKT format as well since this is used for PAIRS
    country_metadata_df['geometry_wkt'] = country_metadata_df['geometry'].apply(lambda g: g.wkt)

    return country_region_json, country_metadata_df

# Standard function that can be used by most countries to filter the vector data from the shapefile
def get_country_geometry_basic_filter(shapefile_gdp, iso_filter, 
                                      type_filter = None, iso3166_2_code_exclusion_filter = None):
    if(type_filter == None):
        country_region_gdp = shapefile_gdp[(shapefile_gdp['iso_a2'] == iso_filter)]
    else:
        country_region_gdp = shapefile_gdp[(shapefile_gdp['iso_a2'] == iso_filter) & (shapefile_gdp['type_en'] == type_filter)]
#   country_region_gdp.to_csv('country_region_gdp.csv', index=False)
    country_region_gdp = country_region_gdp[['iso_3166_2', 'geometry']]
    country_region_gdp.columns = ['iso3166-2_code', 'geometry']
    
    if not(iso3166_2_code_exclusion_filter == None):
        country_region_gdp = country_region_gdp[~country_region_gdp['iso3166-2_code'].isin(iso3166_2_code_exclusion_filter)]
    
    return country_region_gdp

# FR departments where reorganized in 2016, but the shapefile is using the old departments
# FR specific filtering and data cleansing is needed to align the geospatial information to current regions 
def get_FR_geometry_basic_filter(country_metadata_df, shapefile_gdp, iso_filter, type_filter):
    country_region_gdp = shapefile_gdp[(shapefile_gdp['iso_a2'] == iso_filter) & (shapefile_gdp['type_en'] == type_filter)]
    country_region_gdp = country_region_gdp[['region_cod', 'geometry']]

    # The metadata file for FR contains a specific fied
    fr_region_map = pd.DataFrame(columns = ['region_cod', 'iso3166-2_code'])
    for index, region in country_metadata_df.iterrows():
        for reg_code in region['FR_nashapefile_old_region_code'].split(','):
            fr_region_map = fr_region_map.append({'region_cod' : reg_code , 'iso3166-2_code' : region['iso3166-2_code']} , ignore_index=True)

    # The mapping we use to add the iso3166-2_code to the shapefile
    country_region_gdp = shapefile_gdp.merge(fr_region_map, 
                                                         left_on='region_cod', right_on='region_cod', how='inner')

    # We merge the geometry information so that we end up with the 13 regions
    country_region_gdp = country_region_gdp[['iso3166-2_code','geometry']]
    country_region_gdp = country_region_gdp.dissolve(by='iso3166-2_code')
    country_region_gdp = country_region_gdp.reset_index()
    return country_region_gdp
    

In [None]:
def create_region_choropleth(country_code, country_region_json, country_metadata_df):
    fig = px.choropleth(title='<b>Country Geospatial Selection & Population: {}</b>'.format(country_code),
                        geojson=country_region_json, featureidkey="properties.iso3166-2_code", 
                        data_frame=country_metadata_df,locations='iso3166-2_code',
                        color='population',
                        color_continuous_scale = 'Blues')
    fig.update_geos(fitbounds="locations", projection={'type':'mercator'}, visible=False)
    fig.update_layout(
        margin={"r":0,"t":40,"l":0,"b":0}, height=400,
        yaxis=dict(position=0),
        font_size=8,title=dict(x=0,font_size=16),
        legend=dict(title_font_size=10, font_size=8),        
        coloraxis_colorbar=dict(title="Population",lenmode="pixels", len=300))
    fig.show()

In [None]:
try:
    # Get the Country Metadata & Country Geospatial Vector Data
    COUNTRY_REGION_JSON, COUNTRY_METADATA_DF = \
        get_country_geospatial_data(ANALYSIS['country_code'], COUNTRY_METADATA_JSON, SHAPEFILE_FILE_PATH)

    # Create a Choropleth to visualize the Country & Regions being analysed    
    create_region_choropleth(ANALYSIS['country_code'], COUNTRY_REGION_JSON, COUNTRY_METADATA_DF)
except Exception as inst:
    print(inst)

## 5. Get Country COVID-19 Data Set
The Country Specific COVID-19 data set is retrieved. Since available metrics differ widely between countries in terms of terminology, granularity and definition they need to be harmonized before standardized analysis can be run.

Ideally the following **metrics** are obtained on a regional level:
- **confirmed**: individual tested positive for COVID-19
- **hospitalized**: individual admitted to a general hospital and tested positive for COVID-19
- **hospitalized_icu**: individual admitted to a ICU unit in the hospital and tested positive for COVID-19
- **recovered**: individual confirmed to have recovered from COVID-19
- **deceased**: individual confirmed to have passed away with COVID-19 infection

Ideally for each of the metrics the following values are listed per day:
- **{metric}_addition**: increase of that given day
- **{metric}_subtraction**: decline of that given day
- **{metric}_absolute**: running sum of additions minus the running sum of subtractions
- **{metric}_cumulative**: total amount up to that day (running sum of additions)

If _cumulative metrics are provided, without corresponding _addition metric, then the _addition metric is derived from the _cumulative metric.

For each of the available metrics a **_population_weighted** value is added by taking the original input from the source, divided by the *population* as defined in the country metadata.

If a new country is added the **get_country_covid_data** must be modified to add the country. Data retrieval and cleasing will be country specific bu the net result of the code must be:
1. The COVID-19 source is obtained, cleansed and loaded into a Data Frame
2. The DF has a 'date' and 'region_code' column for correlation & mapping and matches the covid_region_code as defined in the country metadata
3. The DF has standardized metric columns per the definition above

In [None]:
# get the country specific covid data and change it into a standard format
def get_country_covid_data(analysis):
    country_code = analysis['country_code']
    try:
        if(country_code == 'NL'):
            covid_source_df_cleansed = get_NL_covid_data()
        elif(country_code == 'FR'):
            covid_source_df_cleansed = get_FR_covid_data()
        elif(country_code == 'DK'):
            covid_source_df_cleansed = get_DK_covid_data(analysis['country_file_path'])
        elif(country_code == 'SE'):
            covid_source_df_cleansed = get_SE_covid_data(analysis['country_file_path'])
        elif(country_code == 'DE'):
            covid_source_df_cleansed = get_DE_covid_data()
        elif(country_code == 'US'):
            covid_source_df_cleansed = get_US_covid_data()
        elif(country_code == 'IN'):
            covid_source_df_cleansed = get_IN_covid_data()
        else:
            raise Exception('Data Collection not implemented for country_code: ' + country_code)
    except Exception as inst:
        print('{}'.format(inst))
    
    #make sure that date & region_code have the correct data type
    covid_source_df_cleansed['date'] = pd.to_datetime(covid_source_df_cleansed['date'])
    covid_source_df_cleansed['region_code'] = covid_source_df_cleansed['region_code'].astype('str') 

    print('\nCOVID-19 measures in source (after cleanse): ' + str(covid_source_df_cleansed.shape[0]))
    return covid_source_df_cleansed
    
# Create a country specific function to:
# a) get the data
# b) align the columns to the standard columns (date, region_code, region_name, {metric}_{metric_type}
# c) cleans the data 
def get_NL_covid_data():
# a) get the data
    covid_source_url = 'https://raw.githubusercontent.com/J535D165/CoronaWatchNL/master/data-json/data-provincial/RIVM_NL_provincial.json'
    covid_source_df = pd.DataFrame(requests.get(covid_source_url).json()['data'])
# b) map the columns
    covid_source_column_dict = {
        # map the columns for data, region_code, region_name
        'Datum': 'date', 'Provinciecode':'region_code', 'Provincienaam':'region_name',
        # map the columns for the avaible metrics
        'totaalAantal':'confirmed_addition',
        'totaalAantalCumulatief':'confirmed_cumulative',
        'ziekenhuisopnameAantal':'hospitalized_addition',
        'ziekenhuisopnameAantalCumulatief':'hospitalized_cumulative',
        'overledenAantal':'deceased_addition',
        'overledenAantalCumulatief': 'deceased_cumulative'}
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
# c) cleans the data

# return the data
    return covid_source_df

def get_FR_covid_data():
#   a) get the data
    covid_source_url = "https://raw.githubusercontent.com/opencovid19-fr/data/master/dist/chiffres-cles.csv"
    covid_source_df = pd.read_csv(io.StringIO(requests.get(covid_source_url).content.decode('utf-8')))
#   b) map the columns
    covid_source_column_dict = {
        # map the columns for data, region_code, region_name
        'date': 'date','maille_code':'region_code','maille_nom':'region_name',
        # map the columns for the avaible metrics
        'deces':'deceased_cumulative',
        'nouvelles_hospitalisations':'hospitalized_addition',#'hospitalises':'hospitalized_addition',
        'nouvelles_reanimations':'hospitalized_icu_addition',#'reanimation':'hospitalized_icu_addition',
        'gueris':'recovered_cumulative',
        # map other columns used for filtering
        'source_nom' : 'source'}
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
#   c) cleans the data
    # source includes data on various granularity levels, we only want to keep data on region level
    covid_source_df = covid_source_df[covid_source_df['granularite'].eq('region')]

    # There can be several entries for the same Region/day if so select the "Santé publique France" in the source column.
    # NB! "Santé publique France" values may be NaN while others sources are not. Only an issue around March 23-25 
    covid_source_df = covid_source_df[~covid_source_df.duplicated(['date', 'region_code'], keep=False) | covid_source_df['source'].eq('Santé publique France')]
    
# return the data
    return covid_source_df

def get_DK_covid_data(country_file_path):
#   a) get the data
    covid_source_url = country_file_path + 'Newly_admitted_over_time.csv'
    try:
        covid_source_df = pd.read_csv(covid_source_url, sep=';')
        print('WARN: {} is a local file. Please make sure you download the latest file from the [covid_data_provider]'.format(covid_source_url))
    except:
        raise Exception('ERROR: Cannot find {}. Please download this file from the [covid_data_provider] to continue.'.format(covid_source_url))
        
    covid_source_df["Dato"] = pd.to_datetime(covid_source_df["Dato"])
    covid_source_df = covid_source_df.drop(['Total'], axis=1)
    # NB! Translating region names to English for matching afterwards # org regions: ['Hovedstaden', 'Sjælland', 'Syddanmark','Midtjylland','Nordjylland']
    covid_source_df.columns = ['Dato','Capital Region of Denmark', 'Region Zealand', 'Region of Southern Denmark',
                                                                           'Central Denmark Region','North Denmark Region']
    covid_source_df = pd.melt(covid_source_df, id_vars='Dato', value_vars=['Capital Region of Denmark', 'Region Zealand', 'Region of Southern Denmark',
                                                                           'Central Denmark Region','North Denmark Region'], var_name='Region', value_name='Admitted')
#   b) map the columns
    covid_source_df['region_name'] = ''
    covid_source_column_dict = {
        # map the columns for data, region_code
        'Dato': 'date','Region':'region_code',
        # map the columns for the avaible metrics
        'Admitted':'hospitalized_addition'
    }
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
# c) cleans the data

# return the data
    return covid_source_df

def get_SE_covid_data(country_file_path):
#   a) get the data
    covid_source_url = country_file_path + 'region.csv'
    try:
        covid_source_df = pd.read_csv(covid_source_url, sep=',', encoding='utf-8')
        print('WARN: {} is a local file. Please make sure you download the latest file from the [covid_data_provider]'.format(covid_source_url))
    except:
        raise Exception('ERROR: Cannot find {}. Please download this file from the [covid_data_provider] to continue.'.format(covid_source_url))

    # rearrange date column
    dato = covid_source_df['Statistikdatum']
    covid_source_df.drop(labels=['Statistikdatum'], axis=1,inplace = True)
    covid_source_df.insert(0, 'Statistikdatum', dato)
    covid_source_df["Statistikdatum"] = pd.to_datetime(covid_source_df["Statistikdatum"])
    covid_source_df = covid_source_df.drop(['Totalt_antal_fall','Kumulativa_fall','Antal_avlidna','Kumulativa_avlidna','Antal_intensivvardade','Kumulativa_intensivvardade'], axis=1)
    # No translating into English according to https://en.wikipedia.org/wiki/ISO_3166-2:SE 
    # NB! Specifying the new cases column
    colnames = list(covid_source_df)
    colnames.pop(0) # remove 1st item
    covid_source_df = pd.melt(covid_source_df, id_vars='Statistikdatum', value_vars=colnames, var_name='County', value_name='new_cases')
#   b) map the columns
    covid_source_df['region_name'] = ''
    covid_source_column_dict = {
        # map the columns for data, region_code
        'Statistikdatum': 'date','County':'region_code',
        # map the columns for the avaible metrics
        'new_cases':'confirmed_addition'
    }
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
# c) cleans the data

# return the data
    return covid_source_df

def get_DE_covid_data():
#   a) get the data
    covid_source_url = "https://www.arcgis.com/sharing/rest/content/items/f10774f1c63e40168479a1feb6c7ca74/data"
    covid_source_df = pd.read_csv(io.StringIO(requests.get(covid_source_url).content.decode('utf-8')))
    col_keep = ['Meldedatum','Bundesland','AnzahlFall','AnzahlTodesfall']
    covid_source_df = covid_source_df[col_keep]
    covid_source_df["Meldedatum"] = pd.to_datetime(covid_source_df["Meldedatum"])
    covid_source_df = covid_source_df.groupby(by=['Meldedatum','Bundesland']).sum().reset_index() # approach to be checked    
#   b) map the columns
    covid_source_df['region_name'] = ''
    covid_source_column_dict = {
        # map the columns for data, region_code
        'Meldedatum': 'date','Bundesland':'region_code',
        # map the columns for the avaible metrics
        'AnzahlFall':'confirmed_addition',
        'AnzahlTodesfall': 'deceased_addition'}
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
# c) cleans the data

# return the data
    return covid_source_df

def get_US_covid_data():
#   a) get the data
    covid_source_url = 'https://covidtracking.com/api/v1/states/daily.json'
    covid_source_df = pd.DataFrame(requests.get(covid_source_url).json())
#   b) map the columns
    covid_source_column_dict = {
        # map the columns for data, region_code, region_name
        'date': 'date', 'fips':'region_code', 'state':'region_name',
        # map the columns for the avaible metrics
        'positive':'confirmed_cumulative',
        'hospitalizedCumulative':'hospitalized_cumulative',
        'inIcuCumulative':'hospitalized_icu_cumulative',
        'recovered':'recovered_cumulative',
        'death': 'deceased_cumulative'}
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
# c) cleans the data
    covid_source_df['date'] =  pd.to_datetime(covid_source_df['date'], format='%Y%m%d')

# return the data
    return covid_source_df

def get_IN_covid_data():
#   a) get the data
    covid_source_url = "https://api.covid19india.org/csv/latest/state_wise_daily.csv"
    covid_source_df = pd.read_csv(io.StringIO(requests.get(covid_source_url).content.decode('utf-8')))
    # dropping TT column (total for the country) and UN (unknown region)
    covid_source_df.drop(["TT", "UN"], axis = 1, inplace = True)
    # melt & pivot the columns to arrive at a data set compliant with the metamodel
    covid_source_df = pd.melt(covid_source_df, id_vars=['Date', 'Status'], var_name='region_code', value_name='cases')    
    covid_source_df = pd.pivot_table(covid_source_df, values='cases', index=['Date', 'region_code'], columns=['Status'], aggfunc=np.sum)
    covid_source_df.reset_index(inplace=True)
#   b) map the columns
    covid_source_df['region_name'] = ''
    covid_source_column_dict = {
        # map the columns for data, region_code
        'Date': 'date','region_code':'region_code',
        # map the columns for the avaible metrics
        'Confirmed':'confirmed_addition',
        'Deceased': 'deceased_addition',
        'Recovered': 'recovered_addition',
    }
    covid_source_df.rename(columns=covid_source_column_dict, inplace=True)
# c) cleans the data
    return covid_source_df

In [None]:
# map the region_code in of the covid source to iso3166-2 standard
#  - this mapping is defined in the country metadata file by the covid_region_code
def merge_country_meta_data(covid_source_df_cleansed, country_metadata_df):
    country_metadata_df = explode_covid_region_code(country_metadata_df)
    covid_source_df_cleansed = covid_source_df_cleansed.merge(country_metadata_df, left_on='region_code', right_on='covid_region_code', how='left')
    covid_source_df_cleansed['iso3166-2_code'] = covid_source_df_cleansed['iso3166-2_code'].fillna('?')

    if(len(covid_source_df_cleansed[covid_source_df_cleansed['iso3166-2_code'].eq('?')])>0):
        print("\nWARN: Not all regions mapped to iso3166-2_code. They will be dropped. These are missing mappings:\n")
        country_region_mapping = covid_source_df_cleansed[covid_source_df_cleansed['iso3166-2_code'].eq('?')].groupby(['region_code','region_name','iso3166-2_code'])
        print(country_region_mapping['date'].max())
        covid_source_df_cleansed = covid_source_df_cleansed[~covid_source_df_cleansed['iso3166-2_code'].eq('?')]

    print("\nINFO: These regions are mapped to iso3166-2_code:\n")
    country_region_mapping = covid_source_df_cleansed.groupby(['region_code','region_name','iso3166-2_code'])
    print(country_region_mapping['date'].max())
        
    covid_source_df_cleansed.set_index(['iso3166-2_code'], inplace=True)
    covid_source_df_cleansed.sort_values(by=['iso3166-2_code','date'],inplace=True)
    
    return covid_source_df_cleansed

# multiple covid_region_codes can be mapped to one iso3166-2_code
# for this add comma separated IDs in the covid_region_code
# the explode function splits these IDs and add them as separate rows in the dataframe
def explode_covid_region_code(country_metadata_df):
    return pd.DataFrame(country_metadata_df['covid_region_code'].str.split(',').tolist(),
                index=country_metadata_df['iso3166-2_code']) \
        .stack().to_frame() \
        .reset_index([0, 'iso3166-2_code']) \
        .rename(columns={0: 'covid_region_code'})\
        .merge(country_metadata_df.drop(["covid_region_code"], axis = 1), 
               left_on='iso3166-2_code', right_on='iso3166-2_code', how='inner')

    return country_metadata_df

In [None]:
# creates a Data Frame holding the metrics that are available for the analysis & visualizations
def get_available_metrics(covid_source_df):
    metrics = ['confirmed','hospitalized','hospitalized_icu','recovered','deceased']
    metric_types = ['addition','substraction','absolute','cumulative']
    metric_maths = ['','population_weighted']
    
    available_metrics = pd.DataFrame(columns=['column_name','metric','metric_type','metric_math','label'])
    for metric in metrics:
        for metric_type in metric_types:
            for metric_math in metric_maths:
                column_name = metric + '_' + metric_type
                column_name += '_' + metric_math if metric_math != '' else ''
                if column_name in covid_source_df.columns:
                    new_row = {'column_name':column_name,'metric':metric,'metric_type':metric_type,'metric_math':metric_math,'label':metric}
                    available_metrics = available_metrics.append(new_row, ignore_index=True)
    available_metrics.set_index('column_name', inplace=True)
    return available_metrics

# add missing metric columns that can be derived from other metrics and remove everything else
def cleans_metrics(covid_source_df):
    available_metrics = get_available_metrics(covid_source_df)

    # only keep the columns we will include in the analysis
    covid_source_df = covid_source_df[['date','iso3166-2_name_en','population'] + available_metrics.index.tolist()]

    # drop all duplicates
    covid_source_df = covid_source_df.drop_duplicates()
    print('\nCOVID-19 measures in source (after dedup): ' + str(covid_source_df.shape[0]))

    # fill NaN for [cumulative] measurements
    covid_source_cumulative_metrics = available_metrics[(available_metrics.metric_type=='cumulative')].index.tolist()
    if(len(covid_source_cumulative_metrics) > 0):
        covid_source_df[covid_source_cumulative_metrics] = covid_source_df[covid_source_cumulative_metrics].groupby('iso3166-2_code').fillna(method='ffill')

    # derive [addition] from [cumulative] measurements if not already included from the source
    #   [addition] metrics already in the source
    addition_metrics_available = available_metrics[(available_metrics.metric_type=='addition')].metric.tolist()
    #   [cumulative] metrics in the source but no [addition] metric
    cumulative_metrics_for_addition = available_metrics[ \
                                        (available_metrics.metric_type=='cumulative') & \
                                        ~(available_metrics.metric.isin(addition_metrics_available))].index.tolist()

    if(len(cumulative_metrics_for_addition) > 0):
        addition_metrics_to_derive = [cm.replace('cumulative','addition') for cm in cumulative_metrics_for_addition]
        covid_source_df[addition_metrics_to_derive] = covid_source_df.groupby(['iso3166-2_code'])[cumulative_metrics_for_addition].apply(lambda x: x.diff())    

    # fill 0 for all NaN values
    covid_source_df.fillna(value=0, inplace=True)
    
    # derive _cumulative metrics from _addition if not present
    available_metrics = get_available_metrics(covid_source_df)
    cumulative_metrics_available = available_metrics[(available_metrics.metric_type=='cumulative')].metric.tolist()
    addition_metrics_for_cumulative = available_metrics[ \
                                        (available_metrics.metric_type=='addition') & \
                                        ~(available_metrics.metric.isin(cumulative_metrics_available))].index.tolist()

    if(len(addition_metrics_for_cumulative) > 0):
        cumulative_metrics_to_derive = [cm.replace('addition','cumulative') for cm in addition_metrics_for_cumulative]
        covid_source_df[cumulative_metrics_to_derive] = covid_source_df.groupby(['iso3166-2_code'])[addition_metrics_for_cumulative].cumsum()
    
    ## calculate population weighted for all metrics
    available_metrics = get_available_metrics(covid_source_df)
    for metric in available_metrics.index.tolist():
        metric_pw = metric + '_population_weighted'
        covid_source_df[metric_pw]=covid_source_df[metric]/covid_source_df['population']
    
    available_metrics = get_available_metrics(covid_source_df)
    return covid_source_df, available_metrics

def get_metrics_to_analyse(available_metrics, analysis):
    metrics_to_analyse_available = []
    metrics_to_analyse_not_available = []

    for outcome in analysis['model']['outcome']:
        metric_found = False
        for available_metric in available_metrics.index:
            if re.fullmatch(outcome, available_metric):
                metrics_to_analyse_available.append(available_metric)
                metric_found = True
        if not(metric_found):
            metrics_to_analyse_not_available.append(outcome)

    if(len(metrics_to_analyse_not_available) != 0):
        print('WARN: These metrics are not available for analysis: {}'.format(metrics_to_analyse_not_available))

    if(len(metrics_to_analyse_available) == 0):
        raise Exception('ERROR: There are no metrics available for analysis')
    else:
        print('INFO: These metrics will be analysed: {}'.format(metrics_to_analyse_available))

    return available_metrics[available_metrics.index.isin(metrics_to_analyse_available)]

In [None]:
## Get the country COVID-19 source and clean the data to a standard format ##
covid_source_df_cleansed          = get_country_covid_data(ANALYSIS)
covid_source_df_cleansed          = merge_country_meta_data(covid_source_df_cleansed,COUNTRY_METADATA_DF)
COVID_SOURCE_DF,available_metrics = cleans_metrics(covid_source_df_cleansed)

## Determine what metrics we want to analyse are in the country input ##
ANALYSIS['available_metrics']     = get_metrics_to_analyse(available_metrics,ANALYSIS)

## 6. Get IBM PAIRS Geospatial & Temporal data

The **ANALYSIS** definition with the **COUNTRY_METADATA** file is used to control:
- What geospatial filter is applied (i.e. NL Provinces with geometry from a shapefile)
- What temporal filter is applied (i.e. the Time Window with consideration of Rolling Windows & Time Shifts)
- What data set filter is applied (i.e. the IBM PAIRS data layer(s))
- What temporal aggregation is applied (i.e. how IBM PAIRS aggregates the data)

The code below will:
- Construct the IBM PAIRS Query
- Submit the query & process the results
- Cache the results

Caching the result is important since the (initial) data retrieval can be a process that can take hour(s). The duration is depended on the size of the geography, the length of the time period and the data set(s) requested. Therefore the results are cached to a local file. The cache file is checked when running a new analysis and only the missing data (if any) is retrieved.

In [None]:
def data_agg_column(datum, aggregation):
    return '{}_data_agg_{}'.format(datum,aggregation)

def data_points_column(datum):
    return '{}_data_points'.format(datum)

def data_avg_column(datum, aggregation):
    return '{}_data_avg_{}'.format(datum,aggregation)

def init_pairs_cache(pairs_query):
    filepath = pairs_query['cache_file']
    if os.path.isfile(filepath):
        # retrieve the cached results
        print('Found PAIRS_QUERY_CACHE_FILE: ' + filepath)
        cache_df=pd.read_csv(filepath)
        cache_df["date"] = pd.to_datetime(cache_df["date"])
    else:
        print('Did not find PAIRS_QUERY_CACHE_FILE: ' + filepath)
        # create a dataframe with date, region (use the same column name as in source COVID data)
        column_names = ["date", "iso3166-2_code"]
        cache_df = pd.DataFrame(columns = column_names)

    cache_df["date"] = pd.to_datetime(cache_df["date"])
    
    # add the names of the columns we will use
    for alias, layer in pairs_query['layers'].items():
        pairs_query['layers'][alias]['data_points_column'] = data_points_column(alias)
        pairs_query['layers'][alias]['data_agg_column'] = data_agg_column(alias,layer['aggregation'])
        pairs_query['layers'][alias]['data_avg_column'] = data_avg_column(alias,layer['aggregation'])
    return cache_df

# query that creates one layer for each day in the list of provided days for the layer ID & poly provided 
def create_pairs_query_by_poly(pairs_query,poly):
    return create_pairs_query(pairs_query, 'polygon', poly)
    
def create_pairs_query(pairs_query,geo_type,geo_data):
    # The following helps when converting datetime objects to strings in ISO 8601-compliant format.
    iso8601 = '%Y-%m-%dT%H:%M:%SZ'
    
    queryJson = {
        "layers": [],
        "name": pairs_query['alias']
    }

    if(geo_type=="polygon"):
        queryJson['spatial'] = {"type" :"poly", "polygon": {"wkt": geo_data}}
    else:
        raise Exception("Geo Type not implemented")
    
    outer_dates = []
    for alias, layer in pairs_query['layers'].items():
        outer_dates.append(layer['days_to_collect'].min())
        outer_dates.append(layer['days_to_collect'].max())
        
        for startdate in layer['days_to_collect']:     
            enddate=startdate+timedelta(days=1)
            startISO=startdate.strftime(iso8601)
            endISO=enddate.strftime(iso8601)

            queryJson['layers'].append({
                "alias": alias + "_" + startdate.strftime('%Y%m%d'),
                "id": layer['layer_id'],
                "output": True,
                "aggregation" : layer['aggregation'],
                "type": "raster",
                "temporal": {"intervals": [{"start": startISO,"end": endISO}]}
            })

    firstdate = min(outer_dates)
    lastdate = max(outer_dates) + timedelta(days=1)
    queryJson['temporal'] = {"intervals": [{"start": firstdate.strftime(iso8601), "end": lastdate.strftime(iso8601)}]}
    
    return queryJson

def update_pairs_cache(pairs_cache, pairs_server, pairs_credentials, pairs_query, date_range, country_metadata_df):
    start_time_overall = timer()
    total_regions = country_metadata_df.shape[0]
    print("Start IBM PAIRS Queries for {} regions.\n".format(total_regions))
    region_id = 0
    for index, region in country_metadata_df.iterrows():
        region_code = region['iso3166-2_code']
        region_id += 1

        days_to_collect = 0
        # determine the missing dates that are not yet in the cached results per layer
        for alias, layer in pairs_query['layers'].items():
            # we assume we need to get all days
            diff_data_range = date_range
            # but for columns already in the cache we only pull missing values
            if layer['data_agg_column'] in pairs_cache.columns:
                cached_date_range = pairs_cache[\
                    (pairs_cache['iso3166-2_code']==region_code) &\
                    (pairs_cache[layer['data_agg_column']].notna())]['date']
                diff_data_range = date_range.difference(cached_date_range)
                
            layer['days_to_collect'] = diff_data_range
            days_to_collect += len(diff_data_range)

        if(days_to_collect > 0):
            print("Start IBM PAIRS Queries for: {} ({} of {}).".format(region_code,region_id,total_regions))
            for alias, layer in pairs_query['layers'].items():
                print(" - " + alias + ": " + str(len(layer['days_to_collect'])) + " day(s) to collect.")               
            start_time_region = timer()
        else:
            print("Skip IBM PAIRS Queries for: {} ({} of {}) (No days to collect).".format(region_code,region_id,total_regions))
            continue

        query_poly=region['geometry_wkt']
        queryJSON=create_pairs_query_by_poly(pairs_query, query_poly)
        logging.getLogger('ibmpairs.paw').setLevel(logging.ERROR)
        query = paw.PAIRSQuery(queryJSON, pairs_server, pairs_credentials)
        start_time_query = timer()
        print(" Submit query. Total elapse time: " + str(round((start_time_query - start_time_overall),1)))
        query.submit()
        query.poll_till_finished()
        end_time_query = timer()
        queryStatus = query.queryStatus.json()
        if(queryStatus['status'] == 'Succeeded'):
            print(" Download query result. Query elapse time: " + str(round((end_time_query - start_time_query),1)))        
            query.download()
            query.create_layers()
        else:
            print(" No download. PAIRS Query Status: " + queryStatus['status'] + "\n")
            continue

        # the query returns one layer per day
        query_metadata = pd.DataFrame(query.metadata).transpose()
        for dla in query_metadata['datalayerAlias']:
            query_meta = query_metadata[(query_metadata['datalayerAlias'] == dla)]           
            data_idx = query_meta.index[0]
            temporalAggregation = query_metadata.at[data_idx,'temporalAggregation']
            
            layer_alias = dla.split('_')[0]
            layer_date = datetime.strptime(dla.split('_')[1],'%Y%m%d')
            print("   - Layer returned for [" + layer_alias + "] for date: " + layer_date.strftime("%Y-%m-%d"))
            layer_data = query.data[data_idx]
            # non nan values correspond to measurement points in the area of interest, nan values are outside.
            # nanmean computes the average of weather data per surface within the area of interest
            # add 0.001 to avoid divide by zero.
            data_points=0
            data_sum=0
            data_points+=np.count_nonzero(~np.isnan(layer_data))
            data_sum=np.nansum(layer_data)
            data_avg=data_sum/(data_points+0.001)

            #column names in the cache file use the layer_alias as prefix
            data_points_col = data_points_column(layer_alias)
            data_agg_col = data_agg_column(layer_alias,temporalAggregation)
            data_avg_col = data_avg_column(layer_alias,temporalAggregation)

            if(data_sum > 0): # is data_sum is 0 it means there are no results and we skip adding the row
                if(pairs_cache[(pairs_cache['iso3166-2_code']==region_code) & (pairs_cache['date']==layer_date)].empty):
                    #append row to the dataframe
                    layer_row = pd.Series(data={'date':layer_date, 'iso3166-2_code':region_code, data_agg_col: data_sum, data_points_col: data_points, data_avg_col:data_avg})
                    pairs_cache = pairs_cache.append(layer_row, ignore_index=True)
                else:
                    #update row on dataframe
                    layer_row = pairs_cache[(pairs_cache['iso3166-2_code']==region_code) & (pairs_cache['date']==layer_date)]
                    pairs_cache.at[(pairs_cache['iso3166-2_code']==region_code) & (pairs_cache['date']==layer_date),
                                        data_agg_col] = data_sum
                    pairs_cache.at[(pairs_cache['iso3166-2_code']==region_code) & (pairs_cache['date']==layer_date),
                                        data_points_col] = data_points
                    pairs_cache.at[(pairs_cache['iso3166-2_code']==region_code) & (pairs_cache['date']==layer_date),
                                        data_avg_col] = data_avg
        # do a final save of the CSV with safety measure to make sure no duplicate entries are stored
        print("Finished IBM PAIRS Queries for: " + region_code + ". Save results to cache file.\n")
        pairs_cache.drop_duplicates(inplace=True)
        pairs_cache.to_csv(pairs_query['cache_file'], index=False)
    print("Finished all IBM PAIRS Queries.\n")
    return pairs_cache

In [None]:
# get the cached results previousy obtained from PAIRS
PAIRS_CACHE_DF = init_pairs_cache(PAIRS_QUERY)

# get additional data from PAIRS or skip to only use the local cached data
if(PAIRS_CACHE_REFRESH):
    print("\nCollect data from " + PAIRS_QUERY['start_date'].strftime("%Y-%m-%d") + " up to and including " + PAIRS_QUERY['end_date'].strftime("%Y-%m-%d"))
    # create the date range, in daily increments, for the data we want from PAIRS
    requested_date_range = pd.date_range(start=PAIRS_QUERY['start_date'],end=PAIRS_QUERY['end_date'],freq='D')

    # query pairs & update the cache
    PAIRS_CACHE_DF = update_pairs_cache(
        PAIRS_CACHE_DF, PAIRS_SERVER, PAIRS_CREDENTIALS, PAIRS_QUERY, requested_date_range, COUNTRY_METADATA_DF)
else:
    print('\nSkip datacollect from IBM PAIRS. Only use cache.')

## 7. Add Rolling Windows & Time Shifts
The **ANALYSIS** definition controls what *Rolling Windows* and *Time Shifts* are used. These parameters allow the algorithm to be tuned to e.g. incorporate the influence of different incubation times.

A new column is added to the PAIRS_CACHE_DF for each *Rolling Window* and each *Time Shift* specified in the ANALYSIS.

In [None]:
# Helper functions to construct the column names for the Rolling Window and Time Shift columns
def time_shift_label(shift):
    return "TS" + str(shift) + "D"

def time_shift_labels(time_shifts):
    shift_labels = []
    for shift in time_shifts:
        shift_labels.append(time_shift_label(shift))
    return shift_labels

def rolling_window_column(metric,window,window_type):
    return metric + '_rolling_' + window_type + '_' + str(window) + 'D'

def time_shift_column(metric,window,window_type,shift):
    return rolling_window_column(metric,window,window_type) + '_' + time_shift_label(shift)

# Add a rolling window column for each of the specified window_sizes.
# Shift the rolling windows for each of the time_shifts to introduce time-lagged correlation
def add_windows(pairs_cache, analysis):
    pairs_cache = pairs_cache.sort_values(by='date')
    new_pairs_cache = pd.DataFrame()
    output_columns = ['date','iso3166-2_code'] +\
                        list(map(lambda x: data_avg_column(x,analysis['pairs_query']['layers'][x]['aggregation']),\
                                           analysis['model']['predictor'])) 
    for region in pairs_cache['iso3166-2_code'].unique():
        pairs_region = pairs_cache[pairs_cache['iso3166-2_code']==region][output_columns].copy()

        for predictor in analysis['model']['predictor']:
            metric = data_avg_column(predictor,analysis['pairs_query']['layers'][predictor]['aggregation'])
            rolling_windows = analysis['model']['rolling_windows']
            rolling_window_type = analysis['model']['rolling_window_type']
            time_shifts = analysis['model']['time_shifts']

            # calculate the rolling window
            for window in rolling_windows:
                rolling_window_name = rolling_window_column(metric,window,rolling_window_type)
                if(rolling_window_type=='mean'):
                    pairs_region[rolling_window_name] = pairs_region.rolling(window)[metric].mean()
                elif(rolling_window_type=='sum'):
                    pairs_region[rolling_window_name] = pairs_region.rolling(window)[metric].sum()
                else:
                    raise Exception('Rolling Window Type not defined: ' + rolling_window_type)

                # shift the rolling sum for each time_shift to do an easy time-lagged correlation
                for shift in time_shifts:
                    time_shift_name = time_shift_column(metric,window,rolling_window_type,shift)
                    pairs_region[time_shift_name] = pairs_region[rolling_window_name].shift(shift)
        new_pairs_cache = new_pairs_cache.append(pairs_region)
    
    #want to store it separate from the cache file so that rolling sum columns are recreated everytime
    new_pairs_cache.to_csv(analysis['cache_file'], index=False)
    
    return new_pairs_cache

In [None]:
# calculate the rolling window(s) and set the time_shifts for time-lagged correlation
print('\nAdd Rolling Windows & Time Shifts.')
PAIRS_CACHE_DF = add_windows(PAIRS_CACHE_DF, ANALYSIS)

## 8. Merge COVID-19 & IBM PAIRS datasets

At this stage we have available:

- **COVID_PAIRS_DF**: Country Metadata & COVID-19 dataset in a standardized format
- **PAIRS_CACHE_DF**: Geospatial dataset from PAIRS for the country & time window required

These two data sets are merged on the *date* and *iso3166-2 code* so that they can be analyzed & visualized.

In [None]:
# add the IBM PAIRS query results to the COVID-19 source
print('\nCOVID-19 measures in source (cleansed): ' + str(COVID_SOURCE_DF.shape[0]))
COVID_PAIRS_DF = COVID_SOURCE_DF.merge(PAIRS_CACHE_DF, on=['iso3166-2_code','date'])
print('COVID-19 measures in source with ' + PAIRS_QUERY['alias'] + ': ' + str(COVID_PAIRS_DF.shape[0]))

# take the time slice to analyse & visualize
COVID_PAIRS_DF_TIME_SLICED = \
    COVID_PAIRS_DF[ \
     (COVID_PAIRS_DF['date'] >= pd.to_datetime(ANALYSIS['time_window']['window_start'])) & \
     (COVID_PAIRS_DF['date'] <= pd.to_datetime(ANALYSIS['time_window']['window_end']))]

COVID_PAIRS_DF.to_csv(COUNTRY_FILE_PATH + "COVID_PAIRS_DF.csv",index=True)
COVID_PAIRS_DF_TIME_SLICED.to_csv(COUNTRY_FILE_PATH + "COVID_PAIRS_DF_TIME_SLICED.csv",index=True)

## 9. Determine Spearman or GAM
The selected **MODEL** in the **ANALYSIS** definition controls the type of algorithm that is run. This Notebook supports two different models:

### 1. Spearman
The coefficient is determined for the *predictor* and each of the *outcomes* defined in the model.

The correlation between these two variables is then determined for each *rolling_window* and each *time_shift*. The *alpha* variable determines whether the correlation is deemed statistically significant or not.

### 2. Generalized Additive Model (GAM)
The log-linear association is determined for each of the *independent_variables* and each of the *outcomes*, whereby the *control_variables* are used as [confounding variables](https://en.wikipedia.org/wiki/Confounding) which are variables, other than the independent variables in focus, that may affect the outcome and thus, may lead to erroneous conclusions about the relationship between the independent and outcome variables.

The association between the *independent_variables* and each of the *outcomes* is determined for each *rolling_window* and each *time_shift*. The *alpha* variable determines whether the association is deemed statistically significant or not.

In [None]:
# Helper functions to construct the column names for the Rolling Window and Time Shift columns
def analysis_result_key(outcome, predictor, window):
    return outcome + "_"+ predictor + "_rolling_" + str(window) + "D"

def get_spearman_correlations(analysis, covid_pairs_df):
    analysis_result_datasets = {}
    
    # Fix the predictor to the first specified in the model.
    # TODO: Support use of multiple predictors
    predictor = analysis['model']['predictor'][0]
    predictor_column_name = data_avg_column(predictor,analysis['pairs_query']['layers'][predictor]['aggregation'])
    
    # Determine the correlation for each outcome & each rolling window
    for metric_column_name, metric_row in analysis['available_metrics'].iterrows():
        print('\nPerform Spearman correlation for: ' + metric_column_name)
        rank_variable_one = metric_column_name

        for window in analysis['model']['rolling_windows']:
            rank_variable_two = rolling_window_column(predictor_column_name,\
                                                      window,analysis['model']['rolling_window_type'])
            # determine the correlation
            analysis_df = determine_spearman_correlation(
                covid_pairs_df, rank_variable_one, rank_variable_two, 
                analysis['model']['time_shifts'], analysis['model']['alpha'])
            
            # add the analysis_results to the analysis_result_datasets dictionary
            result_key = analysis_result_key(metric_column_name,predictor,window)
            analysis_result_datasets[result_key] = \
                dict(model = 'spearman', 
                     label = metric_row.label,
                     rank_variable_one = rank_variable_one,
                     rank_variable_two = rank_variable_two,
                     predictor = predictor, 
                     significant_field = 'rho', 
                     significant_field_display = 'rho',
                     data = analysis_df)

            
            # write the results to a local file
            analysis_df.to_csv(\
                analysis['country_file_path'] + "spearman_" + result_key + ".csv",index=True)

    return analysis_result_datasets

# Function to determine the spearman correlation coefficient between two rank variables for each time_shift
def determine_spearman_correlation(covid_pairs_df, rank_variable_one, rank_variable_two, time_shifts, alpha):
    analysis_df = pd.DataFrame(covid_pairs_df['iso3166-2_code'].unique())
    analysis_df.columns = ['iso3166-2_code']
    analysis_df.set_index(['iso3166-2_code'], inplace=True)
    
    rankVariableTwos = []
    
    for shift in time_shifts:
        shift_label = time_shift_label(shift)
        rankVariableTwos.append({"variable" : rank_variable_two + "_" + shift_label, "time_shift" : shift, "label" : shift_label})

    for rankVarTwo in rankVariableTwos:        
        for key, data in covid_pairs_df.groupby('iso3166-2_code'):
            df_clean = data[[rank_variable_one,rankVarTwo['variable']]].dropna().copy()
            
            if df_clean.empty:
                print('No input for Spearman Correlation: ', key, rank_variable_one, rankVarTwo['variable'])
            else:
                rho, pval = spearmanr(df_clean[rank_variable_one], df_clean[rankVarTwo['variable']])
                analysis_df.loc[key, rankVarTwo['label'] + '_rho'] = rho
                analysis_df.loc[key, rankVarTwo['label'] + '_pval'] = pval
                analysis_df.loc[key, rankVarTwo['label'] + '_significant'] = ('Y' if pval < alpha else 'N') 

    return analysis_df

def get_gam_associations(analysis, covid_pairs_df):
    analysis_result_datasets = {}

    # Determine the association for each outcome & each rolling window
    for metric_column_name, metric_row in analysis['available_metrics'].iterrows():
        print('\nPerform log-linear GM assocations for: ' + metric_column_name)

        for window in analysis['model']['rolling_windows']:
            # determine the association
            analysis_gam_df = gamModel.determineGam(
                gamModel.convertPDtoR(covid_pairs_df,metric_column_name, analysisHelper.PredictorToColumn(analysis)),
                analysis['model']['independent_variables'],
                analysis['model']['control_variables'],
                analysis['model']['time_shifts'],
                analysis['model']['rolling_window_type'],
                window,
                analysis['model']['alpha'])
            
            # gam model has multiple predictors, so we must unwrap them to obtain results per predictor
            for predictor in analysis['model']['independent_variables']:
                analysis_df = convert_gam_association(
                                    covid_pairs_df,
                                    analysis_gam_df.loc[(analysis_gam_df.predictor == predictor)],
                                    analysis['model']['time_shifts'])

                # add the analysis_results to the analysis_result_datasets dictionary
                result_key = analysis_result_key(metric_column_name,predictor,window)
                analysis_result_datasets[result_key] = \
                    dict(model = 'gam', 
                         label = metric_row.label, 
                         predictor = predictor, 
                         significant_field = 'coeff', 
                         significant_field_display = 'perc_change',
                         data = analysis_df)

                # write the results to a local file
                analysis_df.to_csv(\
                    analysis['country_file_path'] + "gam_" + result_key + ".csv",index=True)

    return analysis_result_datasets

# Function to convert the outcome of the assocation to the format that fits DF format used for visualization
# Conversion is done to enable re-use of visualization logic
def convert_gam_association(covid_pairs_df, analysis_gam_df, time_shifts):
    # Create a table with all ISO3166-2_codes
    analysis_df = pd.DataFrame(covid_pairs_df['iso3166-2_code'].unique())
    analysis_df.columns = ['iso3166-2_code']
    analysis_df.set_index(['iso3166-2_code'], inplace=True)
    
    for shift in time_shifts:
        shift_label = time_shift_label(shift)
        significant_results = analysis_gam_df.loc[(analysis_gam_df.time_shift == str(shift))]

        for key, data in analysis_df.iterrows():
            significant_result = significant_results.loc[(analysis_gam_df['iso3166-2_code'] == key)]
            if(significant_result.empty):
                analysis_df.loc[key, shift_label + '_significant'] = 'N'
                analysis_df.loc[key, shift_label + '_pval'] = np.nan
                analysis_df.loc[key, shift_label + '_coeff'] = np.nan
                analysis_df.loc[key, shift_label + '_perc_change'] = np.nan
            else:
                analysis_df.loc[key, shift_label + '_significant'] = 'Y'
                analysis_df.loc[key, shift_label + '_pval'] = float(significant_result['p_val'].values[0])
                analysis_df.loc[key, shift_label + '_coeff'] = float(significant_result['coeff'].values[0])
                analysis_df.loc[key, shift_label + '_perc_change'] = float(significant_result['perc_change'].values[0])
            
    return analysis_df

In [None]:
# for all the _addition metrics available in the data set we will establish a correlation
if(ANALYSIS['model']['model'] == 'gam'):
    ANALYSIS_RESULTS_DATASETS = get_gam_associations(ANALYSIS,COVID_PAIRS_DF_TIME_SLICED)
elif(ANALYSIS['model']['model'] == 'spearman'):
    ANALYSIS_RESULTS_DATASETS = get_spearman_correlations(ANALYSIS,COVID_PAIRS_DF_TIME_SLICED)

## 10. Create Visualizations
To present the results of the analysis three types of plots are created:

1. **Overview Regions**: Line Charts to provide insight in the time series data for the predictor & outcome values
2. **Scatter Plot Matrix (spearman only)**: Scatter Plots to provide insight in the potential correlation between predictor & outcome
3. **Choropleth Maps**: Maps to show whether a Region shows a significant correlation between the predictor & outcome

A plot is created for each *predictor*, *outcome*, *time shift* and *rolling window*.

NB: Please note that the charts are interactive. By default the Overview Regions & SPLOM hide the charts for all regions. Selecting the Region Code in the legend will show the corresponding chart.

In [None]:
# template to control standard layout of plots
pio.templates["gaa_template"] = go.layout.Template(
    layout_margin={"r":0,"t":40,"l":0,"b":0},
    layout_font_size=8,
    layout_title=dict(x=0,font_size=16),
    layout_legend=dict(title_font_size=10, font_size=8),
)

# helper function to map the region codes to a color that is consistent on various plots
# colors used: https://www.carbondesignsystem.com/data-visualization/color-palettes#categorical-palettes
def categorical_color_map(names):
    color_palette = cycle([
        '#6929c4', # Purple 70
        '#1192e8', # Cyan 50
        '#005d5d', # Teal 70
        '#9f1853', # Magenta 70 
        '#fa4d56', # Red 50
        '#520408', # Red 90
        '#198038', # Green 60
        '#002d9c', # Blue 80
        '#ee5396', # Magenta 50
        '#b28600', # Yellow 50
        '#009d9a', # Teal 50
        '#012749', # Cyan 90
        '#8a3800', # Orange 70
        '#a56eff', # Purple 50
    ])
    return dict(zip(names, color_palette))
    
# 3-colorscale (green, white, red) using IBM Design Language colors: https://www.ibm.com/design/language/color/
def three_color_scale():
    return [(0.00, 'rgb(36, 161, 72)'), (0.33, 'rgb(36, 161, 72)'),       # Green 50
              (0.33, 'rgb(255, 255, 255)'), (0.66, 'rgb(255, 255, 255)'), # White
              (0.66, 'rgb(218, 30, 40)'), (1.00, 'rgb(218, 30, 40)')]     # Red 60

# helper function to create a named trace with consistent appearance & behaviour 
def create_trace_groups(trace_names, region_color_map):
    trace_groups = {
        name: {'name': name, 'legendgroup': name, 'line': {'color': region_color_map[name]}}
        for name in trace_names
    }
    return trace_groups

In [None]:
# create a series of plots that show the time series for each region:
#   - predictors per rolling window
#   - available metrics contained in the analysis
def create_region_overview_plots(covid_pairs_df, analysis,
                                 cases_by_population, region_color_map, row_height = 100, show_regions = True):
    # define a title for the overall plot
    title = "<b>Overview Regions {}</b>: {} - {}".\
        format(analysis['country_code'],analysis['time_window']['window_start'].strftime('%d-%m-%Y'),\
               analysis['time_window']['window_end'].strftime('%d-%m-%Y'))
    
    # create a title for each subplot
    subplot_titles = []
    rows = 0
    predictors = analysis['model']['predictor']
    rolling_windows = analysis['model']['rolling_windows']
    
    for predictor in predictors:
        for window in rolling_windows:
            subplot_titles.append(predictor + " (Rolling Window: " + str(window) + "D)")
            rows += 1

    for metric_column_name, metric_row in analysis['available_metrics'].iterrows():
        if(metric_row.metric_math == 'population_weighted'):
            subplot_titles.append("{} ({} by {:,} people)".\
                                  format(metric_row.metric, metric_row.metric_type, cases_by_population))
        else:
                subplot_titles.append("{} ({})".\
                                  format(metric_row.metric, metric_row.metric_type))
        rows += 1
    
    # create the subplot figure
    fig = make_subplots(rows, 1, subplot_titles=subplot_titles,
                        vertical_spacing=0.05, horizontal_spacing=0.05,
                        shared_xaxes=True, shared_yaxes=False)
    
    # create a style object so that each region is one tracegroup so we can link color & interaction
    trace_groups = create_trace_groups(covid_pairs_df['iso3166-2_code'].unique(), region_color_map)
    
    # for each region create the traces
    for region_code in covid_pairs_df['iso3166-2_code'].unique():
        row_id = 1
        df_for_region = covid_pairs_df[covid_pairs_df['iso3166-2_code'] == region_code]

        for predictor in predictors:
            predictor_column_name = data_avg_column(predictor,analysis['pairs_query']['layers'][predictor]['aggregation'])
            for window in rolling_windows:
                fig.add_trace(go.Scatter(mode='lines',
                    x=df_for_region['date'],
                    y=df_for_region[rolling_window_column(predictor_column_name,window,analysis['model']['rolling_window_type'])],
                    **trace_groups[region_code],showlegend=(row_id==1),visible=show_regions),row_id,1)
                row_id += 1
        
        for metric_column_name, metric_row in analysis['available_metrics'].iterrows():
            if(metric_row.metric_math == 'population_weighted'):
                y_value = round(df_for_region[metric_column_name]*cases_by_population,0)
            else:
                y_value = round(df_for_region[metric_column_name],0)
                
            fig.add_trace(go.Scatter(mode='lines+markers',
                x=df_for_region['date'],
                y=y_value,
                **trace_groups[region_code],showlegend=(row_id==1),visible=show_regions),row_id,1)
            row_id += 1

    fig.update_layout(
        title=title,
        legend_title_text='<b>Region</b>',
        template='plotly+gaa_template',
        height=(row_height*rows) + 25)
    fig.show()

# create a series of scatter plot matrix that shows per rolling window correlations between:
#   - available metrics contained in the analysis
#   - time shifts contained in the analysis
def create_splom(covid_pairs_df, analysis, window, color_discrete_map, show_regions = True):
    # define a title for the overall plot
    title = "<b>SPLOM (Population Weighted) {} for Rolling Window {}D</b>: {} - {}".\
        format(analysis['country_code'],window,analysis['time_window']['window_start'].strftime('%d-%m-%Y'),\
               analysis['time_window']['window_end'].strftime('%d-%m-%Y'))

    # define what scatter plots will be in the matrix
    splom_dimensions = analysis['available_metrics'].index.tolist()
    splom_labels = analysis['available_metrics'].label.to_dict()

    # Fix the predictor to the first specified in the model.
    # TODO: Support use of multiple predictors
    predictor = analysis['model']['predictor'][0]
    predictor_column_name = data_avg_column(predictor,analysis['pairs_query']['layers'][predictor]['aggregation'])

    for shift in analysis['model']['time_shifts']:
        splom_dimensions.append(time_shift_column(predictor_column_name,window, analysis['model']['rolling_window_type'],shift))
        splom_labels[time_shift_column(predictor_column_name,window,analysis['model']['rolling_window_type'],shift)] = str(shift) + 'D'

    # create the splom
    fig = px.scatter_matrix(covid_pairs_df,dimensions=splom_dimensions,labels=splom_labels,
        color="iso3166-2_code",
        symbol="iso3166-2_code",
        hover_name="iso3166-2_code",
        color_discrete_map=color_discrete_map)
    fig.update_traces(diagonal_visible=False, showupperhalf=False, showlowerhalf=True, visible=show_regions)
    fig.update_layout(
        title=title,
        legend_title_text='<b>Region Code</b>',
        template='plotly+gaa_template')
    fig.show()
    
# We build a Grid of Choropleths.
#  - Each row contains one metric type, with one sliding window
#  - Each column contains one time shift

# for visualization we map the 'significant' and 'significant_value' columns to 3 discrete values (-1, 0, 1)
def set_region_category(significant, significant_value):
    category = 0
    if significant == "Y":
        category = 1 if significant_value > 0 else -1 
    return category

def set_region_hovertext(significant, significant_value_display):
    hovertext = "N"
    if significant == "Y":
        hovertext = 'Y ('+ str(round(significant_value_display,2)) + ')'
    return hovertext

# create the choropleth object use to visualize a map on the grid
def get_choropleth_dict(geo_trace, geo_json, geo_choropleth_corr_dataset, shift_label):
    # set category to a 3-value to control discrete three_color_scale
    geo_choropleth_datum = geo_choropleth_corr_dataset['data']
    analysis_model = geo_choropleth_corr_dataset['model']
    
    geo_choropleth_datum[shift_label + '_region_category'] = geo_choropleth_datum.\
        apply(lambda x: set_region_category(x[shift_label + '_significant'],x[shift_label + '_' + geo_choropleth_corr_dataset['significant_field']]),axis=1)
    
    geo_choropleth_datum[shift_label + '_hovertext'] = geo_choropleth_datum.\
        apply(lambda x: set_region_hovertext(x[shift_label + '_significant'],x[shift_label + '_' + geo_choropleth_corr_dataset['significant_field_display']]),axis=1)
    
    if(analysis_model == 'gam'):
        legend_title = 'Assocation'
    elif(analysis_model == 'spearman'):
        legend_title = 'Correlation'
        
    return dict(
            type = 'choropleth',
            geojson=geo_json, featureidkey="properties.iso3166-2_code", #JSON with Polygon & ID field to match the DF with color info
            locations=geo_choropleth_datum.index,#geo_choropleth_datum['iso3166-2_code'], #DF & ID to match to JSON for coloring
            z = geo_choropleth_datum[shift_label + '_region_category'], # Data to be color-coded
            zmin=-1, zmax=1, # Force a scale from -1 to 0 to a discrete color mapping
            colorscale = three_color_scale(),
            showscale=True,
            name=geo_trace,
            text=geo_choropleth_datum[shift_label + '_hovertext'],
            hoverinfo="text+location",
            colorbar=dict(
                title=dict(text=legend_title, side = "top", font=(dict(size=12))),
                outlinecolor='black', outlinewidth=1, y=0.90, xpad=10, ypad=50,
                tickmode="array",
                tickvals=[1, 0, -1],
                ticktext=["Yes (Positive)","No","Yes (Negative)"],
                ticks="outside",
                lenmode="pixels", len=150,
            )
    )
        
def create_choropleth_plots(analysis_results_datasets, country_region_json, analysis, row_height = 200, column_width = 250):
    if(analysis['model']['model'] == 'gam'):
        for predictor in analysis['model']['independent_variables']:
            title = "<b>GAM on COVID-19 & {} {} </b>: {} - {}".\
                format(analysis['country_code'],predictor,analysis['time_window']['window_start'].strftime('%d-%m-%Y'),\
                       analysis['time_window']['window_end'].strftime('%d-%m-%Y'))

            create_choropleth_plots_for_predictor(title, analysis_results_datasets, country_region_json, analysis, predictor)
    elif(analysis['model']['model'] == 'spearman'):
        predictor = analysis['model']['predictor'][0]
        # define a title for the overall plot
        title = "<b>Spearman Correlation on COVID-19 & {} {} </b>: {} - {}".\
            format(analysis['country_code'],predictor,analysis['time_window']['window_start'].strftime('%d-%m-%Y'),\
                   analysis['time_window']['window_end'].strftime('%d-%m-%Y'))

        create_choropleth_plots_for_predictor(title, analysis_results_datasets, country_region_json, analysis, predictor)

def create_choropleth_plots_for_predictor(title, analysis_results_datasets, country_region_json, analysis, predictor, row_height = 200, column_width = 250):    
    # Each of the figures will get a title
    row_titles = []
    subplot_titles = []

    for metric_column_name, metric_row in analysis['available_metrics'].iterrows():
        for window in analysis['model']['rolling_windows']:
            result_key = analysis_result_key(metric_column_name,predictor,window)
            row_titles.append("[Window {}D]".format(str(window)))            
            
            for shift in analysis['model']['time_shifts']:
                subplot_titles.append("{} ({}D)".format(analysis_results_datasets[result_key]['label'],str(shift)))

    # create the grid layout that will hold the various plots
    columns = len(analysis['model']['time_shifts'])
    rows = len(analysis['available_metrics'])*len(analysis['model']['rolling_windows'])

    # determine the size of the grid
    column_widths = [0.4] * columns
    row_heights = [0.4] * rows

    grid_width = column_width * columns
    grid_height = row_height * rows
    
    # create the subplot figure
    grid_specs = [[{"type": "choropleth"}] * columns] * rows
    fig = make_subplots(rows=rows, cols=columns, column_widths=column_widths, row_heights=row_heights, 
                        specs=grid_specs,row_titles=row_titles, subplot_titles=subplot_titles, shared_xaxes=True, 
                        vertical_spacing=0.05)

    # add the plots to the grid layout
    row_id = 1
    col_id = 1

    for metric_column_name, metric_row in analysis['available_metrics'].iterrows():
        for window in analysis['model']['rolling_windows']:
            result_key = analysis_result_key(metric_column_name,predictor,window)
            for shift_label in time_shift_labels(analysis['model']['time_shifts']):
                fig.add_trace(get_choropleth_dict(result_key,country_region_json,analysis_results_datasets[result_key],shift_label), row=row_id, col=col_id)
                col_id += 1

            row_id += 1
            col_id = 1

    # show the grid
    fig.update_geos(fitbounds="locations", projection={'type':'mercator'}, visible=False)
    fig.update_layout(width=grid_width, height=grid_height, title_text=title, template='plotly+gaa_template')
    fig.show()

In [None]:
# parameters to control visualizations
CASES_BY_POPULATION     = 1000000 # value used to show weighted by population expressed in cases by ...
SHOW_REGIONS            = "legendonly" # whether traces in plots are shown by default or not can be True, False, "legendonly"
ROW_HEIGHT_OVERVIEW     = 100 # value to control height of rows in Overview Regions plot
ROW_HEIGHT_CHOROPLETH   = 200
COLUMN_WIDTH_CHOROPLETH = 250

# create a mapping between the iso3166-2_code and a color
# this is to control that different plots use the same color for the same region
REGION_COLOR_MAP = categorical_color_map(COVID_PAIRS_DF_TIME_SLICED['iso3166-2_code'].unique())

# to get a understanding of the data create overview plots
create_region_overview_plots(COVID_PAIRS_DF_TIME_SLICED, ANALYSIS,
                                CASES_BY_POPULATION, REGION_COLOR_MAP, ROW_HEIGHT_OVERVIEW, SHOW_REGIONS)

# to get an initial understanding of the possible correlation generate SPLOM(s)
if(ANALYSIS['model']['model'] == 'spearman'):
    for window in ANALYSIS['model']['rolling_windows']:
        create_splom(COVID_PAIRS_DF_TIME_SLICED, ANALYSIS, window, REGION_COLOR_MAP, SHOW_REGIONS)
    
# visualize the outcome of the regional correlation generate Choropleths
create_choropleth_plots(ANALYSIS_RESULTS_DATASETS, COUNTRY_REGION_JSON, ANALYSIS, ROW_HEIGHT_CHOROPLETH, COLUMN_WIDTH_CHOROPLETH)