## Outline

- [載入資料](#載入資料)
    - [資料前處理](#資料前處理)
- [綜觀全球疫情走勢](#綜觀全球疫情走勢)
- [begin](#begin-------------------------------------------------------)
    - [台灣](#台灣)
    - [美國](#美國)
    - [印度](#印度)
    - [英國](#英國)        
    - [俄羅斯](#俄羅斯)        
    - [日本](#日本)        
    - [南韓](#南韓)        
    - [荷蘭](#荷蘭)
- [預測各國疫情走勢](#預測各國疫情走勢)
- [比較各國資訊](#比較各國資訊)

In [1]:
import gc
import os
from pathlib import Path
import random
import sys
from tqdm import tqdm
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

from IPython.core.display import display, HTML

# --- plotly ---
from plotly import tools, subplots
import plotly.offline as py
py.init_notebook_mode(connected=True)
import plotly.graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff
import plotly.io as pio
pio.templates.default = "plotly_dark"

# --- models ---
from sklearn import preprocessing
from sklearn.model_selection import KFold
import lightgbm as lgb
import xgboost as xgb
import catboost as cb

# --- setup ---
pd.set_option('max_columns', 50)

## 載入資料

使用到的資料集: [COVID-19/csse_covid_19_data/csse_covid_19_time_series/](https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_time_series)

- 全球
    - `confirmed_global_df` : dataframe，儲存從 2020/1/22 到 2020/5/30 各國確診人數

    - `deaths_global_df` : dataframe，儲存從 2020/1/22 到 2020/5/30 各國死亡人數

    - `recovered_global_df` : dataframe，儲存從 2020/1/22 到 2020/5/30 各國康復人數
- 美國
    - `confirmed_us_df` : dataframe，儲存從 2020/1/22 到 2020/5/30 美國確診人數
    
    - `deaths_us_df` : dataframe，儲存從 2020/1/22 到 2020/5/30 美國死亡人數

In [2]:
import requests

for filename in ['time_series_covid19_confirmed_global.csv',
                 'time_series_covid19_deaths_global.csv',
                 'time_series_covid19_recovered_global.csv',
                 'time_series_covid19_confirmed_US.csv',
                 'time_series_covid19_deaths_US.csv']:
    print(f'Downloading {filename}')
    url = f'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/{filename}'
    myfile = requests.get(url)
    open(filename, 'wb').write(myfile.content)

confirmed_global_df = pd.read_csv('time_series_covid19_confirmed_global.csv')
deaths_global_df = pd.read_csv('time_series_covid19_deaths_global.csv')
recovered_global_df = pd.read_csv('time_series_covid19_recovered_global.csv')

Downloading time_series_covid19_confirmed_global.csv
Downloading time_series_covid19_deaths_global.csv
Downloading time_series_covid19_recovered_global.csv
Downloading time_series_covid19_confirmed_US.csv
Downloading time_series_covid19_deaths_US.csv


### 資料前處理

變更日期的格式，由 mm/dd/yy 改成 yy-mm-dd

In [3]:
def _convert_date_str(df):
    try:
        df.columns = list(df.columns[:4]) + [datetime.strptime(d, "%m/%d/%y").date().strftime("%Y-%m-%d") for d in df.columns[4:]]
    except:
        print('_convert_date_str failed with %y, try %Y')
        df.columns = list(df.columns[:4]) + [datetime.strptime(d, "%m/%d/%Y").date().strftime("%Y-%m-%d") for d in df.columns[4:]]

In [4]:
_convert_date_str(confirmed_global_df)
_convert_date_str(deaths_global_df)
_convert_date_str(recovered_global_df)

將鑽石公主號的資料移除，約旦河西岸和加薩走廊的資料包含負值，所以一併移除

In [5]:
# Filter out problematic data points (The West Bank and Gaza had a negative value, cruise ships were associated with Canada, etc.)
removed_states = "Recovered|Grand Princess|Diamond Princess"
removed_countries = "US|The West Bank and Gaza"

confirmed_global_df.rename(columns={"Province/State": "Province_State", "Country/Region": "Country_Region"}, inplace=True)
deaths_global_df.rename(columns={"Province/State": "Province_State", "Country/Region": "Country_Region"}, inplace=True)
recovered_global_df.rename(columns={"Province/State": "Province_State", "Country/Region": "Country_Region"}, inplace=True)

confirmed_global_df = confirmed_global_df[~confirmed_global_df["Province_State"].replace(np.nan, "nan").str.match(removed_states)]
deaths_global_df    = deaths_global_df[~deaths_global_df["Province_State"].replace(np.nan, "nan").str.match(removed_states)]
recovered_global_df = recovered_global_df[~recovered_global_df["Province_State"].replace(np.nan, "nan").str.match(removed_states)]

confirmed_global_df = confirmed_global_df[~confirmed_global_df["Country_Region"].replace(np.nan, "nan").str.match(removed_countries)]
deaths_global_df    = deaths_global_df[~deaths_global_df["Country_Region"].replace(np.nan, "nan").str.match(removed_countries)]
recovered_global_df = recovered_global_df[~recovered_global_df["Country_Region"].replace(np.nan, "nan").str.match(removed_countries)]

將所有日期合併到同一欄位，該日期的累積人數合併到另一欄位

In [6]:
confirmed_global_melt_df = confirmed_global_df.melt(
    id_vars=['Country_Region', 'Province_State', 'Lat', 'Long'], value_vars=confirmed_global_df.columns[4:], var_name='Date', value_name='ConfirmedCases')

deaths_global_melt_df = deaths_global_df.melt(
    id_vars=['Country_Region', 'Province_State', 'Lat', 'Long'], value_vars=confirmed_global_df.columns[4:], var_name='Date', value_name='Deaths')

recovered_global_melt_df = deaths_global_df.melt(
    id_vars=['Country_Region', 'Province_State', 'Lat', 'Long'], value_vars=confirmed_global_df.columns[4:], var_name='Date', value_name='Recovered')

recovered_global_melt_df.head()

Unnamed: 0,Country_Region,Province_State,Lat,Long,Date,Recovered
0,Afghanistan,,33.0,65.0,2020-01-22,0
1,Albania,,41.1533,20.1683,2020-01-22,0
2,Algeria,,28.0339,1.6596,2020-01-22,0
3,Andorra,,42.5063,1.5218,2020-01-22,0
4,Angola,,-11.2027,17.8739,2020-01-22,0


In [7]:
train = confirmed_global_melt_df.merge(deaths_global_melt_df, on=['Country_Region', 'Province_State', 'Lat', 'Long', 'Date'])
train = train.merge(recovered_global_melt_df, on=['Country_Region', 'Province_State', 'Lat', 'Long', 'Date'])

In [8]:
# --- US ---
confirmed_us_df = pd.read_csv('time_series_covid19_confirmed_US.csv')
deaths_us_df = pd.read_csv('time_series_covid19_deaths_US.csv')

# 丟掉不須用到的欄位
confirmed_us_df.drop(['UID', 'iso2', 'iso3', 'code3', 'FIPS', 'Admin2', 'Combined_Key'], inplace=True, axis=1)
deaths_us_df.drop(['UID', 'iso2', 'iso3', 'code3', 'FIPS', 'Admin2', 'Combined_Key', 'Population'], inplace=True, axis=1)

# 將 Long_ 欄位改名為 Long
confirmed_us_df.rename({'Long_': 'Long'}, axis=1, inplace=True)
deaths_us_df.rename({'Long_': 'Long'}, axis=1, inplace=True)

# 變更日期的格式，由 mm/dd/yy 改成 yy-mm-dd
_convert_date_str(confirmed_us_df)
_convert_date_str(deaths_us_df)

# clean
# 丟掉不屬於美國的地區
confirmed_us_df = confirmed_us_df[~confirmed_us_df.Province_State.str.match("Diamond Princess|Grand Princess|Recovered|Northern Mariana Islands|American Samoa")]
deaths_us_df = deaths_us_df[~deaths_us_df.Province_State.str.match("Diamond Princess|Grand Princess|Recovered|Northern Mariana Islands|American Samoa")]

# --- Aggregate by province state ---
#confirmed_us_df.groupby(['Country_Region', 'Province_State'])
confirmed_us_df = confirmed_us_df.groupby(['Country_Region', 'Province_State']).sum().reset_index()
deaths_us_df = deaths_us_df.groupby(['Country_Region', 'Province_State']).sum().reset_index()

# remove lat, long.
confirmed_us_df.drop(['Lat', 'Long'], inplace=True, axis=1)
deaths_us_df.drop(['Lat', 'Long'], inplace=True, axis=1)

# 合併日期
confirmed_us_melt_df = confirmed_us_df.melt(
    id_vars=['Country_Region', 'Province_State'], value_vars=confirmed_us_df.columns[2:], var_name='Date', value_name='ConfirmedCases')
deaths_us_melt_df = deaths_us_df.melt(
    id_vars=['Country_Region', 'Province_State'], value_vars=deaths_us_df.columns[2:], var_name='Date', value_name='Deaths')

# 將美國確診及死亡的資料合併
train_us = confirmed_us_melt_df.merge(deaths_us_melt_df, on=['Country_Region', 'Province_State', 'Date'])

In [9]:
train = pd.concat([train, train_us], axis=0, sort=False)

train_us.rename({'Country_Region': 'country', 'Province_State': 'province', 'Date': 'date', 'ConfirmedCases': 'confirmed', 'Deaths': 'fatalities'}, axis=1, inplace=True)
train_us['country_province'] = train_us['country'].fillna('') + '/' + train_us['province'].fillna('')

In [10]:
train.head()

Unnamed: 0,Country_Region,Province_State,Lat,Long,Date,ConfirmedCases,Deaths,Recovered
0,Afghanistan,,33.0,65.0,2020-01-22,0,0,0.0
1,Albania,,41.1533,20.1683,2020-01-22,0,0,0.0
2,Algeria,,28.0339,1.6596,2020-01-22,0,0,0.0
3,Andorra,,42.5063,1.5218,2020-01-22,0,0,0.0
4,Angola,,-11.2027,17.8739,2020-01-22,0,0,0.0


In [11]:
train.rename({'Country_Region': 'country', 'Province_State': 'province', 'Id': 'id', 'Date': 'date', 'ConfirmedCases': 'confirmed', 'Deaths': 'fatalities', 'Recovered': 'recovered'}, axis=1, inplace=True)
train['country_province'] = train['country'].fillna('') + '/' + train['province'].fillna('')

# test.rename({'Country_Region': 'country', 'Province_State': 'province', 'Id': 'id', 'Date': 'date', 'ConfirmedCases': 'confirmed', 'Fatalities': 'fatalities'}, axis=1, inplace=True)
# test['country_province'] = test['country'].fillna('') + '/' + test['province'].fillna('')
train.head()

Unnamed: 0,country,province,Lat,Long,date,confirmed,fatalities,recovered,country_province
0,Afghanistan,,33.0,65.0,2020-01-22,0,0,0.0,Afghanistan/
1,Albania,,41.1533,20.1683,2020-01-22,0,0,0.0,Albania/
2,Algeria,,28.0339,1.6596,2020-01-22,0,0,0.0,Algeria/
3,Andorra,,42.5063,1.5218,2020-01-22,0,0,0.0,Andorra/
4,Angola,,-11.2027,17.8739,2020-01-22,0,0,0.0,Angola/


## 綜觀全球疫情走勢

觀察 `ww_df` 的各個屬性

In [12]:
ww_df = train.groupby('date')[['confirmed', 'fatalities']].sum().reset_index()
ww_df['new_case'] = ww_df['confirmed'] - ww_df['confirmed'].shift(1)
ww_df['growth_factor'] = ww_df['new_case'] / ww_df['new_case'].shift(1)
ww_df.tail()

Unnamed: 0,date,confirmed,fatalities,new_case,growth_factor
142,2020-06-12,7644065,425774,129536.0,0.936015
143,2020-06-13,7778686,430041,134621.0,1.039255
144,2020-06-14,7912490,433388,133804.0,0.993931
145,2020-06-15,8034266,436893,121776.0,0.910107
146,2020-06-16,8173745,443679,139479.0,1.145373


觀察 `ww_melt_df` 的各個屬性

In [13]:
ww_melt_df = pd.melt(ww_df, id_vars=['date'], value_vars=['confirmed', 'fatalities', 'new_case'])
ww_melt_df

Unnamed: 0,date,variable,value
0,2020-01-22,confirmed,555.0
1,2020-01-23,confirmed,654.0
2,2020-01-24,confirmed,941.0
3,2020-01-25,confirmed,1434.0
4,2020-01-26,confirmed,2118.0
...,...,...,...
436,2020-06-12,new_case,129536.0
437,2020-06-13,new_case,134621.0
438,2020-06-14,new_case,133804.0
439,2020-06-15,new_case,121776.0


### 全球確診/死亡案例 (折線圖)

- 2020/4/2 確診人數突破 1M，且死亡人數為 52K
- 2020/5/1 確診人數突破 3.3M，且死亡人數為 238K
- **好消息! 每日新增確診案例曲線從 2020/4/4 開始趨於平緩至今**

In [14]:
fig = px.line(ww_melt_df, x="date", y="value", color='variable', 
              title="Worldwide Confirmed/Death Cases Over Time")
fig.show()

### 全球確診/死亡案例 (折線圖) (取log)

- 比較 2020/3 初和 2020/3 底，確診案例成長率的上升速度略為增加

In [15]:
fig = px.line(ww_melt_df, x="date", y="value", color='variable',
              title="Worldwide Confirmed/Death Cases Over Time (Log scale)",
             log_y=True)
fig.show()

### 全球死亡率 (折線圖)

- 可以明顯看到，死亡率在 2020/5 開始下降

In [16]:
ww_df['mortality'] = ww_df['fatalities'] / ww_df['confirmed']

fig = px.line(ww_df, x="date", y="mortality", 
              title="Worldwide Mortality Rate Over Time")
fig.show()

列出有多少國家位於何種確診案例的數量等級

In [17]:
country_df = train.groupby(['date', 'country'])[['confirmed', 'fatalities']].sum().reset_index()
target_date = country_df['date'].max()

print('Date: ', target_date)
for i in [1, 10, 100, 1000, 10000]:
    n_countries = len(country_df.query('(date == @target_date) & confirmed > @i'))
    print(f'{n_countries} countries have more than {i} confirmed cases')

Date:  2020-06-16
188 countries have more than 1 confirmed cases
184 countries have more than 10 confirmed cases
165 countries have more than 100 confirmed cases
121 countries have more than 1000 confirmed cases
62 countries have more than 10000 confirmed cases


列出所有國家

In [18]:
countries = country_df['country'].unique()
print(f'{len(countries)} countries are in dataset:\n{countries}')

188 countries are in dataset:
['Afghanistan' 'Albania' 'Algeria' 'Andorra' 'Angola'
 'Antigua and Barbuda' 'Argentina' 'Armenia' 'Australia' 'Austria'
 'Azerbaijan' 'Bahamas' 'Bahrain' 'Bangladesh' 'Barbados' 'Belarus'
 'Belgium' 'Belize' 'Benin' 'Bhutan' 'Bolivia' 'Bosnia and Herzegovina'
 'Botswana' 'Brazil' 'Brunei' 'Bulgaria' 'Burkina Faso' 'Burma' 'Burundi'
 'Cabo Verde' 'Cambodia' 'Cameroon' 'Canada' 'Central African Republic'
 'Chad' 'Chile' 'China' 'Colombia' 'Comoros' 'Congo (Brazzaville)'
 'Congo (Kinshasa)' 'Costa Rica' "Cote d'Ivoire" 'Croatia' 'Cuba' 'Cyprus'
 'Czechia' 'Denmark' 'Diamond Princess' 'Djibouti' 'Dominica'
 'Dominican Republic' 'Ecuador' 'Egypt' 'El Salvador' 'Equatorial Guinea'
 'Eritrea' 'Estonia' 'Eswatini' 'Ethiopia' 'Fiji' 'Finland' 'France'
 'Gabon' 'Gambia' 'Georgia' 'Germany' 'Ghana' 'Greece' 'Grenada'
 'Guatemala' 'Guinea' 'Guinea-Bissau' 'Guyana' 'Haiti' 'Holy See'
 'Honduras' 'Hungary' 'Iceland' 'India' 'Indonesia' 'Iran' 'Iraq'
 'Ireland' 'Israel'

### 目前確診案例前 30 國家 (折線圖)

In [19]:
country_df = train.groupby(['date', 'country'])[['confirmed', 'fatalities']].sum().reset_index()
top_country_df = country_df.query('(date == @target_date) & (confirmed > 1000)').sort_values('confirmed', ascending=False)

In [20]:
shapes = []
for i in (20, 40, 60):
    shapes.append({'type': 'line', 'xref': 'x', 'yref': 'y', 'x0': i, 'y0': 0, 'x1': i, 'y1': 1})
layout = go.Layout(shapes = shapes)

## begin-------------------------------------------------------

In [21]:
from datetime import timedelta

# plot function
def plot_by_country(country_str, events, anno_pos = 0.5, anno_angle = -90):
    country_data = pd.DataFrame()
    for i in range(len(country_df)):
        if country_df['country'][i] == country_str:
            country_data = country_data.append(country_df.iloc[i])
    country_data.index = np.arange(len(country_data))
    fig = px.line(country_data,
                  x='date', y='confirmed', color='country',
                  title=f'Confirmed Cases of {country_str} till {target_date}',
                  height = 500)
    ymax = country_data['confirmed'][len(country_data) - 1]
    for event in events:
        fig.add_trace(go.Scatter(x=['2020-' + event[0], '2020-' + event[0]], y=[0, ymax], name=event[0] + ' ' + event[1],
                                 mode = 'lines', line=dict(dash='dash', color=(event[2]))))
        yesterday = (datetime.strptime('2020-' + event[0], "%Y-%m-%d").date() - timedelta(days=1)).strftime("%Y-%m-%d")
        fig.add_annotation(x=yesterday, y=ymax * anno_pos, text=event[0],
                           showarrow=False, textangle=anno_angle, bgcolor = '#101010')
    fig.show()
    
# plot function (log)
def plot_by_country2(country_str, events, anno_pos = 0.5, anno_angle = -90):
    country_data = pd.DataFrame()
    for i in range(len(country_df)):
        if country_df['country'][i] == country_str:
            country_data = country_data.append(country_df.iloc[i])
    country_data.index = np.arange(len(country_data))
    country_data['confirmed'] = country_data['confirmed'].apply(lambda x : np.log10(x + 1))
    fig = px.line(country_data,
                  x='date', y='confirmed', color='country',
                  title=f'Confirmed Cases of {country_str} till {target_date} (log)',
                  height = 500)
    ymax = country_data['confirmed'][len(country_data) - 1]
    for event in events:
        fig.add_trace(go.Scatter(x=['2020-' + event[0], '2020-' + event[0]], y=[0, ymax], name=event[0] + ' ' + event[1],
                                 mode = 'lines', line=dict(dash='dash', color=(event[2]))))
        yesterday = (datetime.strptime('2020-' + event[0], "%Y-%m-%d").date() - timedelta(days=1)).strftime("%Y-%m-%d")
        fig.add_annotation(x=yesterday, y=ymax * anno_pos, text=event[0],
                           showarrow=False, textangle=anno_angle, bgcolor = '#101010')
    fig.show()

# plot function (log slope)
def plot_by_country3(country_str, events, anno_pos = 0.5, anno_angle = -90):
    country_data = pd.DataFrame()
    for i in range(len(country_df)):
        if country_df['country'][i] == country_str:
            country_data = country_data.append(country_df.iloc[i])
    country_data.index = np.arange(len(country_data))
    country_data['confirmed'] = country_data['confirmed'].apply(lambda x : np.log10(x + 1))
    slope = [0]
    for i in range(1, len(country_data)):
        slope.append(country_data['confirmed'][i] - country_data['confirmed'][i - 1])
    country_data['slope'] = slope
    fig = px.line(country_data,
                  x='date', y='slope', color='country',
                  title=f'Confirmed Cases of {country_str} till {target_date} (log slope)',
                  height = 500)
    ymax = country_data['slope'][len(country_data) - 1]
    line_y = anno_pos*0.5 if anno_pos > 1 else 0.5
    for event in events:
        fig.add_trace(go.Scatter(x=['2020-' + event[0], '2020-' + event[0]], y=[0, line_y], name=event[0] + ' ' + event[1],
                                 mode = 'lines', line=dict(dash='dash', color=(event[2]))))
        yesterday = (datetime.strptime('2020-' + event[0], "%Y-%m-%d").date() - timedelta(days=1)).strftime("%Y-%m-%d")
        fig.add_annotation(x=yesterday, y=float(0.5 * anno_pos), text=event[0],
                           showarrow=False, textangle=anno_angle, bgcolor = '#101010')
    fig.show()

# plot fuction (new cases ratio)
def plot_by_country4(country_str, events, anno_pos = 0.5, anno_angle = -90):
    country_data = pd.DataFrame()
    for i in range(len(country_df)):
        if country_df['country'][i] == country_str:
            country_data = country_data.append(country_df.iloc[i])
    country_data.index = np.arange(len(country_data))
    new_case = [0]
    dmax = 0
    for i in range(1, len(country_data)):
        d = country_data['confirmed'][i] - country_data['confirmed'][i - 1]
        if d > dmax:
            dmax = d
        new_case.append(d)
    country_data['new case'] = new_case
    country_data['new case'] = country_data['new case'].apply(lambda x : x / dmax)
    fig = px.line(country_data,
                  x='date', y='new case', color='country',
                  title=f'New Cases Ratio of {country_str} till {target_date}',
                  height = 500)
    line_y = anno_pos if anno_pos > 1 else 1
    for event in events:
        fig.add_trace(go.Scatter(x=['2020-' + event[0], '2020-' + event[0]], y=[0, line_y], name=event[0] + ' ' + event[1],
                                 mode = 'lines', line=dict(dash='dash', color=(event[2]))))
        yesterday = (datetime.strptime('2020-' + event[0], "%Y-%m-%d").date() - timedelta(days=1)).strftime("%Y-%m-%d")
        fig.add_annotation(x=yesterday, y=float(0.5 * anno_pos), text=event[0],
                           showarrow=False, textangle=anno_angle, bgcolor = '#101010')
    fig.show()
    
color_1 = '#ffff66' # 淺黃---限制境外移入 (病例)
color_2 = '#66ff66' # 淺綠---擴大篩檢 (列出隱藏病例)
color_3 = '#00ffff' # 藍綠---口罩禁止出口 (人民開始意識到疫情嚴重性)
color_4 = '#ff6666' # 粉紅---封城/宵禁/禁足
color_5 = '#ff66ff' # 粉紫---保持社交距離/強制戴口罩
color_6 = '#ffffff' # 白---其他

### 台灣

In [22]:
events = [['01-24', '限制醫療用及N95等口罩出口', color_3],
          ['02-06', '透過電視播送重要防疫訊息', color_5],
          ['03-21', '全球旅遊疫情等級第三級', color_1],
          ['05-29', '有需求者可自費檢驗武漢肺炎', color_2]]

plot_by_country('Taiwan*', events)
plot_by_country2('Taiwan*', events)
plot_by_country3('Taiwan*', events)

In [23]:
plot_by_country4('Taiwan*', events)

### 美國

In [24]:
events = [['02-20', '限制中國入境美國', color_1],
          ['03-03', '解除檢驗限制，擴大採檢', color_2],
          ['03-11', '加緊歐洲入境限制', color_1],
          ['03-18', '簽署國防生產法', color_3],
          ['04-02', '國防生產法用於N95口罩和呼吸機', color_3],
          ['04-13', '美國15州公立學校全學年停課', color_4],
          ['05-09', '47州陸續復工', color_6]]

plot_by_country('US', events)
plot_by_country2('US', events)
plot_by_country3('US', events)

In [25]:
plot_by_country4('US', events)

### 印度

In [26]:
events = [['03-09', '開始封鎖邊境', color_1],
          ['03-19', '禁止口罩出口', color_3],
          ['03-25', '全國封鎖', color_4],
          ['04-07', '強制戴口罩', color_5],
          ['05-11', '擴大篩檢無症狀者', color_2]]

plot_by_country('India', events)
plot_by_country2('India', events)
plot_by_country3('India', events)

In [27]:
plot_by_country4('India', events)

### 英國

In [28]:
events = [['03-23', 'lockdown', color_4],
          ['04-02', '提供更高的每日病毒檢測數量', color_2],
          ['04-10', '從特定地區回國者須居家隔離', color_6],
          ['04-28', '擴大篩檢65歲以上老人及出外工作的人', color_2],
          ['05-22', '對進入英國的人的14天自我隔離', color_6],
          ['06-08', '所有入境者一律都要隔離14天', color_6]]

plot_by_country('United Kingdom', events)
plot_by_country2('United Kingdom', events)
plot_by_country3('United Kingdom', events)

In [29]:
plot_by_country4('United Kingdom', events)

### 俄羅斯

In [30]:
events = [['02-20', '禁止中國公民入境', color_1],
          ['03-01', '暫時禁止出口醫用口罩', color_3],
          ['03-16', '限制歐洲航班', color_1],
          ['03-30', '封城', color_4],]

plot_by_country('Russia', events)
plot_by_country2('Russia', events)
plot_by_country3('Russia', events)

In [31]:
plot_by_country4('Russia', events)

### 日本

In [32]:
events = [['02-01', '限制湖北省相關旅客及所有外國籍患者入境', color_1],
          ['02-13', '允許政府在入境前實施隔離及停留措施', color_6],
          ['02-27', '請求全國學校停課', color_4],
          ['03-04', '病毒檢測加入公費保險範圍', color_2],
          ['03-24', '宣布2020東京奧運延期舉行', color_6],
          ['04-16', '對全國發布緊急事態宣言', color_4]]

plot_by_country('Japan', events)
plot_by_country2('Japan', events)
plot_by_country3('Japan', events)

In [33]:
plot_by_country4('Japan', events)

### 南韓

In [34]:
events = [['02-07', '450所幼兒園、70多所小學、30所中學停課', color_4],
          ['02-22', '暫停人群密集的活動', color_5],
          ['03-05', '禁止口罩出口，實施口罩限購', color_3],
          ['04-01', '入境旅客實施14天隔離檢疫', color_2],
          ['05-08', '夜總會等娛樂場所停業一個月', color_5]]

plot_by_country('Korea, South', events)
plot_by_country2('Korea, South', events)
plot_by_country3('Korea, South', events)

In [35]:
plot_by_country4('Korea, South', events)

### 荷蘭

In [36]:
events = [['03-09', '要求公民停止握手', color_5],
          ['03-13', '取消病例數高國家的航班', color_1]]

plot_by_country('Netherlands', events)
plot_by_country2('Netherlands', events)
plot_by_country3('Netherlands', events)

In [37]:
plot_by_country4('Netherlands', events)

## 預測各國疫情走勢

In [38]:
country_df['date'] = pd.to_datetime(country_df['date'])
def sigmoid(t, M, beta, alpha, offset=0):
    alpha += offset
    return M / (1 + np.exp(-beta * (t - alpha)))

def error(x, y, params):
    M, beta, alpha = params
    y_pred = sigmoid(x, M, beta, alpha)
    
    # apply weight, latest number is more important than past.
    weight = np.arange(len(y_pred)) ** 2
    loss_mse = np.mean((y_pred - y) ** 2 * weight)
    return loss_mse

def gen_random_color(min_value=0, max_value=256) -> str:
    """Generate random color for plotly"""
    r, g, b = np.random.randint(min_value, max_value, 3)
    return f'rgb({r},{g},{b})'

In [39]:
# train and predict

# 請注意模型預設的感染人數必須超過 1000，並且在感染人數超過1000後才開始當作 train data，
# 隨便放入確診數低於 1000 的國家會直接噴 Error !!!
# 後來透過參數列，改成可以指定國家和最低確診人數
# 使用的模型為 scipy.optimize.minimize (非線性的優化?)
# m, beta, alpha 的意義需要更深入釐清

def fit_sigmoid(target_countries, least_confirmed = -1, exclude_days=0, pred_end = '2020-07-01'):
    target_country_df_list = []
    pred_df_list = []
    flag = 0;
    if least_confirmed == -1:
        flag = 1
    for target_country in target_countries:
        print('target_country', target_country)
        # --- Train ---
        target_country_df = country_df.query('country == @target_country')
        if flag == 1:
            least_confirmed = target_country_df['confirmed'].max() // 100
        
        #train_start_date = target_country_df['date'].min()
        train_start_date = target_country_df.query(f'confirmed > {least_confirmed}')['date'].min()
        train_end_date = pd.to_datetime(target_date) - pd.Timedelta(f'{exclude_days} days')
        target_date_df = target_country_df.query('(date >= @train_start_date) & (date <= @train_end_date)')
        if len(target_date_df) <= 7:
            print('WARNING: the data is not enough, use 7 more days...')
            train_start_date -= pd.Timedelta('7 days')
            target_date_df = target_country_df.query('(date >= @train_start_date) & (date <= @train_end_date)')

        y = target_date_df['confirmed'].values
        x = np.arange(len(y))

        lossfun = lambda params: error(x, y, params)
        res = sp.optimize.minimize(lossfun, x0=[np.max(y) * 5, 0.04, 2 * len(y) / 3.], method='nelder-mead')
        M, beta, alpha = res.x
        # sigmoid_models[key] = (M, beta, alpha)
        # np.clip(sigmoid(list(range(len(data), len(data) + steps)), M, beta, alpha), 0, None).astype(int)

        # --- Pred ---
        pred_start_date = target_country_df['date'].min()
        pred_end_date = pd.to_datetime(pred_end)
        days = int((pred_end_date - pred_start_date) / pd.Timedelta('1 days'))
        #print('pred start', pred_start_date, 'end', pred_end_date, 'days', days)

        x = np.arange(days)
        offset = (train_start_date - pred_start_date) / pd.Timedelta('1 days')
        
        print('offset =', offset, 'least confirmed =', least_confirmed, 'M =', M, 'beta =', beta, 'alpha =', alpha)
        
        y_pred = sigmoid(x, M, beta, alpha, offset=offset)
        # target_country_df['confirmed_pred'] = y_pred

        all_dates = [pred_start_date + np.timedelta64(x, 'D') for x in range(days)]
        pred_df = pd.DataFrame({
            'date': all_dates,
            'country': target_country,
            'confirmed_pred': y_pred,
        })

        target_country_df_list.append(target_country_df)
        pred_df_list.append(pred_df)
    return target_country_df_list, pred_df_list

In [40]:
# plot predictions

def plot_sigmoid_fitting(target_countries, target_country_df_list, pred_df_list, title=''):
    fig = go.Figure()
    ymax = 0
    for i in range(len(target_countries)):
        target_country = target_countries[i]
        target_country_df = target_country_df_list[i]
        pred_df = pred_df_list[i]
        color = gen_random_color(min_value=20)
        tmp = pred_df['confirmed_pred'].max()
        if tmp > ymax:
            ymax = tmp
        # Prediction
        fig.add_trace(go.Scatter(
            x=pred_df['date'], y=pred_df['confirmed_pred'],
            name=f'{target_country}_pred',
            line=dict(color=color, dash='dash')
        ))

        # Ground truth
        fig.add_trace(go.Scatter(
            x=target_country_df['date'], y=target_country_df['confirmed'],
            mode='markers', name=f'{target_country}_actual',
            line=dict(color=color),
        ))
    fig.add_trace(go.Scatter(x=['2020-06-08', '2020-06-08'], y=[0, ymax],mode='lines', name='2020-06-08',line=dict(color=color_4)))
    fig.update_layout(title=title, xaxis_title='Date', yaxis_title='Confirmed cases')
    fig.show()

In [41]:
target_countries = ['Taiwan*']
title = ''
target_country_df_list, pred_df_list = fit_sigmoid(target_countries, exclude_days=7, pred_end = '2020-08-01')
plot_sigmoid_fitting(target_countries, target_country_df_list, pred_df_list, title='Logistic fitting with TW, train data from 2020-01-27 to 2020-06-08')

target_country Taiwan*
offset = 5.0 least confirmed = 4 M = 439.32494933158193 beta = 0.1334825854717855 alpha = 58.01817676657952


In [42]:
target_countries = ['US', 'India', 'United Kingdom', 'Russia']
target_country_df_list, pred_df_list = fit_sigmoid(target_countries, exclude_days=7, pred_end = '2020-08-01')
plot_sigmoid_fitting(target_countries, target_country_df_list, pred_df_list, title='Prediction from 06-08 to 08-01')

target_country US
offset = 59.0 least confirmed = 21375 M = 2292259.005662643 beta = 0.048220832192738844 alpha = 43.67403452442561
target_country India
offset = 74.0 least confirmed = 3540 M = 711462.9052275483 beta = 0.061180483292522486 alpha = 72.41930647990128
target_country United Kingdom
offset = 58.0 least confirmed = 2996 M = 302142.74595158116 beta = 0.06774493211070778 alpha = 36.7409833576344
target_country Russia
offset = 75.0 least confirmed = 5447 M = 567572.31061553 beta = 0.0742126063796045 alpha = 41.92411866186069


In [43]:
target_countries = ['Japan', 'Korea, South', 'Netherlands']
target_country_df_list, pred_df_list = fit_sigmoid(target_countries, exclude_days=7, pred_end = '2020-08-01')
plot_sigmoid_fitting(target_countries, target_country_df_list, pred_df_list, title='Sigmoid fitting with all latest data')

target_country Japan
offset = 35.0 least confirmed = 174 M = 16787.224285979544 beta = 0.12151625988974021 alpha = 49.40434025662397
target_country Korea, South
offset = 30.0 least confirmed = 121 M = 11725.473031832069 beta = 0.03686519170574022 alpha = -0.7912161527075783
target_country Netherlands
offset = 49.0 least confirmed = 492 M = 47263.02218060543 beta = 0.08130926199207118 alpha = 30.333217173381477


# 比較各國資訊

In [44]:
def plot_countries(country_list, column = 'confirmed', log = False):
    tmp_df = country_df[country_df['country'].isin(country_list)]
    if log == True:
        tmp_df[column] = tmp_df[column].apply(lambda x : np.log10(x + 1))
    fig = px.line(tmp_df, x='date', y=column, color='country',
                  title=f'{column} cases of {target_date} ' + ('(log)' if log == True else ''))
    fig.show()

country_list = ['United Kingdom', 'Italy', 'Brazil', 'US']
plot_countries(country_list)
plot_countries(country_list, log = True)
plot_countries(country_list, 'fatalities')