In [1]:
import numpy as np 
import pandas as pd
from matplotlib import pyplot as plt
from datetime import datetime


In [2]:
should_save_data = False
n_times = 5

In [3]:
data_ox = pd.read_csv("OxCGRT_march_21_2021.csv",low_memory=False)
data_owid = pd.read_csv("owid-covid-data_march_21_2021.csv",low_memory=False)

### data_ox include the oxford data base which include: 

In [4]:
data_ox.columns

Index(['CountryName', 'CountryCode', 'RegionName', 'RegionCode',
       'Jurisdiction', 'Date', 'C1_School closing', 'C1_Flag',
       'C2_Workplace closing', 'C2_Flag', 'C3_Cancel public events', 'C3_Flag',
       'C4_Restrictions on gatherings', 'C4_Flag', 'C5_Close public transport',
       'C5_Flag', 'C6_Stay at home requirements', 'C6_Flag',
       'C7_Restrictions on internal movement', 'C7_Flag',
       'C8_International travel controls', 'E1_Income support', 'E1_Flag',
       'E2_Debt/contract relief', 'E3_Fiscal measures',
       'E4_International support', 'H1_Public information campaigns',
       'H1_Flag', 'H2_Testing policy', 'H3_Contact tracing',
       'H4_Emergency investment in healthcare', 'H5_Investment in vaccines',
       'H6_Facial Coverings', 'H6_Flag', 'H7_Vaccination policy', 'H7_Flag',
       'H8_Protection of elderly people', 'H8_Flag', 'M1_Wildcard',
       'ConfirmedCases', 'ConfirmedDeaths', 'StringencyIndex',
       'StringencyIndexForDisplay', 'Stringenc

### Data_owid include the owid data base which include:

In [5]:
data_owid.columns

Index(['iso_code', 'continent', 'location', 'date', 'total_cases', 'new_cases',
       'new_cases_smoothed', 'total_deaths', 'new_deaths',
       'new_deaths_smoothed', 'total_cases_per_million',
       'new_cases_per_million', 'new_cases_smoothed_per_million',
       'total_deaths_per_million', 'new_deaths_per_million',
       'new_deaths_smoothed_per_million', 'reproduction_rate', 'icu_patients',
       'icu_patients_per_million', 'hosp_patients',
       'hosp_patients_per_million', 'weekly_icu_admissions',
       'weekly_icu_admissions_per_million', 'weekly_hosp_admissions',
       'weekly_hosp_admissions_per_million', 'new_tests', 'total_tests',
       'total_tests_per_thousand', 'new_tests_per_thousand',
       'new_tests_smoothed', 'new_tests_smoothed_per_thousand',
       'positive_rate', 'tests_per_case', 'tests_units', 'total_vaccinations',
       'people_vaccinated', 'people_fully_vaccinated', 'new_vaccinations',
       'new_vaccinations_smoothed', 'total_vaccinations_per_hun

####  The table should be: 

#### State: (26 features)

###### Geographic and general state( 9 features): 
(1) CountryName,  <br />
(2) Date,  <br />
(3) population,  <br />
(4) population_density,  <br />
(5) median_age,  <br />
(6) gdp_per_capita,  <br />
(7) aged_65_older/population,  <br />
(8) life_expectancy,  <br />
(9)human_development_index. <br />

###### Corona state (9 features): 
(1) total_cases per million, <br />
(2)total_deaths per million, <br />
(3) new_cases_per_million, <br />
(4) new_deaths_per_million,<br />
(5) new_tests_per_thousand, <br />
(6) positive_rate, <br />
(7) people_fully_vaccinated_per_hundred, <br />
(8)icu_patients_per_million, <br />
(9) StringencyIndex(oxford) <br />

###### Health care state (6 features):  
(1) cardiovasc_death_rate, <br />
(2) diabetes_prevalence, <br />
(3) female_smokers+ male_smokers/population,<br />
(4) hospital_beds_per_thousand, <br />
(5) hosp_patients_per_million, <br />
(6) icu_patients_per_million. <br />


#### Policy: ( 23 features )
(1) C1_School closing,         + 0.5 X(2)  C1_flag <br />
(3) C2_Workplace closing,      + 0.5 X (4)  C2_flag <br />
(5) C3_Cancel public events,   + 0.5 X (6)  C3_flag <br />
(7) C4_Restrictions on gatherings,+  0.5X (8)  C4_flag <br />
(9) C5_Close public transport,    + 0.5X(10) C5_flag  <br />
(11) C6_Stay at home requirements, + 0.5X (12) C6_Flag <br />
(13) C7_Restrictions on internal movement +0.5X(14) C7_Flag <br />
(15) C8_International travel controls, +0.5X(16)C8_Flag  <br />
(17) H1_Public information campaigns, +0.5X(18)H1_Flag <br />
(19) H2_Testing policy, +0.5X (20) H2_Flag <br />
(21) H6_Facial Coverings, +0.5X(22) H6_Flag <br />
(23) C9_Vaccinate_n precetage of the population (new_vaccinations[t+1]/population) <br />


### seir models:

![image info](./images/SEIR-SEIRS.png)

####  (1)

\begin{split}\begin{aligned}
\frac{dS}{dt} & = -\frac{\beta SI}{N}\\
\frac{dE}{dt} & = \frac{\beta SI}{N} - \sigma E\\
\frac{dI}{dt} & = \sigma E - \gamma I\\
\frac{dR}{dt} & = \gamma I
\end{aligned}\end{split}

### We want to learn:
\begin{equation}
\begin{split}
& \beta - \textrm{Rate of spread, the probability of transmitting disease between a susceptible and an infectious individual } \\ 
& \sigma - \textrm{incubation rate, the rate of latent individuals becoming infectious} \\ 
& \gamma - \textrm{Recovery rate, = 1/D, is determined by the average duration, D, of infection}  \\ 
& \xi - \textrm{rate which recovered individuals return to the susceptible state} \\ 
\end{split}
\end{equation}

#### (2) SEIR with vital dynamics: (enabling vital dynamics (births and deaths) )

\begin{split}\begin{aligned}
\frac{dS}{dt} & = \mu N - \nu S - \frac{\beta SI}{N}\\
\frac{dE}{dt} & = \frac{\beta SI}{N} - \nu E - \sigma E\\
\frac{dI}{dt} & = \sigma E - \gamma I - \nu I\\
\frac{dR}{dt} & = \gamma I - \nu R
\end{aligned}\end{split}

### Addition:
\begin{equation}
\begin{split}
& \mu - \textrm{birth rate } \\ 
& \nu - \textrm{death rates} \\ 
\end{split}
\end{equation}

### The output: 
\begin{equation}
\begin{split}
& S- \textrm{susceptible population} \\
& I - \textrm{infected}, \\
& R - \textrm{removed population (either by death or recovery)}\\
& N = S+I+R
\end{split}
\end{equation}

In [6]:
data_owid.date = pd.to_datetime(data_owid['date'],format='%Y-%m-%d')
data_ox.Date = pd.to_datetime(data_ox['Date'],format='%Y%m%d')

In [7]:
data_ox = data_ox[data_ox.RegionName.isna()]

In [8]:
data_ox.columns

Index(['CountryName', 'CountryCode', 'RegionName', 'RegionCode',
       'Jurisdiction', 'Date', 'C1_School closing', 'C1_Flag',
       'C2_Workplace closing', 'C2_Flag', 'C3_Cancel public events', 'C3_Flag',
       'C4_Restrictions on gatherings', 'C4_Flag', 'C5_Close public transport',
       'C5_Flag', 'C6_Stay at home requirements', 'C6_Flag',
       'C7_Restrictions on internal movement', 'C7_Flag',
       'C8_International travel controls', 'E1_Income support', 'E1_Flag',
       'E2_Debt/contract relief', 'E3_Fiscal measures',
       'E4_International support', 'H1_Public information campaigns',
       'H1_Flag', 'H2_Testing policy', 'H3_Contact tracing',
       'H4_Emergency investment in healthcare', 'H5_Investment in vaccines',
       'H6_Facial Coverings', 'H6_Flag', 'H7_Vaccination policy', 'H7_Flag',
       'H8_Protection of elderly people', 'H8_Flag', 'M1_Wildcard',
       'ConfirmedCases', 'ConfirmedDeaths', 'StringencyIndex',
       'StringencyIndexForDisplay', 'Stringenc

### Actions data base data_ox

In [9]:
data_ox=data_ox.replace(np.nan, 0)

In [10]:
data_ox['C1_index']=(data_ox['C1_School closing']+0.5*data_ox['C1_Flag'])*2
data_ox['C2_index']=(data_ox['C2_Workplace closing']+0.5*data_ox['C2_Flag'])*2
data_ox['C3_index']=(data_ox['C3_Cancel public events']+0.5*data_ox['C3_Flag'])*2
data_ox['C4_index']=(data_ox['C4_Restrictions on gatherings']+0.5*data_ox['C4_Flag'])*2
data_ox['C5_index']=(data_ox['C5_Close public transport']+0.5*data_ox['C5_Flag'])*2
data_ox['C6_index']=(data_ox['C6_Stay at home requirements']+0.5*data_ox['C6_Flag'])*2
data_ox['C7_index']=(data_ox['C7_Restrictions on internal movement']+0.5*data_ox['C7_Flag'])*2
data_ox['C8_index']= data_ox['C8_International travel controls']
data_ox['H1_index']=(data_ox['H1_Public information campaigns']+0.5*data_ox['H1_Flag'])*2
data_ox['H6_index']=(data_ox['H6_Facial Coverings']+0.5*data_ox['H6_Flag'])*2
data_ox['H8_index']=(data_ox['H8_Protection of elderly people']+0.5*data_ox['H8_Flag'])*2


#data_ox['C9_index'] ----> need the other data base to be done...

data_ox=data_ox.drop(columns=[
            'C1_School closing','C2_Workplace closing','C3_Cancel public events','C4_Restrictions on gatherings'
            ,'C5_Close public transport', 'C6_Stay at home requirements', 'C7_Restrictions on internal movement',
            'C8_International travel controls','H1_Public information campaigns','H6_Facial Coverings', 
            'C1_Flag','C2_Flag','C3_Flag','C4_Flag','C5_Flag','C6_Flag','C7_Flag', 'H1_Flag','H6_Flag'
            ])

# Delete economic staff
data_ox=data_ox.drop(columns=[
       'E2_Debt/contract relief', 'E3_Fiscal measures',
       'E4_International support', 'H2_Testing policy', 'H3_Contact tracing',
       'H4_Emergency investment in healthcare', 'H5_Investment in vaccines',
       'H7_Vaccination policy', 'H7_Flag','H8_Protection of elderly people', 'H8_Flag',
       'M1_Wildcard', 'StringencyIndexForDisplay',
       'StringencyLegacyIndex', 'StringencyLegacyIndexForDisplay',
       'GovernmentResponseIndex', 'GovernmentResponseIndexForDisplay',
        'ContainmentHealthIndexForDisplay',
       'EconomicSupportIndex', 'EconomicSupportIndexForDisplay','Jurisdiction',
       'E1_Income support', 'E1_Flag',
        'CountryName', 'RegionName', 'RegionCode'
            ])


In [11]:
if should_save_data:
    data_ox.to_csv('modified_Oxford.csv', index=False) 

In [12]:
data_ox = data_ox.rename(columns={'CountryCode': 'iso_code', 
                                  'Date':'date', 
                                  'ConfirmedCases':'total_cases',
                                  'ConfirmedDeaths':'total_deaths'})

In [13]:
data_ox.columns

Index(['iso_code', 'date', 'total_cases', 'total_deaths', 'StringencyIndex',
       'ContainmentHealthIndex', 'C1_index', 'C2_index', 'C3_index',
       'C4_index', 'C5_index', 'C6_index', 'C7_index', 'C8_index', 'H1_index',
       'H6_index', 'H8_index'],
      dtype='object')

#### Creating Geographic and General State:

In [14]:
data_owid=data_owid.drop(columns=[
        'location', 'continent', 'total_cases', 'new_cases',
       'new_cases_smoothed', 'total_deaths', 'new_deaths',
       'new_deaths_smoothed',
       'new_cases_smoothed_per_million',
       'new_deaths_smoothed_per_million', 'reproduction_rate', 'icu_patients',
        'hosp_patients',
        'weekly_icu_admissions',
       'weekly_icu_admissions_per_million', 'weekly_hosp_admissions',
       'weekly_hosp_admissions_per_million', 'new_tests', 'total_tests',
       'total_tests_per_thousand', 
       'new_tests_smoothed', 'new_tests_smoothed_per_thousand',
       'tests_per_case', 'tests_units', 'total_vaccinations',
       'people_vaccinated', 'people_fully_vaccinated', 'new_vaccinations',
       'total_vaccinations_per_hundred',
       'people_vaccinated_per_hundred',
       'new_vaccinations_smoothed_per_million', 'stringency_index',
       'aged_70_older', 'extreme_poverty',
        'handwashing_facilities'
            ])

In [15]:
data_owid.columns

Index(['iso_code', 'date', 'total_cases_per_million', 'new_cases_per_million',
       'total_deaths_per_million', 'new_deaths_per_million',
       'icu_patients_per_million', 'hosp_patients_per_million',
       'new_tests_per_thousand', 'positive_rate', 'new_vaccinations_smoothed',
       'people_fully_vaccinated_per_hundred', 'population',
       'population_density', 'median_age', 'aged_65_older', 'gdp_per_capita',
       'cardiovasc_death_rate', 'diabetes_prevalence', 'female_smokers',
       'male_smokers', 'hospital_beds_per_thousand', 'life_expectancy',
       'human_development_index'],
      dtype='object')

In [16]:
data_owid=data_owid.replace(np.nan, 0)
data_owid['smokers'] = data_owid['male_smokers']+data_owid['female_smokers']
data_owid = data_owid.drop(columns = ['male_smokers', 'female_smokers'])

In [17]:
data_owid.columns

Index(['iso_code', 'date', 'total_cases_per_million', 'new_cases_per_million',
       'total_deaths_per_million', 'new_deaths_per_million',
       'icu_patients_per_million', 'hosp_patients_per_million',
       'new_tests_per_thousand', 'positive_rate', 'new_vaccinations_smoothed',
       'people_fully_vaccinated_per_hundred', 'population',
       'population_density', 'median_age', 'aged_65_older', 'gdp_per_capita',
       'cardiovasc_death_rate', 'diabetes_prevalence',
       'hospital_beds_per_thousand', 'life_expectancy',
       'human_development_index', 'smokers'],
      dtype='object')

In [18]:
data_owid.to_csv('modified_owid.csv', index=False)

## combine databases

In [19]:
data_combined = pd.merge(data_owid, data_ox, how='inner', left_on=['iso_code', 'date'], right_on = ['iso_code', 'date'])

In [20]:
geographic_columns = ['iso_code', 
                     'date',
                     'population',
                     'population_density',
                     'median_age',
                     'gdp_per_capita',
                     'aged_65_older',
                     'life_expectancy',
                     'human_development_index']
covid_columns = ['iso_code',
                'date',
                'total_cases_per_million', 
                'total_cases',
                'total_deaths_per_million',
                'total_deaths',
                'people_fully_vaccinated_per_hundred',
                'hosp_patients_per_million',
                'icu_patients_per_million',
                'new_tests_per_thousand',
                'new_cases_per_million',
                'new_deaths_per_million',                
                'positive_rate',
                'StringencyIndex',
                'ContainmentHealthIndex']
health_columns = ['iso_code',
                  'date',
                  'cardiovasc_death_rate',
                 'smokers',
                 'diabetes_prevalence',
                 'hospital_beds_per_thousand']
policy_columns = ['iso_code',
                    'date',
                    'C1_index',
                    'C2_index',
                    'C3_index',
                    'C4_index',
                    'C5_index',
                    'C6_index',
                    'C7_index',
                    'C8_index',
                    'H1_index',
                    'H6_index',
                    'H8_index',
                    'new_vaccinations_smoothed']
neighbor_columns = ['iso_code',
                   'date',
                   'total_cases_per_million',
                   'total_deaths_per_million']

In [21]:
data_geographic = data_combined[geographic_columns]
data_covid = data_combined[covid_columns]
data_health = data_combined[health_columns]
data_policies = data_combined[policy_columns]
data_neighbors = data_combined[neighbor_columns]
relevant_countires_3 = data_geographic.iso_code.unique()

## save all data files

In [22]:
if should_save_data:
    date = datetime.date(datetime.now())
    data_geographic.to_csv(f"data_geographic_{date}.csv")
    data_covid.to_csv(f"data_covid_{date}.csv")
    data_health.to_csv(f"data_health_{date}.csv")
    data_policies.to_csv(f"data_policies_{date}.csv")
    data_neighbors.to_csv(f"data_neighbors_{date}.csv")

## create informaiton about distance between countries 

In [23]:
distance_matrix = pd.read_csv("distance-matrix.csv")
distance_matrix = distance_matrix.rename(columns={'Unnamed: 0':'alpha-2'})

In [24]:
country_conversion_data = pd.read_csv("country_iso_conversion.csv")

In [25]:
country_conversion_data = country_conversion_data[['alpha-2', 'alpha-3']]

In [26]:
relevant_countries = data_geographic['iso_code']
merged_geogrpahic_data = pd.merge(relevant_countries, country_conversion_data, how='left', left_on='iso_code', right_on='alpha-3')[['alpha-2', 'alpha-3']]

In [27]:
relevant_countires_2 = merged_geogrpahic_data['alpha-2'].unique()

## create nearset countries functions

In [89]:
def get_nearest_countries(distance_params, country_iso_code, n_countries=4):
    country_distances = distance_params['country_distances']
    country_conversion = distance_params['country_conversion'] 
    relevant_countires = distance_params['relevant_countires']
    # get iso country 2 code for wanted country-
    country_iso_2 = country_conversion[country_conversion['alpha-3'] == country_iso_code]['alpha-2'].item()
    # get nearest countries index -  
    specific_country = country_distances[country_distances['alpha-2'] == country_iso_2].to_numpy()
    if specific_country.size > 0:
        sorted_distances_index = np.argsort(specific_country[0, 1:], axis=0)[1:n_countries+5]
        iso_2_countries = country_distances['alpha-2'].to_numpy()[sorted_distances_index]
        iso_3_countries = []
        for s in iso_2_countries:
            if s in relevant_countires and len(iso_3_countries) < n_countries:
                iso_3_countries.append(country_conversion[country_conversion['alpha-2'] == s]['alpha-3'].item())
    else:
        iso_3_countries = []
    return iso_3_countries

In [99]:
def get_neighbors_dataset(current_iso, current_date, neighbors_isos, data_per_iso, n_times=5, n_neighbors=4):
    n_features = data_per_iso[current_iso].shape[1]-2
    neighbor_data = np.zeros(shape=(n_times, n_features, n_neighbors))
    for i_n, n in enumerate(neighbors_isos):
        relevant_data = data_per_iso[n][(data_per_iso[n]['date'] < current_date + pd.to_timedelta(n_times, unit='d')) & (data_per_iso[n]['date'] >= current_date)]
        neighbor_data[n_times-len(relevant_data):, :, i_n] = relevant_data.to_numpy()[:, 2:]
    return neighbor_data

test nearest countries function

In [100]:
distance_params = {}
distance_params['country_distances'] = distance_matrix
distance_params['country_conversion'] = country_conversion_data
distance_params['relevant_countires'] = relevant_countires_2
get_nearest_countries(distance_params, 'ISR', 4)

['PSE', 'JOR', 'LBN', 'CYP']

## create x and y for model training and validation

In [105]:
def create_input_dataset(data, n_times = 5, n_neighbors=4, with_y = False, is_neighbors = False, **distance_params):
    # save df info as a dictionary where the key is the country iso-3 and the value is the features and date
    data_per_iso = {k: v.sort_values(by='date') for (k, v) in data.groupby(['iso_code'])}
    n_features = data.shape[1] - 2  # remove date and iso_code since they are not features
    x_out = []
    y_out = []
    batch_isos = []
    for iso in data_per_iso.keys():        
        x_full = data_per_iso[iso].to_numpy()[:, 2:]
        dates = np.sort(data_per_iso[iso].date.unique())
        if is_neighbors:
            neighbors_isos = get_nearest_countries(distance_params, country_iso_code = iso, n_countries=n_neighbors)
        for i_d, d in enumerate(dates):
            if i_d+n_times < dates.size:
                if is_neighbors:
                    data_out = get_neighbors_dataset(iso, d, neighbors_isos, data_per_iso, n_times, n_neighbors)
                    x_out.append(data_out)
                else:
                    x_out.append(x_full[i_d:i_d + n_times, :])
                batch_isos.append(iso)
            if with_y and i_d + n_times < dates.size:
                y_out.append(x_full[i_d + n_times, :])
#         if is_neighbors:
#             print(f"finished country:{iso}, total num dates:{dates.size}")
    x_out_arr = np.array(x_out)
    y_out_arr = np.array(y_out)
    batch_isos = np.array(batch_isos)
    return batch_isos, x_out_arr, y_out_arr

In [106]:
_, x_geo, _ = create_input_dataset(data_geographic)
batch_isos, x_covid, y_out = create_input_dataset(data_covid, with_y=True)
_, x_health, _ = create_input_dataset(data_health)
_, x_policies, _ = create_input_dataset(data_policies)
# x_neighbors, _ = create_input_dataset(data_neighbors, n_times=5, n_neighbors=4, with_y=False, is_neighbors=True, country_distances=distance_matrix, country_conversion=country_conversion_data, relevant_countires=relevant_countires_2)

In [107]:
_, x_neighbors, _ = create_input_dataset(data_neighbors, n_times=5, n_neighbors=4, with_y=False, is_neighbors=True, country_distances=distance_matrix, country_conversion=country_conversion_data, relevant_countires=relevant_countires_2)

In [115]:
np.save("x_neighbors", x_neighbors)
np.save("x_health", x_health)
np.save("x_covid", x_covid)
np.save("x_geo", x_geo)
np.save("x_policies", x_policies)
np.save("y_out", y_out)
np.save("batch_isos", batch_isos)