forked from DVL-Sejong/SIRD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
72 lines (54 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from SIRD.datatype import Country, PreprocessInfo, DatasetInfo
from SIRD.io import load_links, load_initial_dict, load_dataset, load_regions
from SIRD.io import save_setting, save_region_result, save_result_dict
from SIRD.loader import DataLoader
from SIRD.sird import SIRD
from SIRD.util import get_predict_period_from_dataset, generate_dataframe
import argparse
def get_args():
parser = argparse.ArgumentParser(description='test')
parser.add_argument(
"--country", type=str, default='italy',
choices=['italy', 'india', 'us', 'china'],
help="Country name"
)
parser.add_argument(
"--y_frames", type=int, default=3,
help="Number of x frames for generating dataset"
)
args = parser.parse_args()
return args
def main(args):
country = Country(args.country.upper())
link_df = load_links(country)
pre_info = PreprocessInfo(country=country, start=link_df['start_date'], end=link_df['end_date'],
increase=True, daily=True, remove_zero=True,
smoothing=True, window=5, divide=False)
test_info = PreprocessInfo(country=country, start=link_df['start_date'], end=link_df['end_date'],
increase=False, daily=True, remove_zero=True,
smoothing=True, window=5, divide=False)
initial_dict = load_initial_dict(country, pre_info, test_info)
predict_dates = get_predict_period_from_dataset(pre_info, initial_dict, args.y_frames)
predict_info = DatasetInfo(x_frames=0, y_frames=args.y_frames,
test_start=predict_dates[0], test_end=predict_dates[-1])
save_setting(predict_info, 'data_info')
dataset = load_dataset(country, pre_info, test_info)
result_hash = f'{pre_info.get_hash()}_{test_info.get_hash()}_{predict_info.get_hash()}'
result_dict = dict()
for region in load_regions(country):
print(f'Predict {region.upper()}')
loader = DataLoader(predict_info, dataset, region)
region_df = generate_dataframe([], ['susceptible', 'infected', 'recovered', 'deceased'], 'date')
day_dict = dict()
for loader_index in range(len(loader)):
x, y, initial_values = loader[loader_index]
model = SIRD(predict_info, initial_values)
predict_df = model.predict(x)
day_dict.update({str(loader_index).zfill(4): predict_df})
region_df = region_df.append(predict_df.iloc[0, :])
result_dict.update({region: day_dict})
save_region_result(country, result_hash, region, region_df)
save_result_dict(country, result_hash, result_dict)
if __name__ == '__main__':
args = get_args()
main(args)