In [1]:
import pickle
import sys
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import datetime
import seaborn as sns
import glob
import arviz as az

sys.path.append("../")
sys.path.append("../covid19_inference")

import covid19_soccer
from covid19_soccer.plot.utils import get_from_trace
import covid19_inference as cov19

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
""" Matplotlib config
"""
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams["figure.figsize"] = [3.4, 2.7]  # APS single column
matplotlib.rcParams["figure.dpi"] = 300  # this primarily affects the size on screen
#matplotlib.rcParams['axes.linewidth'] = 0.3
matplotlib.rcParams["axes.labelcolor"] = "black"
matplotlib.rcParams["axes.edgecolor"] = "black"
matplotlib.rcParams["xtick.color"] = "black"
matplotlib.rcParams["ytick.color"] = "black"
matplotlib.rcParams["xtick.labelsize"] = 8
matplotlib.rcParams["ytick.labelsize"] = 8
matplotlib.rcParams["axes.labelsize"] = 8
matplotlib.rcParams["axes.titlesize"]= 10
matplotlib.rcParams["legend.fontsize"] = 6
matplotlib.rcParams["legend.title_fontsize"] = 8

""" Colors
Done with rcParams / see plot.rcParams
"""


""" General configs
"""
fig_path = "./figures"
model_path = f"/data.nst/jdehning/covid_uefa_traces11/"

# Save figure as pdf and png        
save_kwargs = {
    "transparent":True,
    "dpi":300,
    "bbox_inches":"tight"
}

def load(fstr):
    with open(fstr, "rb") as f:
         return pickle.load(f)

In [4]:
# Extract filenames
import re
file_paths = list(glob.glob(model_path+"*.pickled"))
params = []
for fpath in file_paths:
    # Extract params
    string_keys = re.search(r'-([^$]*)-', fpath).group(1)
    
    # Replace minus signs
    for n in range(0,10):
        string_keys = string_keys.replace(f"-{n}",f"m{n}")
        
    params.append(dict(subString.split("=") for subString in string_keys.split("-")))

In [5]:
countries = np.unique([t["country"] for i, t in enumerate(params)])

In [42]:
# Iterate all countries with some multithreading beacause it takes forever otherwise
from concurrent.futures import ProcessPoolExecutor, as_completed
import os

os.environ["OMP_NUM_THREADS"] = "32"
os.environ["OPENBLAS_NUM_THREADS"] = "32"
os.environ["MKL_NUM_THREADS"] = "32"
os.environ["VECLIB_MAXIMUM_THREADS"] = "32"
os.environ["NUMEXPR_NUM_THREADS"] = "32"

def get_compares_and_sample_stats(country):
    """ For each country we want to compare model runs in different categories.

        - offset
        - prior_delay
        - median width delay
        - interval cps
    """
    
    model_indexs = [i for i, t in enumerate(params) if t["country"]==country and t["draws"] =="1000"]
    
    # (i) Offset
    unique_offsets = np.unique([t["offset_data"].replace("m","-") for i, t in enumerate(np.array(params)[model_indexs])])
    traces, models, dls, meta = [],[],[],[]
    for offset in unique_offsets:
        fstr=(f"/data.nst/jdehning/covid_uefa_traces11/UEFA"+
            f"-beta=False"+
            f"-country={country}"+
            f"-offset_data={int(offset)}"+
            f"-draw_delay=True"+
            f"-weighted_alpha_prior=0"+
            f"-prior_delay={-1}"+
            f"-width_delay_prior={0.1}"+
            f"-sigma_incubation=-1.0"+
            f"-median_width_delay=1.0"+
            f"-interval_cps=10.0"+
            f"-tune={500}"+
            f"-draws={1000}"+
            f"-max_treedepth={12}.pickled")
        model, trace = load(fstr)
        traces.append(trace)
        models.append(model)
        meta.append(f"Offset {offset}")
        dls.append(covid19_soccer.dataloader.Dataloader_gender(countries=[country],offset_data=int(offset)))
    
    compare = az.compare(dict(zip(unique_offsets,traces)))
    
    diagnostics = {}
    #.to_array().max()
    diagnostics["rhat"] = [az.rhat(trace) for trace in traces]
    diagnostics["ess"] = [az.ess(trace) for trace in traces]
    diagnostics["bfmi"] = [az.bfmi(trace) for trace in traces]
    return compare, diagnostics, traces,  country
    # Compare in these four for ["Scotland", "Germany", "France", "England", "Spain", "Czechia", "Italy"]
    
    
    # (ii) Prior delay & prior median width delay
    unique_delays = np.unique([t["prior_delay"].replace("m","-") for i, t in enumerate(np.array(params)[model_indexs])])
    unique_median_width_delay = np.unique([t["median_width_delay"].replace("m","-") for i, t in enumerate(np.array(params)[model_indexs])])
    traces, models, dls, meta = [],[],[],[]
    for delay in unique_delays:
        for median_width in unique_median_width_delay:
            fstr=(f"/data.nst/jdehning/covid_uefa_traces11/UEFA"+
                f"-beta=False"+
                f"-country={country}"+
                f"-offset_data={0}"+
                f"-draw_delay=True"+
                f"-weighted_alpha_prior=0"+
                f"-prior_delay={int(delay)}"+
                f"-width_delay_prior={0.1}"+
                f"-sigma_incubation=-1.0"+
                f"-median_width_delay={int(median_width)}"+
                f"-interval_cps=10.0"+
                f"-tune={500}"+
                f"-draws={1000}"+
                f"-max_treedepth={12}.pickled")
            model, trace = load(fstr)
            traces.append(trace)
            models.append(model)
            dls.append(covid19_soccer.dataloader.Dataloader_gender(countries=[country]))
            meta.append(f"Prior delay: {delay}, median_width {median_width}")
    
    compare = az.compare(dict(zip(meta,traces)))

    
    # (iii) Interval cps
    unique_cps_interval = np.unique([t["interval_cps"].replace("m","-") for i, t in enumerate(np.array(params)[model_indexs])])
    traces, models, dls, meta = [],[],[],[]
    for interval in unique_cps_interval:
        fstr=(f"/data.nst/jdehning/covid_uefa_traces11/UEFA"+
            f"-beta=False"+
            f"-country={country}"+
            f"-offset_data={0}"+
            f"-draw_delay=True"+
            f"-weighted_alpha_prior=0"+
            f"-prior_delay={-1}"+
            f"-width_delay_prior={0.1}"+
            f"-sigma_incubation=-1.0"+
            f"-median_width_delay=1.0"+
            f"-interval_cps={interval}"+
            f"-tune={500}"+
            f"-draws={1000}"+
            f"-max_treedepth={12}.pickled")
        model, trace = load(fstr)
        traces.append(trace)
        models.append(model)
        dls.append(covid19_soccer.dataloader.Dataloader_gender(countries=[country]))
        meta.append(f"Interval cps: {interval}")
        
    compare = az.compare(dict(zip(unique_offsets,traces)))
    sample_stats = dict(zip(unique_offsets,[trace.sample_stats for trace in traces]))                        
                        
                        
    #prior_delay = [-1, 2, 3, 4, 5, 6, 7, 8, 10, 12]
    #median_width_delay = [0.5, 1.0, 2.0]
    #interval_cps = [10.0, 6.0, 20.0]

    
    model_fpaths = np.array(file_paths)[model_indexs]
    
    models,traces = [],[]
    for fpath in model_fpaths:
        # Skip low number tune runs
        model, trace = load(fpath)
        models.append(model)
        traces.append(trace)
        
    compare = az.compare(dict(zip(model_indexs,traces)))
    sample_stats = dict(zip(model_indexs,[trace.sample_stats for trace in traces]))
    #az.plot_compare(compare)
    return compare, sample_stats, country

In [69]:
compare, diagnostics, traces, country = get_compares_and_sample_stats("Czechia")

  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


In [114]:
compare

Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale,rhat_max
14,0,-609.124668,36.300766,0.0,0.3524235,11.664768,0.0,False,log,2.001083
21,1,-613.748262,42.493815,4.623594,0.339791,11.9298,16.998987,False,log,1.381952
10,2,-618.902045,35.876766,9.777377,0.148097,11.509544,16.187423,False,log,1.482736
7,3,-620.025857,38.092066,10.901188,0.1596884,11.124308,16.021587,False,log,1.513088
4,4,-629.105759,39.713738,19.98109,0.0,11.444002,16.487619,False,log,1.94556
2,5,-632.517831,41.893724,23.393162,0.0,10.862746,16.395213,False,log,1.521534
0,6,-635.198596,42.397368,26.073928,0.0,11.204347,16.603706,False,log,1.277737
35,7,-641.783662,42.540731,32.658993,0.0,12.573217,16.638641,False,log,1.244198
-2,8,-648.309971,62.278173,39.185303,0.0,12.292169,18.502386,True,log,1.713746
-4,9,-648.483788,55.966528,39.359119,0.0,12.357784,17.663009,True,log,1.28571


In [123]:
diagnostics["ess"][0]


In [112]:
compare["rhat_max"] = [np.max(np.array(d.max().to_array())) for d in diagnostics["rhat"]]

In [113]:
compare

Unnamed: 0,rank,loo,p_loo,d_loo,weight,se,dse,warning,loo_scale,rhat_max
14,0,-609.124668,36.300766,0.0,0.3524235,11.664768,0.0,False,log,2.001083
21,1,-613.748262,42.493815,4.623594,0.339791,11.9298,16.998987,False,log,1.381952
10,2,-618.902045,35.876766,9.777377,0.148097,11.509544,16.187423,False,log,1.482736
7,3,-620.025857,38.092066,10.901188,0.1596884,11.124308,16.021587,False,log,1.513088
4,4,-629.105759,39.713738,19.98109,0.0,11.444002,16.487619,False,log,1.94556
2,5,-632.517831,41.893724,23.393162,0.0,10.862746,16.395213,False,log,1.521534
0,6,-635.198596,42.397368,26.073928,0.0,11.204347,16.603706,False,log,1.277737
35,7,-641.783662,42.540731,32.658993,0.0,12.573217,16.638641,False,log,1.244198
-2,8,-648.309971,62.278173,39.185303,0.0,12.292169,18.502386,True,log,1.713746
-4,9,-648.483788,55.966528,39.359119,0.0,12.357784,17.663009,True,log,1.28571


2.001082910444726

In [None]:
processes = []
with ProcessPoolExecutor() as executor:
    for country in countries:
        processes.append(executor.submit(get_compares_and_sample_stats, country))

compares = {}
sample_stats = {}
for task in as_completed(processes):
    compare, sample_stat, country = task.result()
    compares[country] = compare
    sample_stats[country] = sample_stat