In [1]:
import pickle
import sys

from tabulate import tabulate
import numpy as np
from tqdm import tqdm


sys.path.append("../")
sys.path.append("../covid19_inference")

import covid19_soccer

In [18]:
from tabulate import tabulate


In [2]:
%load_ext autoreload
%autoreload 2

In [7]:
countries = ["England","Scotland","Germany","France","Spain","Slovakia","Portugal","Netherlands","Italy","Czechia","Belgium","Austria"]

def load(fstr):
    with open(fstr, "rb") as f:
         return pickle.load(f)

traces, models, dls = [], [], []
for country in tqdm(countries):
    #'UEFA-beta=False-country=England-offset_games=0-draw_delay=True-weighted_alpha_prior=0-prior_delay=-1-width_delay_prior=0.1-sigma_incubation=-1.0-median_width_delay=1.0-tune=200-draws=300-max_treedepth=10.pickled'
    model = None
    fstr=lambda tune, draws, max_treedepth: (f"/data.nst/smohr/covid19_soccer_data/main_traces/"+
        f"-beta=False"+
        f"-country={country}"+
        f"-offset_data=0"+
        f"-prior_delay=-1"+
        f"-width_delay_prior=0.1"+
        f"-sigma_incubation=-1.0"+
        f"-median_width_delay=1.0"+
        f"-interval_cps=10.0"+
        f"-f_fem=0.2"+
        f"-uc=True"
        f"-len=normal"+                    
        f"-t={tune}"+
        f"-d={draws}"+
        f"-max_treedepth={max_treedepth}.pkl")
    #print(fstr(4000, 8000, 12))
    if os.path.exists(fstr(4000, 8000, 12)):
        try:
            model, trace = load(fstr(4000, 8000, 12))
            print(f"Use 8000 sample runs for {country}")
        except Exception as exc:
            print(exc)
            pass
    if model is None and os.path.exists(fstr(2000, 4000, 12)):
        try:
            model, trace = load(fstr(2000, 4000, 12))
            print(f"Use 4000 sample runs for {country}")
        except Exception as exc:
            print(exc)
            pass
    if model is None and os.path.exists(fstr(1000, 1500, 12)):
        try: 
            model, trace = load(fstr(1000, 1500, 12))
            print(f"Use 1500 sample runs for {country}")
        except Exception as exc:
            print(exc)
            pass
    if model is None: 
        print(fstr(1000, 1500, 12), " not found")
        continue
    
    # Remove chains with likelihood larger than -200, should only be the case for 2 chains in France
    mask = (np.mean(trace.sample_stats.lp, axis=1)>-200)
    trace.posterior = trace.posterior.sel(chain=~mask)
    
    dl = covid19_soccer.dataloader.Dataloader_gender(countries=[country])
    models.append(model)
    traces.append(trace)
    dls.append(dl)

  0%|          | 0/12 [00:00<?, ?it/s]

Use 1500 sample runs for England


  8%|▊         | 1/12 [01:16<13:56, 76.04s/it]

Use 8000 sample runs for Scotland


 17%|█▋        | 2/12 [06:35<36:31, 219.19s/it]

Use 8000 sample runs for Germany


 25%|██▌       | 3/12 [11:41<38:47, 258.64s/it]

Use 4000 sample runs for France


 33%|███▎      | 4/12 [12:57<24:52, 186.57s/it]

Use 8000 sample runs for Spain


 42%|████▏     | 5/12 [17:12<24:38, 211.22s/it]

Use 8000 sample runs for Slovakia


 50%|█████     | 6/12 [22:40<25:06, 251.10s/it]

Use 4000 sample runs for Portugal


 58%|█████▊    | 7/12 [25:13<18:15, 219.05s/it]

Use 8000 sample runs for Netherlands


 67%|██████▋   | 8/12 [30:18<16:25, 246.28s/it]

Use 8000 sample runs for Italy


 75%|███████▌  | 9/12 [35:10<13:02, 260.70s/it]

Use 8000 sample runs for Czechia


 83%|████████▎ | 10/12 [40:19<09:11, 275.63s/it]

Use 8000 sample runs for Belgium


 92%|█████████▏| 11/12 [44:54<04:35, 275.49s/it]

Use 8000 sample runs for Austria


100%|██████████| 12/12 [49:52<00:00, 249.37s/it]


In [14]:
from multiprocessing import Pool
import arviz as az

def get_max_rhat(country, trace):
    rhat = az.rhat(trace)
    chosen_vars = []
    for var in list(rhat.data_vars):
        if "lambda" in var:
            chosen_vars.append(var)
        if "R_t" in var:
            chosen_vars.append(var)
        if "factor_female" in var:
            chosen_vars.append(var)
        if "alpha" in var:
            chosen_vars.append(var)

    rhats_R_t = []
    rhat_max = rhat.max()
    for var in chosen_vars:
        rhats_R_t.append(float(rhat_max[var]))
    rhats_all = []
    for var in rhat.data_vars:
        rhats_all.append(float(rhat_max[var]))
    print(f"{country} max rhat of R_t {max(rhats_R_t)} ")
    print(f"{country} max overall rhat {max(rhats_all)} ")
        
    return rhats_R_t, rhats_all
    

#def get_max_rhat(country, trace):
#    return 3.
    
rhats = []
    
with Pool(processes=6) as pool:
    results = []
    for i, country in enumerate(countries):
        trace = traces[i]    
        result = pool.apply_async(get_max_rhat, (country, trace)) 
        results.append(result)
    for res in results:
        rhats.append(res.get())


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


England max rhat of R_t 2.26046650711145 
England max overall rhat 3.2903463960110844 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


France max rhat of R_t 1.0441943450642983 
France max overall rhat 2.7155046606491227 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Scotland max rhat of R_t 1.003280366413768 
Scotland max overall rhat 1.3451921538974065 
Germany max rhat of R_t 1.004522335118012 
Germany max overall rhat 2.1864668463112746 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Portugal max rhat of R_t 1.0016870174018668 
Portugal max overall rhat 1.108063109459645 
Spain max rhat of R_t 1.0171315883037155 
Spain max overall rhat 3.307441441974928 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Slovakia max rhat of R_t 1.0047990382654182 
Slovakia max overall rhat 1.8882625992770752 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Netherlands max rhat of R_t 1.5897480426200736 
Netherlands max overall rhat 1.5897480426200736 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Italy max rhat of R_t 1.0004530871956616 
Italy max overall rhat 1.0603302653943796 


  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Czechia max rhat of R_t 1.0079267461336 
Czechia max overall rhat 1.3227219918863116 
Belgium max rhat of R_t 1.0128744270596246 
Belgium max overall rhat 1.4554976676470603 
Austria max rhat of R_t 1.002032818308459 
Austria max overall rhat 1.0651445571412297 


In [20]:
table = []
for i, country in enumerate(countries):
    table.append([country, f"{max(rhats[i][0]):.2f}"])
header = ["Country", "Max. R-hat of relevant variables"]

print(tabulate(table, header, tablefmt="latex_raw"))

\begin{tabular}{lr}
\hline
 Country     &   Max. R-hat of relevant variables \\
\hline
 England     &                               2.26 \\
 Scotland    &                               1    \\
 Germany     &                               1    \\
 France      &                               1.04 \\
 Spain       &                               1.02 \\
 Slovakia    &                               1    \\
 Portugal    &                               1    \\
 Netherlands &                               1.59 \\
 Italy       &                               1    \\
 Czechia     &                               1.01 \\
 Belgium     &                               1.01 \\
 Austria     &                               1    \\
\hline
\end{tabular}
