# Influenza Forecast

##### What are the goals of this jupyter notebook?
1. One goal of this jupyter notebook was to quickly implement forecast and influenza data visualizations. This provided a fast way to further explore the model predictions and to quickly try out different visualizations especially to gain insights and select visualizations for the project presentation via [forecasting webapp](https://influenzaforecast.herokuapp.com/). 
2. The interested reader can explore code underlying the forecast and plots presented via the [forecasting webapp](https://influenzaforecast.herokuapp.com/). The underlying code is located in the modules InfluenzaForecastPy3.py and DataFormatting.py which are imported into this notebook.

##### What is the setup of the forecast?
At this point only a brief recap is provided. A more comprehensive description of the project can be found on the [forecasting webapp site](https://influenzaforecast.herokuapp.com/). The goal of the project is to forecast the number of reported influenza infections for each of the sixteen states in Germany. This is done via classifying whether a certain threshold of infected per 100 000 inhabitants is crossed or not. The threshold 0.8 is a critical threshold. Once this threshold is crossed usually an influenza wave occurs. And once a threshold of 7.0 is crossed the wave is relatively severe. Benefits from this forecast are that hospitals and doctor's offices and companies would be able to plan in advance. Therefore clogging could be prevent. Companies in general could prepare their schedules for an increased number of absences. The following three data sources are used.
1. The [Robert Koch Institute](https://survstat.rki.de/Content/Query/Create.aspx) provides the reported number of influenza cases in Germany on a state level per week. To detect an influenza infection a laboratory test has to be performed. Therefore only a fraction of people infected by and suffering from influenza are detected. Despite this fact up to a factor the reported number of influenza cases should be a good proxy for the actual number of influenza cases. 
2. The [Deutscher Wetterdienst](ftp://ftp-cdc.dwd.de/pub/CDC/observations_germany/climate/daily/kl/historical/) provides weather data for Germany on a daily basis. For instance mean, min, max of temperature, relative humidity and precipitation. It turns out that this feature class provides the least predictive power.
3. [Google Flu Trends](https://www.google.org/flutrends/about/) provides numbers reflecting the frequency of influenza related search queries on a state level per week.

### Overview:

##### 1.) Classification: 
- Classification for a specific year
- Cross-validation for model and feature selection
- Grid search for model parameter selection
- Training and predicting for all validation years (2005-2015 excluding 2009) and all forecasting weeks

##### 2.) Exploratory Visualizations:
- Visualizing state commonalities
- Statistics of the wave start, end, length, height 
- Influenza map
- Visualizing the relation between the start week of the wave and the wave length
- Visualizing the relation between start week of the wave and the wave peak
- Classification features
- Overall number of infected per state


###### The following button hides or unhides the code cells.

In [None]:
from IPython.display import HTML

HTML('''<script>
code_show=false; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Click here to toggle on/off the raw code."></form>''')

In [None]:
from datetime import datetime
from isoweek import Week
import numpy as np
import pandas as pd

# Pipeline
from sklearn.pipeline import Pipeline

# Scaling
from sklearn.preprocessing import StandardScaler

# Classification
from sklearn.neural_network import MLPClassifier

from bokeh.io import push_notebook, output_notebook, show
from bokeh.plotting import figure
from bokeh.models import (
    Range1d, 
    LinearAxis,
    ColumnDataSource,
    HoverTool,
    LinearColorMapper,
    BasicTicker,
    FixedTicker,
    PrintfTickFormatter,
    ColorBar,
    DatetimeTickFormatter)

from ipywidgets import interact
import ipywidgets as widgets

from InfluenzaForecastPy3 import *
from Data.DataFormatting import DataFrameProvider

In [None]:
# Bokeh plots are shown inline.
output_notebook()

### 1.) Classification
An object of type ForecastProvider is assigned. The particulars of the forecast like forecasting horizon, forecasting model can be specified via the parameters. This object will be used throughout "1.) Classification". If modified it is important that the weeks_in_advance_int parameter and the length of the classification_pipeline_per_week_list is equal. Since each forecasting distance gets its own model. 

In [None]:
forecast_provider = ForecastProvider(weeks_in_advance_int=5, classification_pipeline_per_week_list=[Pipeline(
                     steps=[('preprocessing', StandardScaler()),
                            ('regressor', MLPClassifier(hidden_layer_sizes=(20, 20, 10), alpha=1,
                                                        learning_rate='adaptive', batch_size=3000,
                                                        random_state=1341, max_iter=1000))])]*5)


#### Classification for a specific year
The following line trains the specified models, one for each forecasting week. The Accuracy, Precision, Recall and ROC AUC, Log Loss (if the model provides prediction probabilities) are printed per week. Further the ROC AUC is plotted for each forecasting week. And finally, for each forecasting week the prediction probabilities and the actual values are plotted. The prediction probabilities specify whether the above specified threshold of reported infected per 100 000 inhabitants will be crossed or not.   

In [None]:
# Valid prediction years are 2005-2008,2010-2014. The outlier year 2009 is not included since the swine flu occurred in 2009. It could be included 
# by initializing the forecast_provider differently (see __init__ documentation).
X_influenza_test, y_test, y_pred_ndarray, test_year_week_per_row_list, complete_unique_year_week_list = forecast_provider.classification_forecast_for_year(prediction_year=2014)

The following plot shows the predictions for the whole prediction horizon. The forecasting visualization shown on the [forecasting webapp site](https://influenzaforecast.herokuapp.com/) is based on this plot. So, by pressing play or using the slider the present date changes. Thus the progression of the predictions can be observed throughout the year. The plot displays the influenza waves for all sixteen states for the above specified year. So, sixteen waves can be observed in the following plot. The waves can be associated to states via the alphabetical ordering of the states. By clicking multiple times on the play button the speed can be increased.

In [None]:
def update_map_forecast(row_index_multiple):
    row_index = row_index_multiple//28
    
    current_year_week_tuple = test_year_week_per_row_list[row_index]
    current_year_week_tuple_index = None
    
    for index, year_week in enumerate(complete_unique_year_week_list):
        if year_week == current_year_week_tuple:
            current_year_week_tuple_index = index
            break
    
    plot_x = [datetime.strptime(str(year_week[0]) + 'W' + str(year_week[1]) + ' MON', '%YW%U %a') for year_week in
              complete_unique_year_week_list]
    
    steps1.data_source.data['x'] = plot_x[current_year_week_tuple_index - X_influenza_test.shape[1]: current_year_week_tuple_index + y_test.shape[1]]
    steps1.data_source.data['y'] = np.hstack([X_influenza_test[row_index, :],y_test[row_index, :]])
    
    steps2.data_source.data['x'] = plot_x[current_year_week_tuple_index: current_year_week_tuple_index+y_pred_ndarray.shape[1]]
    steps2.data_source.data['y'] = y_pred_ndarray[row_index, :]*0.8
    
    line1.data_source.data['x'] = [plot_x[current_year_week_tuple_index]]*2
    
    push_notebook(handle_p_forecast)
    

In [None]:
p = figure(title="Actual Numbers vs Predictions for Different States", y_range=(0,16), x_axis_label='Date', y_axis_label='# Reported Inlfuenza Infections per 100 000 Inhabitants''Date')

source1 = ColumnDataSource(data=dict(x=[1]*(X_influenza_test.shape[1]+y_test.shape[1]), y=[1]*(X_influenza_test.shape[1]+y_test.shape[1])))
source2 = ColumnDataSource(data=dict(x=[1]*(y_test.shape[1]), y=[1]*(y_test.shape[1])))
source4 = ColumnDataSource(data=dict(x=[X_influenza_test.shape[1]]*2, y=[0, 15]))

steps1 = p.step(x='x', y='y', source=source1, color = 'gray', legend='Actual Influenza Numbers')

steps2 = p.step(x='x', y='y', source=source2, color = 'red', legend='Prediction: is threshold crossed?')

line1 = p.line(x='x', y='y', source=source4, color = 'black', legend='Present Moment')
    
p.xaxis.formatter = DatetimeTickFormatter(
    years=["%D %B %Y"]
)

p.xaxis.major_label_orientation = 3.0 / 4
p.xaxis[0].ticker.desired_num_ticks = 30        

handle_p_forecast = show(p, notebook_handle=True)


play = widgets.Play(min=0, max=X_influenza_test.shape[0]*7, step=1)
slider = widgets.IntSlider(min=0, max=X_influenza_test.shape[0]*7, step=1)
my_widgets = widgets.jslink((play, 'value'), (slider, 'value'))
my_box = widgets.VBox([play, slider])
    
out = widgets.interactive_output(update_map_forecast, {'row_index_multiple': play}) 
display(out, my_box)


##### Cross-validation for model and feature selection
The Accuracy, Precision, Recall and ROC AUC, Log Loss (if the model provides prediction probabilities) for cross-validation for the specified forecasting_week_index are printed and the scores_model is returned. So, in this evaluation the forecasting week is fixed but cross-validation is performed, thus all years are evaluated. The forecasting_week_index is restricted by the weeks_in_advance_int instance variable and should not be greater. (For instance forecasting_week_index = 0 refers to a one week forecast). 

In [None]:
return_crossvalidation = forecast_provider.do_cross_validation(classification_pipeline=Pipeline(
                     steps=[('preprocessing', StandardScaler()),
                            ('regressor', MLPClassifier(hidden_layer_sizes=(20, 20, 10), alpha=1,
                                                        learning_rate='adaptive', batch_size=3000,
                                                        random_state=1341, max_iter=1000))]), forecasting_week_index=1, wave_threshold=0.8)

##### Grid search for model parameter selection
As above the forecasting week is fixed but all years are considered during cross-validation. Cross-validation is used to determine the best performing parameters for the model. The best parameters are printed and the best_score_ attribute of the grid search as well as the best parameters are returned.

In [None]:
return_gridsearch = forecast_provider.do_gridsearch(classification_pipeline=Pipeline(
                     steps=[('preprocessing', StandardScaler()),
                            ('regressor', MLPClassifier(hidden_layer_sizes=(20, 20, 10), alpha=1,
                                                        learning_rate='adaptive', batch_size=3000,
                                                        random_state=1341, max_iter=1000))]),
                      parameters={'regressor__batch_size': [500, 3000, 7000],
                                  'regressor__hidden_layer_sizes': [(10, 10), (10, 20, 10)],
                                  'regressor__alpha': [0.001, 0.1, 1.0]},
                      forecasting_week_index=0, wave_threshold=0.8)


##### Training and predicting for all validation years (2005-2015 excluding 2009) and all forecasting weeks
The next cell provides the actual values and the predictions for all forecasting weeks and validation years per forecasting week. The cell after the next cell then visualizes the results. The Accuracy, Precision, Recall, F2 score and ROC AUC, Log Loss (if the model provides prediction probabilities) as well as the confusion matrices are visualized per week. All metrics except the confusion matrices also provide the per year metric value besides the overall metric value.

In [None]:
output_by_week_list, _, _ = forecast_provider.get_formatted_pred_probas_and_actual_values()

In [None]:
row_metrics, row_confusion_threshold1, row_confusion_threshold2 = visualize_and_save_metrics(output_by_week_list, proba_threshold_list=[0.5, 0.5], threshold_list=[0.8, 7.0], threshold_color_list=["#036564", "#550b1d"], file_str='')
show(column([row_metrics, row_confusion_threshold1, row_confusion_threshold2]))

### 2.) Exploratory Visualizations
So far, the classification forecast has been explored. In the remaining part of this notebook. The influenza wave statistics are explored. For instance, the influenza waves of different states are compared to each other or a violin plot is provided showing a relation between the start of an influenza wave and its intensity. In the next two cells the data sources and a function providing the widget functionality are loaded.

In [None]:
# Loading the rolling window data frame (Pandas.DataFrame) from the DataFormatting module.
# This is the feature data set used for visualization and prediction.
dataFrameProvider = DataFrameProvider()
rolling_window_df = dataFrameProvider.getFeaturesDF(number_of_temp_intervals=12, number_of_humid_intervals=12, number_of_prec_intervals=12)

In [None]:
def get_ui_out(update_func):
    button_names = ['Baden-Wuerttemberg', 'Bayern', 'Berlin', 'Brandenburg', 'Bremen', 'Hamburg',
     'Hessen', 'Mecklenburg-Vorpommern', 'Niedersachsen', 'Nordrhein-Westfalen',
     'Rheinland-Pfalz', 'Saarland', 'Sachsen', 'Sachsen-Anhalt',
     'Schleswig-Holstein', 'Thueringen']
    num_of_buttons = len(button_names)
    buttons = []
    for i in range(num_of_buttons):
        buttons.append(widgets.ToggleButton(
        value=True,
        description=button_names[i],
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Description'
    ))

    button_dict = {}
    for index in range(num_of_buttons):
        button_dict[button_names[index]]=buttons[index]

    d = widgets.IntSlider()
    upper_button_list = []
    lower_button_list = []
    for index in range(num_of_buttons//2):
        upper_button_list.append(buttons[index])
        lower_button_list.append(buttons[num_of_buttons//2 + index])

    v_box1 = widgets.HBox(upper_button_list)
    v_box2 = widgets.HBox(lower_button_list)
    ui = widgets.VBox([ v_box1, v_box2 ])

    out = widgets.interactive_output(update_func, button_dict) 
    return ui, out

##### Visualizing state commonalities
The following plot shows the reported number of influenza infections per state from 2005-2015. By clicking on a state in the legend the visibility of the corresponding influenza number curve is toggled. It is interesting that the waves start, end and peak at roughly the same week. Although the height of the wave and the total number of infected from 2005-2015 vary significantly as we will see in the final plot of this notebook. 

In [None]:
show(visualize_state_commonalities(rolling_window_df))

##### Statistics of the wave start, end, length, height 
It is interesting to see that a typical wave start and end can be narrowed down significantly by looking at their distribution. Also, the average wave length can be inferred from the plot and the plot reveals that waves with high peaks are relatively rare. The wave start is defined as the first week in a season for which the reported influenza numbers are above 2.0 infected per 100 000 inhabitants. The wave end is defined accordingly.

In [None]:
def update_wave_stats_distribution(**args):
    state_list = [ ele[0] for ele in args.items() if ele[1] ]
    if len(state_list) == 0:
        print('Please select at least one state.')
    else:
        p_distribution = visualize_wave_stats_distributions(rolling_window_df, states=state_list )
        show(p_distribution)
ui, out = get_ui_out(update_wave_stats_distribution)
display(ui, out)

##### Influenza map
The following plot shows that there is also a spatial relation between the wave starts of different states. By clicking multiple times on the play button the speed can be increased.

In [None]:
p_map, p_patches_map, year_week_list = visualize_infection_numbers_on_map(rolling_window_df)
handle_p_map = show(p_map, notebook_handle=True)

def update_map(year_week_index):
    year_week_index = year_week_index//7
    p_patches_map.data_source.data['rate'] = p_patches_map.data_source.data[str(year_week_index)]
    p_map.title.text = 'Number of Newly Reportd Influenza Infections per 100 000 Inhabitants in Week ' + str(year_week_list[year_week_index][1]) + ' of Year ' + str(year_week_list[year_week_index][0])
    push_notebook(handle_p_map)

play = widgets.Play(min=0, max=3500, step=1)
slider = widgets.IntSlider(min=0, max=3500, step=1)
my_widgets = widgets.jslink((play, 'value'), (slider, 'value'))
my_box = widgets.VBox([play, slider])
    
out = widgets.interactive_output(update_map, {'year_week_index': play}) 
display(out, my_box)

##### Visualizing the relation between the start week of the wave and the wave length
The following figure shows a heatmap of the wave length in weeks given a certain start week. States can be toggled by clicking the corresponding state buttons. The plot basically indicates that the earlier the wave starts the longer the wave lasts.

In [None]:
def update_wave_length(**args):
    state_list = [ ele[0] for ele in args.items() if ele[1] ]
    if len(state_list) == 0:
        print('Please select at least one state.')
    else:
        p_heat = visualize_wave_start_vs_length_via_heatmap(rolling_window_df, states=state_list )
        show(p_heat)
        
ui, out = get_ui_out(update_wave_length)
display(ui, out)

##### Visualizing the relation between start week of the wave and the wave peak
The following figure shows a violin plot of the wave peak in reported number of infected given a certain start week. States can be toggled by clicking the corresponding state buttons. As above the plot indicates that the tendency is the earlier the wave starts the severer it is.

In [None]:
def update_wave_severity(**args):
    state_list = [ ele[0] for ele in args.items() if ele[1] ]
    if len(state_list) == 0:
        print('Please select at least one state.')
    else:
        p_violin = visualize_wave_start_vs_severity_via_violin(rolling_window_df, states=state_list)
        p_violin

ui, out = get_ui_out(update_wave_severity)
display(ui, out)

##### Classification features
The following plot visualized the features. Clicking on the legend shows or hides the corresponding feature. A state is selected by clicking on the corresponding button.

In [None]:
toggle_buttons = widgets.ToggleButtons(
    options=['Baden-Wuerttemberg', 'Bayern', 'Berlin', 'Brandenburg', 'Bremen', 'Hamburg',
     'Hessen', 'Mecklenburg-Vorpommern', 'Niedersachsen', 'Nordrhein-Westfalen',
     'Rheinland-Pfalz', 'Saarland', 'Sachsen', 'Sachsen-Anhalt',
     'Schleswig-Holstein', 'Thueringen'],
    description=' ',
    disabled=False,
    button_style='', 
    tooltips=['Description of slow', 'Description of regular', 'Description of fast'],
)


def update_state_features(state_str):
    show(visualize_data_per_state(rolling_window_df , state_str))
    
out = widgets.interactive_output(update_state_features, {'state_str': toggle_buttons}) 
display(toggle_buttons, out)

##### Overall number of infected per state
The following bar plot shows that there is a clear difference between states with respect to the total number of reported influenza infected from 2005 till 2015 (excluding the outlier year 2009, the swine flu epidemic). It would be interesting whether these significant differences are also caused by state-wise differences in taking samples and sending them to a laboratory for testing for the influenza virus.   

In [None]:
show(visualize_overall_reported_cases(rolling_window_df))