In [None]:
import pandas as pd
import numpy as np

import json
import copy

# Czechia - load local area information

In [None]:
cz_df = pd.read_csv('../../data/raw_data_w_sources/cz_cases_deaths.csv')

In [None]:
cz_df = pd.read_csv('../../data/raw_data_w_sources/cz_cases_deaths.csv')
cz_df['date'] = pd.to_datetime(cz_df['date'], dayfirst=True)
cz_df = cz_df[~cz_df['LAU Unit'].isnull()]
cz_df = cz_df.rename({'NUTS3 Unit': 'area'}, axis=1)

In [None]:
cz_timeseries_df = cz_df.groupby(['area', 'date']).sum()

In [None]:
nuts3_area_info_list = []

start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']
for nuts3_area in cz_timeseries_df.index.unique(0):
    nuts3_area_dict = {
        'area': nuts3_area,
    }
    
    for col_name, start_date in zip(col_names, start_dates):
        nuts3_area_dict[f'{col_name}-cumcases'] = cz_timeseries_df.loc[nuts3_area].loc[start_date]['Infected']
        nuts3_area_dict[f'{col_name}-cumdeaths'] = cz_timeseries_df.loc[nuts3_area].loc[start_date]['Deaths']
    
    nuts3_area_info_list.append(nuts3_area_dict)

In [None]:
cz_area_df = pd.DataFrame(nuts3_area_info_list).set_index('area')

In [None]:
cz_timeseries_df['new_cases'] = cz_timeseries_df.groupby(level=[0]).diff()['Infected']
cz_timeseries_df['new_deaths'] = cz_timeseries_df.groupby(level=[0]).diff()['Deaths']

# Switzerland - load local area information

In [None]:
swiss_df = pd.read_csv('../../data/raw_data_w_sources/ch_cases_deaths.csv')

In [None]:
swiss_df = swiss_df.drop(swiss_df.columns.difference(['date', 'abbreviation_canton_and_fl', 'ncumul_conf', 'ncumul_deceased']), axis=1)

In [None]:
swiss_df = swiss_df.rename({'abbreviation_canton_and_fl': 'area'}, axis=1)

In [None]:
import json

with open('../../data/raw_data_w_sources/ch_canton_lookup.json', 'r') as fp:
    swiss_canton_lookup = json.load(fp)

In [None]:
dates = pd.date_range('2020-03-01', '2021-01-01')
swiss_df['date'] = pd.to_datetime(swiss_df['date'])

In [None]:
filled_swiss_df = swiss_df.set_index('date').groupby('area').apply(lambda x: x.reindex(dates, fill_value=None).drop('area', axis=1)).reset_index()

In [None]:
swiss_ts_df = filled_swiss_df.replace(swiss_canton_lookup)
swiss_ts_df = swiss_ts_df.rename({'level_1': 'date'}, axis=1)
swiss_ts_df = swiss_ts_df.set_index(['area', 'date'])

In [None]:
for canton in swiss_ts_df.index.unique(0):
    if np.isnan(swiss_ts_df.loc[canton].loc[dates[0]]['ncumul_conf']):
        swiss_ts_df.loc[(canton, dates[0]), 'ncumul_conf'] = 0
    
    if np.isnan(swiss_ts_df.loc[canton].loc[dates[0]]['ncumul_deceased']):
        swiss_ts_df.loc[(canton, dates[0]), 'ncumul_deceased'] = 0
    
    interp_df = swiss_ts_df.loc[canton].interpolate()
    interp_df['new_cases'] = interp_df['ncumul_conf'].diff()
    interp_df['new_deaths'] = interp_df['ncumul_deceased'].diff()
    for date in dates:
        swiss_ts_df.loc[(canton, date), 'ncumul_conf'] = interp_df.loc[date, 'ncumul_conf']
        swiss_ts_df.loc[(canton, date), 'ncumul_deceased'] = interp_df.loc[date, 'ncumul_deceased']
        swiss_ts_df.loc[(canton, date), 'new_cases'] = np.around(interp_df.loc[date, 'new_cases'])
        swiss_ts_df.loc[(canton, date), 'new_deaths'] = np.around(interp_df.loc[date, 'new_deaths'])

## Warning: there was some missing data here

e.g., in some areas, some days are missing. To get around this, I did linear interpolation. 

In [None]:
swiss_df = swiss_df.replace(swiss_canton_lookup)
swiss_df = swiss_df.set_index(['area', 'date'])

In [None]:
for canton in swiss_df.index.unique(0):
    canton_df = swiss_df.loc[canton]
    all_dates = pd.date_range(canton_df.index[0], canton_df.index[-1])
    print(f'Canton: {canton} has {100*float(len(canton_df.index)/len(all_dates)):.2f}% of dates and {canton_df.loc[canton_df.index[-1]]["ncumul_conf"]} cases all pandemic')

In [None]:
canton_info_list = []

start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']
for canton in swiss_df.index.unique(0):
    canton_dict = {
        'area': canton,
    }
    
    for col_name, start_date in zip(col_names, start_dates):
        canton_dict[f'{col_name}-cumcases'] = swiss_ts_df.loc[canton].loc[start_date]['ncumul_conf']
        canton_dict[f'{col_name}-cumdeaths'] = swiss_ts_df.loc[canton].loc[start_date]['ncumul_deceased']
    
    canton_info_list.append(canton_dict)

In [None]:
swiss_area_df = pd.DataFrame(canton_info_list).set_index('area')

In [None]:
swiss_ts_df = swiss_ts_df.drop(['ncumul_conf', 'ncumul_deceased'], axis=1)

In [None]:
swiss_timeseries_df = swiss_ts_df

# Germany - load local area information

In [None]:
# ags dict contains information about the local areas of germany
with open('../../data/raw_data_w_sources/de_ags.json') as json_file:
    ags_info_dict = json.load(json_file)

In [None]:
cases_df = pd.read_csv('../../data/raw_data_w_sources/de_cases-rki-by-ags.csv')
cases_df = cases_df.drop('sum_cases', axis=1)
cases_df = cases_df.rename({'time_iso8601': 'date'}, axis=1)
cases_df['date'] = pd.to_datetime(cases_df['date'])
cases_df['date'] = pd.to_datetime(cases_df['date'].dt.date)
cases_df = cases_df.set_index('date')

deaths_df = pd.read_csv('../../data/raw_data_w_sources/de_deaths-rki-by-ags.csv')
deaths = deaths_df.drop('sum_deaths', axis=1)
deaths_df = deaths_df.rename({'time_iso8601': 'date'}, axis=1)
deaths_df['date'] = pd.to_datetime(deaths_df['date'])
deaths_df['date'] = pd.to_datetime(deaths_df['date'].dt.date)
deaths_df = deaths_df.set_index('date')

In [None]:
ags_info_list = []

start_dates = ['2020-03-02', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']
for ags in ags_info_dict.keys():
    if ags == '3152':
        continue
    
    ags_dict = {
        'area': ags_info_dict[ags]['name'],
        'region': ags_info_dict[ags]['state'],
        'country': 'Germany',
    }
    for col_name, start_date in zip(col_names, start_dates):
        ags_dict[f'{col_name}-cumcases'] = cases_df[ags].loc[start_date]
        ags_dict[f'{col_name}-cumdeaths'] = deaths_df[ags].loc[start_date]
    
    ags_info_list.append(ags_dict)

In [None]:
cases_df = pd.read_csv('../../data/raw_data_w_sources/de_cases-rki-by-ags.csv')
cases_df = cases_df.drop('sum_cases', axis=1)
cases_df = cases_df.rename({'time_iso8601': 'date'}, axis=1)
cases_df['date'] = pd.to_datetime(cases_df['date'])
cases_df['date'] = pd.to_datetime(cases_df['date'].dt.date)
cases_df = cases_df.set_index('date')
cases_df = cases_df.diff()

deaths_df = pd.read_csv('../../data/raw_data_w_sources/de_deaths-rki-by-ags.csv')
deaths = deaths_df.drop('sum_deaths', axis=1)
deaths_df = deaths_df.rename({'time_iso8601': 'date'}, axis=1)
deaths_df['date'] = pd.to_datetime(deaths_df['date'])
deaths_df['date'] = pd.to_datetime(deaths_df['date'].dt.date)
deaths_df = deaths_df.set_index('date')
deaths_df = cases_df.diff()

In [None]:
ags_time_series_list = []

Ds = pd.date_range('2020-03-02', '2021-01-01')
for ags in ags_info_dict.keys():
    if ags == '3152':
        continue
        
    for d in Ds:
        ags_dict = {
            'area': ags_info_dict[ags]['name'],
            'date': d
        }
        ags_dict['new_cases'] = cases_df[ags][d]
        ags_dict['new_deaths'] = deaths_df[ags][d]
        
        ags_time_series_list.append(ags_dict)

In [None]:
germany_ts_df = pd.DataFrame(ags_time_series_list)

In [None]:
germany_area_df = pd.DataFrame(ags_info_list)

In [None]:
germany_area_df = germany_area_df.set_index('area')

In [None]:
germany_ts_df = germany_ts_df.set_index(['area', 'date'])

# UK - load ltla df

Note: for the UK, at the moment, we don't have region info for the LTLAs in Northern Ireland, Scotland, or Wales.

In [None]:
with open('../../data/raw_data_w_sources/uk_ltla_info.json') as json_file:
    uk_ltla_info_dict = json.load(json_file)

uk_ltla_info_df = pd.DataFrame([d['attributes'] for d in uk_ltla_info_dict['features']])
uk_ltla_info_df = uk_ltla_info_df.rename({'LAU117NM': 'area', 'NUTS318NM': 'NUTS3', 'NUTS118NM': 'region'} ,axis=1)
uk_ltla_info_df = uk_ltla_info_df.set_index('area')

In [None]:
uk_df = pd.read_csv('../../data/raw_data_w_sources/uk_case_deaths.csv', infer_datetime_format=True)
uk_df = uk_df.drop(['areaCode', 'newCasesByPublishDate', 'newDeaths28DaysByPublishDate'], axis=1)
uk_df['areaType'] = 'UK'
uk_df = uk_df.rename({'areaType': 'country', 'areaName':'area', 'newCasesBySpecimenDate': 'new_cases', 'newDeaths28DaysByDeathDate': 'new_deaths'}, axis=1)
uk_df = uk_df.set_index(['area', 'date'])

In [None]:
uk_df = uk_df.sort_index(level=[1],ascending=[True])

In [None]:
uk_areas = uk_df.index.unique(0)

In [None]:
start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']

In [None]:
uk_area_list = []

for ltla in uk_areas:

    try:
        region = uk_ltla_info_df.loc[ltla]['region']
        nuts3 = uk_ltla_info_df.loc[ltla]['NUTS3']
    except KeyError:
        print(f'{ltla} missing in my lookup table')
        region = 'unknown'
        nuts3 = 'unknown'
        
    
    ltla_dict = {
        'name': ltla,
        'region':  region,
        'NUTS3':  nuts3,
        'country': 'UK',
    }
    
    cum_cases = uk_df.loc[ltla]['new_cases'].cumsum()
    cum_deaths = uk_df.loc[ltla]['new_deaths'].cumsum()
    for col_name, start_date in zip(col_names, start_dates):
        ltla_dict[f'{col_name}-cumcases'] = cum_cases[start_date]
        ltla_dict[f'{col_name}-cumdeaths'] = cum_deaths[start_date]
    
    uk_area_list.append(ltla_dict)

In [None]:
uk_area_df = pd.DataFrame(uk_area_list)

# UK LTLA converted to NUTS3

In [None]:
def NUTS3_lookup(ltla):
    try:
        nuts3 = uk_ltla_info_df.loc[ltla]['NUTS3']
    except KeyError:
#         print(f'{ltla} missing in my lookup table')
        nuts3 = 'unknown'
    return nuts3

In [None]:
nuts3_uk_df = uk_df.reset_index()

In [None]:
nuts3_uk_df['NUTS3'] = nuts3_uk_df['area'].map(NUTS3_lookup)

In [None]:
days = nuts3_uk_df['date'].unique()
nuts3_regions = nuts3_uk_df['NUTS3'].unique()

In [None]:
nuts3_df_list = []

nuts3_uk_df_merged = None

for nuts3_region in nuts3_regions:
    if nuts3_region == 'unknown':
        continue
    
    filtered_df = nuts3_uk_df.loc[nuts3_uk_df['NUTS3'] == nuts3_region]
    
    case_death_series = filtered_df.groupby('date').sum()
    case_death_series['area'] = nuts3_region
    
    if nuts3_uk_df_merged is None:
        nuts3_uk_df_merged = copy.deepcopy(case_death_series)
    else:
        nuts3_uk_df_merged = nuts3_uk_df_merged.append(case_death_series)
    
nuts3_uk_df_merged = nuts3_uk_df_merged.reset_index()
nuts3_uk_df_merged['date'] = pd.to_datetime(nuts3_uk_df_merged['date'])
nuts3_uk_df_merged = nuts3_uk_df_merged.set_index(['area', 'date'])
nuts3_uk_df_merged = nuts3_uk_df_merged.sort_index(level=[1],ascending=[True])

In [None]:
uk_nuts3_area_list = []

for nuts3_region in nuts3_regions:
    if nuts3_region == 'unknown':
        continue
            
    nuts3_dict = {
        'area': nuts3_region,
        'region': uk_ltla_info_df.loc[uk_ltla_info_df['NUTS3'] == nuts3_region]['region'][0]
    }
    cum_cases = nuts3_uk_df_merged.loc[nuts3_region]['new_cases'].cumsum()
    cum_deaths = nuts3_uk_df_merged.loc[nuts3_region]['new_deaths'].cumsum()
    
    for col_name, start_date in zip(col_names, start_dates):
        nuts3_dict[f'{col_name}-cumcases'] = cum_cases[start_date]
        nuts3_dict[f'{col_name}-cumdeaths'] = cum_deaths[start_date]
    
    uk_nuts3_area_list.append(nuts3_dict)

In [None]:
uk_nuts3_area_df = pd.DataFrame(uk_nuts3_area_list)

In [None]:
uk_timeseries_df = nuts3_uk_df_merged
uk_area_df = uk_nuts3_area_df.set_index('area')

In [None]:
uk_area_df = uk_area_df.loc[uk_area_df['region'] != 'Scotland']
uk_area_df = uk_area_df.loc[uk_area_df['region'] != 'Wales']
uk_area_df = uk_area_df.loc[uk_area_df['region'] != 'Northern Ireland']

# Austria - load local area dataframe

In [None]:
austria_ltla_lookup = pd.read_csv('../../data/raw_data_w_sources/at_lau_lookup.csv')
austria_ltla_lookup = austria_ltla_lookup.set_index('GKZ')

def at_ltla_lookup(ltla):
    if ltla in austria_ltla_lookup.index:
        return austria_ltla_lookup.loc[ltla]['State Code (middle column of HASC)']
    return 'Vienna'

In [None]:
austria_df = pd.read_csv('../../data/raw_data_w_sources/at_case_deaths.csv', error_bad_lines=False, delimiter=';', skiprows=1)

In [None]:
austria_df = pd.read_csv('../../data/raw_data_w_sources/at_case_deaths.csv', error_bad_lines=False, delimiter=';', skiprows=1)
austria_df = austria_df.drop([' number of cases total',
       ' number of cases of 7 days', ' seven days of incidence cases',' number of total totals',
       ' number of held daily', ' number of healing total'], axis=1)
austria_df[' GKZ'] = austria_df[' GKZ'].map(at_ltla_lookup)
austria_df = austria_df.rename({'Time': 'date', ' district': 'area', ' GKZ': 'region', ' number of inhabitants': 'population', ' number of cases': 'new_cases', ' number of dead daily': 'new_deaths'}, axis=1)
austria_df = austria_df.drop('population', axis=1)
austria_df['date'] = pd.to_datetime(austria_df['date'], format='%d.%m.%Y %M:%H:%S')

In [None]:
austria_df = austria_df.set_index(['area', 'date'])

In [None]:
aut_area_list = []
start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']

start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']


for area in austria_df.index.unique(0):
    area_dict = {
        'area': area,
        'region': austria_df.loc[area].iloc[0]['region'],
        'country': 'Austria',
    }
    
    cum_cases = austria_df.loc[area]['new_cases'].cumsum()
    cum_deaths = austria_df.loc[area]['new_deaths'].cumsum()
    for col_name, start_date in zip(col_names, start_dates):
        area_dict[f'{col_name}-cumcases'] = cum_cases[start_date]
        area_dict[f'{col_name}-cumdeaths'] = cum_deaths[start_date]
    
    aut_area_list.append(area_dict)

In [None]:
austria_timeseries_df = austria_df

In [None]:
austria_area_df = pd.DataFrame(aut_area_list).set_index('area')

# Austria - but done at a higher level

In [None]:
austria_nuts2_regions = austria_timeseries_df['region'].unique()

In [None]:
austria_nuts2_df_list = []

austria_nuts2_df_merged = None

for nuts2_region in austria_nuts2_regions:    
    filtered_df = austria_timeseries_df.loc[austria_timeseries_df['region'] == nuts2_region]
    
    case_death_series = filtered_df.groupby('date').sum()
    case_death_series['area'] = nuts2_region
    
    if austria_nuts2_df_merged is None:
        austria_nuts2_df_merged = copy.deepcopy(case_death_series)
    else:
        austria_nuts2_df_merged = austria_nuts2_df_merged.append(case_death_series)
    
austria_nuts2_df_merged = austria_nuts2_df_merged.reset_index()
austria_nuts2_df_merged = austria_nuts2_df_merged.set_index(['area', 'date'])
austria_nuts2_df_merged = austria_nuts2_df_merged.sort_index(level=[1],ascending=[True])

In [None]:
austria_nuts2_timeseries_df = austria_nuts2_df_merged

In [None]:
austria_nuts2_area_list = []

for nuts2_region in austria_nuts2_regions:
    nuts2_dict = {
        'area': nuts2_region
    }
    
    cum_cases = austria_nuts2_timeseries_df.loc[nuts2_region]['new_cases'].cumsum()
    cum_deaths = austria_nuts2_timeseries_df.loc[nuts2_region]['new_deaths'].cumsum()
    
    for col_name, start_date in zip(col_names, start_dates):
        nuts2_dict[f'{col_name}-cumcases'] = cum_cases[start_date]
        nuts2_dict[f'{col_name}-cumdeaths'] = cum_deaths[start_date]
    
    austria_nuts2_area_list.append(nuts2_dict)

In [None]:
austria_nuts2_area_df = pd.DataFrame(austria_nuts2_area_list)
austria_nuts2_area_df = austria_nuts2_area_df.set_index('area')

# Load Italy Case and Death Data

In [None]:
italy_df = pd.read_csv('../../data/raw_data_w_sources/it_cases_deaths.csv', delimiter=',')
italy_df['date'] = pd.to_datetime(italy_df['date'])
italy_df['date'] = italy_df['date'].dt.date
italy_df = italy_df.set_index(['area', 'date'])
italy_df['new_deaths'] = italy_df.groupby('area').diff()['total_deaths']
italy_df = italy_df.drop('total_deaths', axis=1)

In [None]:
italy_timeseries_df = italy_df

In [None]:
italy_area_list = []
start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']

start_dates = ['2020-03-01', '2020-04-01','2020-05-01', '2020-06-01', '2020-07-01', '2020-08-01', '2020-09-01', '2020-10-01', '2020-11-01', '2020-12-01', '2021-01-01']
col_names = ['MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC', 'JAN']


for area in italy_timeseries_df.index.unique(0):
    area_dict = {
        'area': area,
    }
    
    cum_cases = italy_timeseries_df.loc[area]['new_cases'].cumsum()
    cum_deaths = italy_timeseries_df.loc[area]['new_deaths'].cumsum()
    for col_name, start_date in zip(col_names, start_dates):
        area_dict[f'{col_name}-cumcases'] = cum_cases[start_date]
        area_dict[f'{col_name}-cumdeaths'] = cum_deaths[start_date]
    
    italy_area_list.append(area_dict)

italy_area_df = pd.DataFrame(italy_area_list).set_index('area')

# Actually Perform Thresholding

In [None]:
import matplotlib.pyplot as plt

def threshold_cumulative_cases(area_df, threshold = 200, start_str='AUG', end_str='JAN'):
    total_cases = area_df[f'{end_str}-cumcases'] - area_df[f'{start_str}-cumcases']
    filtered_df = total_cases[total_cases>threshold].to_frame(f'cumulative_cases_{start_str}_{end_str}')
    filtered_df = filtered_df.sort_values(filtered_df.columns[0])
    
    return filtered_df, 100*len(filtered_df.index)/len(area_df.index)

def verify_timeseries(filtered_df, timeseries_df, plot_title='', n_regions=5):
    areas_to_plot = filtered_df.index[:n_regions]
    
    plt.figure(figsize=(8, 8), dpi=300)
    
    Ds = pd.date_range('2020-08-01', '2020-12-30')
    for p_i, area in enumerate(areas_to_plot):
        plt.subplot(n_regions, 1, p_i+1)
        plt.plot(timeseries_df.loc[area].loc[Ds]['new_cases'], color='tab:blue')
        plt.ylabel('cases')
        plt.yscale('log')
        plt.ylim([10**0.5, 10**3])
        plt.twinx()
        plt.plot(timeseries_df.loc[area].loc[Ds]['new_deaths'], color='tab:red')
        plt.title(f'{area}')
        plt.yscale('log')
        plt.ylabel('deaths')
        plt.ylim([10**0.5, 10**2])
    
    plt.suptitle(plot_title)
    plt.tight_layout()

In [None]:
filtered_df, prop_remaining = threshold_cumulative_cases(austria_nuts2_area_df, 10000)
verify_timeseries(filtered_df, austria_nuts2_timeseries_df, 'Austria - NUTS2 Level')

In [None]:
filtered_df, prop_remaining = threshold_cumulative_cases(italy_area_df, 4000)
verify_timeseries(filtered_df, italy_timeseries_df, 'Italy - NUTS2')

# how many areas remaining at different threshold vals

In [None]:
import numpy as np

In [None]:
threshold_vals = np.linspace(100, 10000, 40)

In [None]:
timeseries_dfs = [germany_ts_df, austria_timeseries_df, italy_timeseries_df, austria_nuts2_timeseries_df, uk_timeseries_df, swiss_timeseries_df]
area_dfs = [germany_area_df, austria_area_df, italy_area_df, austria_nuts2_area_df, uk_area_df, swiss_area_df]
titles = ['germany (nuts3)', 'austria (nuts3)', 'italy', 'austria (nuts2)', 'uk', 'swizterland']

In [None]:
plt.figure(figsize=(4, 3), dpi=200)
for plot_i, (ts_df, area_df, title) in enumerate(zip(timeseries_dfs, area_dfs, titles)):
    p_remaining = np.array([p for _, p in [threshold_cumulative_cases(area_df, t_val) for t_val in threshold_vals]])
    plt.plot(threshold_vals, p_remaining, label=title)
    plt.xlabel('case threshold')
    plt.ylabel('percent areas remaining %')
    plt.legend()

In [None]:
threshold_vals = [2000]
timeseries_dfs = [germany_ts_df, austria_timeseries_df, swiss_timeseries_df]
area_dfs = [germany_area_df, austria_area_df, swiss_area_df]
titles = ['germany (nuts3)', 'austria (nuts3)', 'switzerland']

In [None]:
for t_val in threshold_vals:
    for ts_df, area_df, title in zip(timeseries_dfs, area_dfs, titles):
        filtered_df, p = threshold_cumulative_cases(area_df, t_val)
        verify_timeseries(filtered_df, ts_df, f'Thresholded {t_val}\n{title}\n{p}% of areas remaining')

# Implement deaths stratification

In [None]:
import seaborn as sns

In [None]:
# defining first wave as deaths from first march to end of june (start of july)
def compute_first_wave_deaths(area_df, start_str='MAR', end_str='JUL'):
    fw_deaths = area_df[f'{end_str}-cumdeaths'] - area_df[f'{start_str}-cumdeaths']
    fw_deaths = fw_deaths.to_frame(f'first_wave_deaths')
    filtered_df = fw_deaths.sort_values(fw_deaths.columns[0])
    
    return filtered_df

In [None]:
timeseries_dfs = [germany_ts_df, austria_timeseries_df, italy_timeseries_df, austria_nuts2_timeseries_df, uk_timeseries_df, swiss_timeseries_df]
area_dfs = [germany_area_df, austria_area_df, italy_area_df, austria_nuts2_area_df, uk_area_df, swiss_area_df]
titles = ['germany (nuts3)', 'austria (nuts3)', 'italy', 'austria (nuts2)', 'uk', 'switzerland']

In [None]:
plt.figure(figsize=(6, 6), dpi=300)
for plot_index, (area_df, title) in enumerate(zip(area_dfs, titles)):
    plt.subplot(3, 2, plot_index+1)
    sns.histplot(compute_first_wave_deaths(area_df))
    plt.title(title)

plt.tight_layout()

In [None]:
italy_area_df.index.unique()

In [None]:
def threshold_and_stratify(t_val, area_df, n_groups=5, samples_per_group=3, required=None):
    filtered_df, p = threshold_cumulative_cases(area_df, t_val)
    print(f'Under this threshold, there are {p}% ({len(filtered_df.index)}) areas remaining')
    
    full_thresholded_df = area_df.loc[filtered_df.index]
    fw_deaths = compute_first_wave_deaths(full_thresholded_df)
    fw_deaths = fw_deaths.sort_values('first_wave_deaths', ascending=False)

    stratifications = np.linspace(0, 1, n_percentiles+1)
    groups = np.array_split(list(fw_deaths.index), n_groups)
    
    samples_remaining = samples_per_group * np.ones(shape=5)
    
    samples = []
    
    if required is not None:
        for r in required:
            group = [i for i, group in enumerate(groups) if r in group][0]
            samples.append(r)
            samples_remaining[group] = samples_remaining[group] - 1
    
    for g, n_samples in zip(groups, samples_remaining):
        if n_samples > 0:
            samples.extend(np.random.choice(g, int(n_samples), replace=False).tolist())        
            
    return samples