In [1]:
import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import plotly.graph_objects as go
from keras.models import load_model
from keras.preprocessing.sequence import TimeseriesGenerator

In [2]:
#Het inladen van de CSV bestand.
file_name = "Case_PRB_Amerika_Total_Deaths.csv"
df = pd.read_csv(file_name)

#De loop_back variabel dat aangeeft hoever het LSTM-model terug moet kijken tijdens
look_back = 5

In [3]:
#convert datetime string to datetime object
#Assign index to date column
df["date"] = pd.to_datetime(df["date"], infer_datetime_format=True)
df.set_axis(df["date"], inplace = True)

In [4]:
#Het ophalen van de totale doden en het reshapen van de data
death_data = df["total_deaths"].values
death_data = death_data.reshape((-1, 1))

#Het ophalen van de datum data
date_data = df["date"]

In [5]:
#Het kleine setjes aan data van de dataset zodat deze één voor één in een sequence door het neuraal netwerk heen gevoerd kan worden.
#Bijvoorbeeld [1,2,3,4,5] -> [[1,2,3],[2,3,4],[3,4,5]]
data_generator = TimeseriesGenerator(death_data, death_data, length=5, batch_size=1)

In [6]:
#Loading the saved model
model = load_model('Trained_Model')

In [7]:
#De hoeveelheid dagen dat we in de toekomst willen kijken.
num_prediction = 30 

#Het reshapen van de data om deze in de grafiek te plotten.
prediction = model.predict(data_generator)
prediction = prediction.reshape((-1))
death_data = death_data.reshape((-1))

#Functie voor het voorspellen van toekomstige waardes.
def predict(num_prediction, model):
    prediction_list = death_data[-look_back:]
    
    for _ in range(num_prediction):
        x = prediction_list[-look_back:]
        x = x.reshape((1, look_back, 1))
        out = model.predict(x)[0][0]
        prediction_list = np.append(prediction_list, out)
    prediction_list = prediction_list[look_back-1:]
        
    return prediction_list

#Functie voor creëren van de datums voor de toekomstige waardes.
def predict_dates(num_prediction):
    last_date = df['date'].values[-1]
    prediction_dates = pd.date_range(last_date, periods=num_prediction+1).tolist()
    return prediction_dates

forecast = predict(num_prediction, model)
forecast_dates = predict_dates(num_prediction)

#Het plotten van de resultaten in een grafiek.
trace1 = go.Scatter(
    x = date_data,
    y = prediction,
    mode = "lines",
    name = "Prediction"
)
trace2 = go.Scatter(
    x = date_data,
    y = death_data,
    mode="lines",
    name = "Ground Truth"
)
trace3 = go.Scatter(
    x = forecast_dates,
    y = forecast,
    mode="lines",
    name = "Future prediction"
)
layout = go.Layout(
    title = "Total deaths prediction",
    xaxis = {'title' : "Date"},
    yaxis = {'title' : "Total Deaths"}
)
fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
fig.show()