In [None]:
import os
import sys
import numpy as np
import pickle

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.sentence_based_bert import *
from src.model_settings import *
from src.data import *

In [None]:
model_settings = initialize_model_settings()

# Data processing

In [None]:
get_raw_data(model_settings, "s3")
# get_raw_data(model_settings, "local") ## If raw data already present locally

process_data(model_settings)

upload_processed_data_s3(model_settings)

## To skip previous steps and just download processed data:
# download_processed_data_s3(model_settings)


# Summary statistics

In [None]:
gen_sum_stats(model_settings)


# Train model

In [None]:
for stage in [1, 2]:

    df_train, df_test, _ = gen_train_test_data(model_settings, stage)

    train_bert(model_settings, df_train, stage, prev_epoch=None, from_s3=False)

    ## To see the performance in training and testing has stablized in the final epochs:
    assess_training(model_settings, stage, df_train=df_train, df_test=df_test)


# Evaulate model

In [None]:
for stage in [1, 2]:

    _, df_test, df_train_unbal = gen_train_test_data(model_settings, stage) ## Use data before balancing to compile results
    df_forecast, _ = gen_forecast_data(model_settings, stage)

    model, _, device = create_model(model_settings, stage, prev_epoch=model_settings['epochs'], from_s3=False)
    ## from_s3=True if model is not present locally
    ## from_s3=False if model is present locally
    
    print("Training: ")
    calc_results(model_settings, stage, model, device, df_train_unbal, "train")
    
    print("Testing: ")
    calc_results(model_settings, stage, model, device, df_test ,       "test")
    
    print("Forecast: ")
    calc_results(model_settings, stage, model, device, df_forecast,    "forecast")


# Interpret results

In [None]:
slct_date = '02-02-2020'
extract_articles(model_settings, slct_date)

In [None]:
kword = '北京'

slct_date = '06-30-2020'
extract_articles(model_settings, slct_date, kword)

# Robustness check

### Translating Covid terms to Sars terms

In [None]:
for stage in [1, 2]:

    df_forecast, _ = gen_forecast_data(model_settings, stage, to_translate=True)

    model, _, device = create_model(model_settings, stage, prev_epoch=model_settings['epochs'], from_s3=False)
    ## from_s3=True if model is not present locally
    ## from_s3=False if model is present locally
    
    print("Forecast: ")
    calc_results(model_settings, stage, model, device, df_forecast, "forecast", to_translate=True)


### Filtering out foriegn country-related articles

In [None]:
## Use new model folder name

model_settings_domestic_only = model_settings
model_settings_domestic_only['model_name'] = "Sentence_based_domestic_only"


In [None]:
## Train model

for stage in [1, 2]:

    df_train, df_test, _ = gen_train_test_data(model_settings_domestic_only, stage, domestic_only=True)

    train_bert(model_settings_domestic_only, df_train, stage, prev_epoch=None, from_s3=False)

    assess_training(model_settings_domestic_only, stage, df_train=df_train, df_test=df_test)

In [None]:
## Evaluate model

for stage in [1, 2]:

    _, df_test, df_train_unbal = gen_train_test_data(model_settings, stage, domestic_only=True)
    df_forecast, _ = gen_forecast_data(model_settings, stage, domestic_only=True)

    model, _, device = create_model(model_settings, stage, prev_epoch=model_settings['epochs'], from_s3=False)
    
    print("Training: ")
    calc_results(model_settings, stage, model, device, df_train_unbal, "train")
    
    print("Testing: ")
    calc_results(model_settings, stage, model, device, df_test ,       "test")
    
    print("Forecast: ")
    calc_results(model_settings, stage, model, device, df_forecast,    "forecast")
