### Note
If you are running all cells at one go please fill up the credentials further down

# Imports

In [None]:
# Additional Installations For Google Colab

!pip install TSErrors
!pip install chart_studio
!pip install plotly -U

In [None]:
import pandas as pd

import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import chart_studio.plotly as py
import chart_studio

import requests

from collections import Counter
from TSErrors import FindErrors

from sklearn.model_selection import ParameterGrid

from keras.models import Sequential
from keras.layers.convolutional import Conv1D, MaxPooling1D
from keras.layers import Dense, Flatten

from datetime import datetime
from datetime import date

from math import log
from math import e

from itertools import chain

import warnings
warnings.simplefilter("ignore")

# Data Pre-Processing

## Getting The Data 

In [None]:
# Data from the John Hopkins University Dataset on GitHub
# https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_time_series

# Defining the variables required
filenames = ['time_series_covid19_confirmed_global.csv',
             'time_series_covid19_deaths_global.csv',
             'time_series_covid19_recovered_global.csv']

url = 'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/'

# Making the main dataframes required for the analysis
confirmed_global = pd.read_csv(url + filenames[0])
deaths_global = pd.read_csv(url + filenames[1])
recovered_global = pd.read_csv(url + filenames[2])
country_cases = pd.read_csv('https://raw.githubusercontent.com/CSSEGISandData/COVID-19/web-data/data/cases_country.csv')

## Data Cleaning

In [None]:
# Simple Data Cleaning - Removing and renaming the Columns

# Removing the Province/State column, as it is pretty much not of any use
confirmed_global.drop(columns = ['Province/State', 'Lat', 'Long'], inplace = True)
deaths_global.drop(columns = ['Province/State', 'Lat', 'Long'], inplace = True)
recovered_global.drop(columns = ['Province/State', 'Lat', 'Long'], inplace = True)
country_cases.drop(columns = ["People_Tested","People_Hospitalized"],inplace = True)

# Renaming the columns for easier access
confirmed_global.rename(columns = {"Country/Region": "country"}, inplace = True)
deaths_global.rename(columns = {"Country/Region": "country"}, inplace = True)
recovered_global.rename(columns = {"Country/Region": "country"}, inplace = True)

country_cases.rename(columns = {
    "Country_Region" : "country",
    "Last_Update": "last",
    "Confirmed": "confirmed",
    "Deaths": "deaths",
    "Recovered" : "recovered",
    "Active" : "active",
    "Mortality_Rate": "mortality"
}, inplace = True)

In [None]:
# Removing some duplicate values from the table
confirmed_global = confirmed_global.groupby(['country'], as_index = False).sum()
deaths_global = deaths_global.groupby(['country'], as_index = False).sum()
recovered_global = recovered_global.groupby(['country'], as_index = False).sum()

In [None]:
country_cases_sorted = country_cases.sort_values("confirmed", ascending=False)
country_cases_sorted.index = [x for x in range(len(country_cases_sorted))]

## Error Corrections

In [None]:
# This value is being changed as there was an error in the original dataset that had to be modified
confirmed_global.at[178, '5/20/20'] = 251667

## DataFrames

In [None]:
confirmed_global.head()

In [None]:
deaths_global.head()

In [None]:
recovered_global.head()

In [None]:
country_cases_sorted.head()

# Data Visualization - General Graphs

## Timeseries

### Code

In [None]:
def get_new_cases(country):
    time_series = confirmed_global.melt(
        id_vars=["country"], var_name="date", value_name="cases"
    )
    time_series = time_series[time_series["country"] == country]
    time_series = time_series.drop(["country"], axis=1)
    time_series.index = [x for x in range(len(time_series))]
    return time_series

In [None]:
def get_new_deaths(country):
    time_series = deaths_global.melt(
        id_vars=["country"], var_name="date", value_name="cases"
    )
    time_series = time_series[time_series["country"] == country]
    time_series = time_series.drop(["country"], axis=1)
    time_series.index = [x for x in range(len(time_series))]
    return time_series

In [None]:
def get_new_recoveries(country):
    time_series = recovered_global.melt(
        id_vars=["country"], var_name="date", value_name="cases"
    )
    time_series = time_series[time_series["country"] == country]
    time_series = time_series.drop(["country"], axis=1)
    time_series.index = [x for x in range(len(time_series))]
    return time_series

In [None]:
def get_plot(time_series,name):
    color = "#f54842" if "deaths" in name else "#45a2ff" if "cases" in name else "#42f587"
    fig = px.bar(time_series, x="date", y="cases",color_discrete_sequence= [color]*len(time_series))
    return fig

In [None]:
def plot_timeseries(country_name, func_name, title,n = -90,daily = False):
    if not daily:
      new_confirmed_cases = func_name(country_name)[n:]
    else:
      confirmed_cases = func_name(country_name)
      cases = confirmed_cases["cases"].diff()[1:]
      new_confirmed_cases = confirmed_cases[1:]
      new_confirmed_cases["cases"] = cases
      new_confirmed_cases = new_confirmed_cases[n:]
    fig = get_plot(new_confirmed_cases,str(func_name))
    fig.update_layout(
        template = 'plotly_dark',
        title=title,
        xaxis_title="Date",
        yaxis_title=f'Number of {"deaths" if "deaths" in title else "new cases"}',
    )
    return fig

### Examples

#### US - Confirmed Cases

In [None]:
plot_timeseries("US", get_new_cases, "Confirmed Cases")

#### India - Recoveries (Daily: Last month)

In [None]:
plot_timeseries("India", get_new_recoveries, "Recoveries",n = -30,daily = True)

## Inter-Country : Line Plot

### Code

In [None]:
def unpivot(df):
   return df.melt(id_vars = ["country"],  value_vars = df.columns[1:])

In [None]:
def compare(df,*args):
  l = list(args)
  temp = unpivot(df)
  return temp[temp["country"].isin(l)]

In [None]:
def create_data(df):
  new = df
  l = list(set(new["variable"]))
  l.sort()
  l.reverse()
  ff= new[new['variable'].isin(l[::5])]
  ff.rename(columns = {"country": "Country","variable" : "Date","value": "Cases"}, inplace = True)
  return ff

In [None]:
def static_line(df,*args):
  df = compare(df,*args)
  ff = create_data(df)
  fig = px.line(ff, x="Date", y="Cases", color="Country",template="plotly_dark",range_y=[0,ff["Cases"].max()])
  fig.layout.update(hovermode = "x")
  return fig

### Example

#### Recoveries - India, New Zealand, US, Brazil

In [None]:
static_line(recovered_global,"India","New Zealand","US","Brazil")

## Intra-Country : All 3 Studies

### Code

In [None]:
def line_comparison(country):
    whole_df = pd.DataFrame()
    whole_df["dates"] = list(confirmed_global.columns[1:])
    whole_df["confirmed"] = list(confirmed_global.loc[confirmed_global['country'] == country].values.flatten()[1:])
    whole_df["deaths"] = list(deaths_global.loc[deaths_global['country'] == country].values.flatten()[1:])
    whole_df["recovered"] = list(recovered_global.loc[recovered_global['country'] == country].values.flatten()[1:]) 
    
    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=whole_df["dates"],
            y=whole_df["confirmed"],
            mode="lines",
            name="confirmed"
        )
    )

    fig.add_trace(
        go.Scatter(
            x=whole_df["dates"],
            y=whole_df["deaths"],
            mode="lines",
            name="deaths"
        )
    )

    fig.add_trace(
        go.Scatter(
            x=whole_df["dates"],
            y=whole_df["recovered"],
            mode="lines",
            name="recovered"
        )
    )


    fig.update_layout(
        height=500,
        showlegend=True,
        
        template = "plotly_dark",
        title_text=f"Analysis of {country.title()}", hovermode='x'
    )

    return fig

### Example

#### India

In [None]:
line_comparison("India")

# Data Visualization - Animations

## Top Ten Affected 

### Code

In [None]:
def unpivot(df):
   return df.melt(id_vars = ["country"],  value_vars = df.columns[1:])

In [None]:
def take_top10(df):
  top = list(df[df["variable"] == df['variable'][df.index[-1]]].sort_values(by=['value'], ascending=False).head(10)["country"])
  df = df[df['country'].isin(top)]
  return df

In [None]:
def create_data(df):
  new = take_top10(df)
  l = list(set(new["variable"]))
  l.sort()
  l.reverse()
  ff= new[new['variable'].isin(l[::5])]
  ff.rename(columns = {"country": "Country","variable" : "Date","value": "Cases"}, inplace = True)
  return ff

In [None]:
def plot_fig(ff):
  fig = px.bar(ff, x="Country", y="Cases", color="Country",template="plotly_dark",animation_frame="Date", animation_group="Country", range_y=[0,ff["Cases"].max()])
  fig.layout.update(showlegend=False)
  return fig

In [None]:
def animated_barchart(df):
  return plot_fig(create_data(take_top10(unpivot(df))))

### Examples

#### Confirmed Cases

In [None]:
animated_barchart(confirmed_global)

#### Deaths

In [None]:
animated_barchart(deaths_global)

#### Recoveries 

In [None]:
animated_barchart(recovered_global)

## Comparison (User's Choice)

### Code

In [None]:
def compare(df,*args):
  l = list(args)
  temp = unpivot(df)
  return temp[temp["country"].isin(l)]

In [None]:
def plot_fig_compare(ff):
  fig = px.bar(ff, x="Country", y="Cases", color="Country",template="plotly_dark",animation_frame="Date", animation_group="Country",range_y=[0,ff["Cases"].max()])
  fig.layout.update(hovermode = "x")
  return fig

In [None]:
def create_comparison_animation(df,*args):
  df = compare(df,*args)
  ff = create_data(df)
  return plot_fig_compare(ff)

### Examples

#### Confirmed Cases - India, US, Australia, Brazil

In [None]:
create_comparison_animation(confirmed_global,"India","US","Australia","Brazil")

#### Recoveries - India, US, Australia, Brazil

In [None]:
create_comparison_animation(recovered_global,"India","US","Australia","Brazil")

#### Deaths - India, US, Brazil

In [None]:
create_comparison_animation(deaths_global,"India","US","Brazil")

# Data Visualization - Chloropleths

## Code

### Setting Credentials (Mapbox and Chart Studio)

In [None]:
# Set your credentials befpre running this cell!!

chart_studio.tools.set_credentials_file(username="MajimeArun",
                                         api_key="KK84E7jKFwJ23FiKj49y")
mapbox_access_token = "pk.eyJ1IjoiY2hhcnRzdHVkaW91c2VyIiwiYSI6ImNrZXd3bTBoNTA4bnYyemw4N3l5aDN5azIifQ.7e-KoC1KMXr_EKbkahgAQQ"

### Formatting Data

In [None]:
def chainer(s):
    return list(chain.from_iterable(s.str.split(",")))

In [None]:
def convert_df(df, cols):
    df.dropna(inplace=True)
    df.set_index(df[cols[0]].values)


    L = []
    for i in range(len(df)):
        string = ""
        for j in range(len(cols[1])):
            if j != (len(cols[1]) - 1):
                string = string + str(df[cols[1][j]].values[i]) + ","
            else:
                string = string + str(df[cols[1][j]].values[i])

        L.append(string)

    df["New"] = L
    lens = df["New"].str.split(",").map(len)

    df = pd.DataFrame(
        {
            "Country": np.repeat(df[cols[0]], lens),
            "Lat": np.repeat(df[cols[-2]], lens),
            "Long_": np.repeat(df[cols[-1]], lens),
            "Count": chainer(df["New"]),
        }
    )
    df["Study"] = [cols[1][i] for i in range(len(cols[1]))] * (
        len(df.index) // len(cols[1])
    )
    return df

### Creating Trace

In [None]:
def create_hovertemplate(df, study, country):
    emoji = "💀" if study.lower() == "deaths" else "😷" if study.lower() == "recovered" else "🏥"
    return f"{emoji}: {format(int(float(df.loc[(df['Study'] == study) & (df['Country'] == country), 'Count'])),',d')}"

In [None]:
def create_data(df, study, color):
    countries = list(df["Country"].value_counts().index)
    data = []
    df.dropna(inplace=True)

    for country in countries:
        try:
            event_data = dict(
                lat=df.loc[(df["Study"] == study) & (df["Country"] == country), "Lat"],
                lon=df.loc[
                    (df["Study"] == study) & (df["Country"] == country), "Long_"
                ],
                name=f"{country}",
                marker={
                    "size": log(
                        float(
                            df.loc[
                                (df["Study"] == study) & (df["Country"] == country),
                                "Count",
                            ]
                        ),
                        1.5,
                    ),
                    "opacity": 0.5,
                    "color": color,
                },
                type="scattermapbox",
                hovertemplate=create_hovertemplate(df, study, country),
            )
            data.append(event_data)
        except:
            continue

    return data

### Creating Layout

In [None]:
def create_basic_layout(latitude, longitude, zoom):
    layout = {
        "height": 700,
        "margin": {"t": 0, "b": 0, "l": 0, "r": 0},
        "font": {"color": "#FFFFFF", "size": 15},
        "paper_bgcolor": "#000000",
        "showlegend": False,
        "mapbox": {
            "accesstoken": mapbox_access_token,
            "bearing": 0,
            "center": {"lat": latitude, "lon": longitude},
            "pitch": 0,
            "zoom": zoom,
            "style": "dark",
        },
    }
    return layout

In [None]:
def update_layout(study, layout):
    annotations = [
        {
            "text": f"{study.capitalize()} Cases",
            "font": {"color": "#FFFFFF", "size": 14},
            "borderpad": 10,
            "x": 0.05,
            "y": 0.05,
            "xref": "paper",
            "yref": "paper",
            "align": "left",
            "showarrow": False,
            "bgcolor": "black",
        }
    ]

    layout["title"] = f"{study.capitalize()}"
    layout["annotations"] = annotations
    layout["hoverlabel"] = dict(font_size=16, font_family="Rockwell",font_color = "black")

    return layout

In [None]:
def get_lat_long(country, coord_df=country_cases_sorted):
    lat = float(coord_df.loc[(coord_df["country"] == country), "Lat"])
    long = float(coord_df.loc[(coord_df["country"] == country), "Long_"])
    return lat, long

### Getting Data For Country Plot

In [None]:
def get_country_wise_data():
    response = requests.get("https://corona.lmao.ninja/v2/jhucsse")
    data = response.json()
    return data

In [None]:
def choose_country(array, country):
    return [i for i in array if (i["country"] == country)]

In [None]:
def get_country_frame(country):
    def get(string, country):
        return [i[string] for i in country]

    coords = get("coordinates", country)
    stats = get("stats", country)
    names = get("province", country)

    def make_column(string, main):
        return [i[string] for i in main]

    df = pd.DataFrame()
    df["Provinces"] = names
    df["lat"] = make_column("latitude", coords)
    df["lon"] = make_column("longitude", coords)
    df["Confirmed"] = make_column("confirmed", stats)
    df["Recoveries"] = make_column("recovered", stats)
    df["Deaths"] = make_column("deaths", stats)
    df = df[df["Provinces"] != "Unknown"]
    return df

### Creating Figure Object

In [None]:
def interactive_map(data, layout):
    figure = {"data": data, "layout": layout}

    return figure

### Final Function

#### Global Plot Function

In [None]:
def plot_study(
    starting_df,
    cols,
    study_dict,
    location="global",
    zoom=2,
    latitude=20.59,
    longitude=78.96,
):
    color = study_dict["color"]
    study = study_dict["study"]
    df = convert_df(starting_df, cols)
    data = create_data(df, study, color)
    layout = create_basic_layout(latitude, longitude, zoom)
    updated_layout = update_layout(study, layout)
    figure = interactive_map(data, updated_layout)
    return figure

#### Country Plot Function

In [None]:
def plot_country(Country, data, study):
    country = choose_country(data, Country)
    df = get_country_frame(country)
    columns = ["Provinces", ["Confirmed", "Recoveries", "Deaths"], "lat", "lon"]
    color = "#45a2ff" if study == "Confirmed" else "#f54842" if study == "Deaths" else "#42f587"
    d = dict(study=study.title(), color=color)
    figure = plot_study(
        df,
        columns,
        d,
        country,
        zoom=4.5,
        latitude=get_lat_long(Country)[0],
        longitude=get_lat_long(Country)[1],
    )
    return figure

## Examples

### Global - Confirmed Cases

In [None]:
confirmed = dict(study="confirmed",color="#45a2ff")
recovered = dict(study="recovered",color="#42f587")
deaths = dict(study="deaths",color="#f54842")

columns = ["country", ["deaths", "confirmed", "recovered"], "Lat", "Long_"]

figure = plot_study(country_cases_sorted, columns, confirmed)
py.iplot(figure)

### Japan - Recoveries Cases

In [None]:
figure= plot_country("Japan",get_country_wise_data(),"Recoveries")
py.iplot(figure)

# Working With The Latest Data - Of Individual Countries

## Getting Data

In [None]:
def get_today_data():
    today_data = requests.get("https://corona.lmao.ninja/v2/all?yesterday")
    today_country_data = requests.get("https://corona.lmao.ninja/v2/jhucsse")

    today_data = today_data.json()
    today_country_data = today_country_data.json()

    return today_data, today_country_data

## Formatting Data 

In [None]:
def cases_object(array):
    obj1 = {
        study: sum([(i["stats"][study]) for i in array])
        for study in ["confirmed", "deaths", "recovered"]
    }
    return {**obj1, "updatedAt": [i["updatedAt"] for i in array]}

In [None]:
def choose_country(array, country):
    return [i for i in array if (i["country"] == country)]

In [None]:
def get_final_object(country, array):
    return cases_object(choose_country(array, country))

In [None]:
def get_country_frame(country):
    def get(string, country):
        return [i[string] for i in country]

    coords = get("coordinates", country)
    stats = get("stats", country)
    names = get("province", country)

    def make_column(string, main):
        return [i[string] for i in main]

    df = pd.DataFrame()
    df["Provinces"] = names
    df["lat"] = make_column("latitude", coords)
    df["lon"] = make_column("longitude", coords)
    df["Confirmed"] = make_column("confirmed", stats)
    df["Recoveries"] = make_column("recovered", stats)
    df["Deaths"] = make_column("deaths", stats)
    df = df[df["Provinces"] != "Unknown"]
    return df

## Visualizing The Data

In [None]:
today_data,today_country_data = get_today_data()
country_stats = get_country_frame(choose_country(today_country_data, "India"))

### Bar Chart

In [None]:
def plot_province(data, metric, metric_name):
    fig = go.Figure()

    fig.add_trace(
        go.Bar(x=data["Provinces"], y=data[metric])
    )

    fig.update_layout(
        title={
            "text": "Province Details",
            "y": 0.9,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
        template="plotly_dark",
        xaxis_title="Province",
        yaxis_title="Cases",
    )

    return fig



#### Example - India

In [None]:
plot_province(country_stats, "Confirmed", "Confirmed Cases")

### Table Form

In [None]:
def table_province_data(data, metric):
    df = pd.DataFrame(data={"Provinces": data["Provinces"], metric: data[metric]})
    df[metric] = df[metric].map(lambda x: format(x, ",d"))
    if len(df) <= 1:
        return
    else:
        return df

#### Example - India

In [None]:
table_province_data(country_stats, "Confirmed")

Unnamed: 0,Provinces,Confirmed
0,Andaman and Nicobar Islands,4929
1,Andhra Pradesh,881273
2,Arunachal Pradesh,16696
3,Assam,215997
4,Bihar,250390
5,Chandigarh,19551
6,Chhattisgarh,276337
7,Dadra and Nagar Haveli and Daman and Diu,3374
8,Delhi,623415
9,Goa,50772


# Predictive TimeSeries Model : CNN

## Code

### Formatting The Data

In [None]:
def get_data(confirmed = confirmed_global, deaths = deaths_global, recovered = recovered_global):

    recovered = recovered.groupby("country").sum().T
    deaths = deaths.groupby("country").sum().T
    confirmed = confirmed.groupby("country").sum().T

    deaths.index = pd.to_datetime(deaths.index, infer_datetime_format=True)
    recovered.index = pd.to_datetime(
        recovered.index, infer_datetime_format=True)
    confirmed.index = pd.to_datetime(
        confirmed.index, infer_datetime_format=True)

    return deaths, recovered, confirmed

In [None]:
def create_data_frame(dataframe, country):

    deaths, recovered, confirmed = get_data()

    if dataframe == "deaths":
        data = pd.DataFrame(
            index=deaths.index, data=deaths[country].values, columns=["Total"]
        )

    elif dataframe == "recovered":
        data = pd.DataFrame(
            index=recovered.index, data=recovered[country].values, columns=[
                "Total"]
        )

    elif dataframe == "confirmed":
        data = pd.DataFrame(
            index=confirmed.index, data=confirmed[country].values, columns=[
                "Total"]
        )

    data = data[(data != 0).all(1)]

    data_diff = data.diff()

    # removing the first value from data_diff as it had no previous value and is a NaN after diffrencing
    data_diff = data_diff[1:]

    return data, data_diff

### Series Creation

In [None]:
def make_series(df_name, country, steps):

    data, data_diff = create_data_frame(df_name, country)

    # Taking the values from data_diff and making them an array
    series = np.array(data_diff["Total"])

    X, y = [], []
    for i in range(len(series)):
        end = i + steps
        if end > len(series) - 1:
            break
        x_sample, y_sample = series[i:end], series[end]
        X.append(x_sample)
        y.append(y_sample)

    return data, data_diff, np.array(X), np.array(y)

### Error : MASE

In [None]:
def mase(y_true, y_pred):
    er = FindErrors(y_true, y_pred)
    return er.mase()

### Parameter Grid

In [None]:
def create_param_grid():

    param_grid = {
        "filters": (60, 70),
        "nodes": (60, 70),
        "epochs": (60, 70),
        "activation1": ("swish", "relu", "tanh"),
        "activation2": ("swish", "relu", "tanh"),
    }
    grid = ParameterGrid(param_grid)

    return grid

### Compiling The Model

In [None]:

def compile_model(p):

    model = Sequential()
    model.add(
        Conv1D(
            filters=p["filters"],
            kernel_size=2,
            activation=p["activation1"],
            input_shape=(14, 1),
        )
    )
    model.add(MaxPooling1D(pool_size=2))
    model.add(Flatten())
    model.add(Dense(p["nodes"], activation=p["activation2"]))
    model.add(Dense(1))
    model.compile(optimizer="adam", loss="mse")

    return model

### Hyperparameter Tuning

In [None]:
def hyperparameter_tuning(grid, X_train, y_train):

    parameters = pd.DataFrame(columns=["MASE", "Parameters"])
    for p in grid:
        model = compile_model(p)

        # reshaping the set to suit the required input shape
        X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))

        model.fit(X_train, y_train, epochs=p["epochs"], verbose=0)
        predictions = model.predict(X_train, verbose=0)

        # flattening the predictions to a 1D array to calculate the MASE
        predictions = predictions.flatten()

        MASE = mase(y_train, predictions)
        parameters = parameters.append(
            {"MASE": MASE, "Parameters": p}, ignore_index=True
        )

    return parameters

In [None]:
def get_best_params(parameters):

    # sort the dataframe based on MASE values
    final = parameters.sort_values("MASE").reset_index().iloc[0]

    return final.values[2]

### Testing The Model

In [None]:
def test_model(p, X_train, X_test, y_train, y_test, data):

    model = compile_model(p)

    # reshaping the set to suit the required input shape
    X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))

    model.fit(X_train, y_train, epochs=p["epochs"], verbose=0)

    # reshaping the set to suit the required input shape
    X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))

    # predicting results of X_test
    predictions = model.predict(X_test, verbose=0)
    predictions = predictions.flatten()

    # Taking the cumulative of the predictions step wise
    # Start is the value just before the test_set, which is used to begin taking the cumulative
    start = data["Total"][-len(y_test) - 1]
    predictions_cumulative = []
    for i in predictions:
        start = start + i
        predictions_cumulative.append(start)

    # The actual cumulative values
    y_test_cumulative = data["Total"][-len(y_test):]

    MASE = mase(y_test_cumulative, predictions_cumulative)

    return MASE

### Fitting The Final Model

In [None]:
def make_final_model(p, X, y):
    model = compile_model(p)

    # reshaping the set to suit the required input shape
    X = X.reshape((X.shape[0], X.shape[1], 1))

    model.fit(X, y, epochs=p["epochs"], verbose=0)

    return model

### Forecasting The Next 14 Days

In [None]:
def forecast(data_diff, data, n, model):

    forecast = []

    for i in range(n):
        l = len(forecast)
        inp = (list(data_diff["Total"][-(n - l):])) + forecast
        inp = np.array(inp)
        inp = inp.reshape(1, 14, 1)
        future = model.predict(inp, verbose=0)
        forecast.append(list(future.flatten())[0])

    forecast_cumulative = []
    start = data["Total"][-1]
    for i in forecast:
        start = start + i
        forecast_cumulative.append(start)

    return forecast_cumulative

### Plotting The Forecast

In [None]:
def plot_graph(data, pred):

    datelist = pd.date_range(data.index[-1], periods=15).tolist()
    datelist = datelist[1:]
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(x=data.index, y=data["Total"],
                   mode="lines", name="Up till now")
    )
    fig.add_trace(go.Scatter(x=datelist, y=pred,
                             mode="lines", name="Predictions*"))
    fig.update_layout(template="plotly_dark")

    return fig

### Flatline Check - Naive Forecast

In [None]:
def check_slope(x, y):
    c = Counter(np.diff(y) / np.diff(x))
    return 0 not in [i[0] for i in c.most_common(1)]

In [None]:
def naive_forecast(study, country):
    df, _ = create_data_frame(study, country)
    datelist = pd.date_range(df.index[-1], periods=15).tolist()[1:]
    predictions = [df.Total[-1]] * 14
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(x=df.index, y=df["Total"], mode="lines", name="Up till now")
    )
    fig.add_trace(
        go.Scatter(x=datelist, y=predictions,
                   mode="lines", name="Predictions*")
    )
    fig.update_layout(template="plotly_dark")
    return 1, fig, predictions

### Final Function

In [None]:
def cnn_predict(df_name, country):

    data, data_diff, X, y = make_series(df_name, country, 14)
    grid = create_param_grid()
    n = len(data_diff) * 17 // 20
    X_train, X_test, y_train, y_test = X[:n], X[n:], y[:n], y[n:]
    parameters = hyperparameter_tuning(grid, X_train, y_train)
    p = get_best_params(parameters)
    MASE = (test_model(p, X_train, X_test, y_train, y_test, data)).round(2)
    if MASE <= 1 or check_slope([1, 2, 3, 4, 5], data.Total[-5:]):
        cnn = make_final_model(p, X, y)
        f = forecast(data_diff, data, 14, cnn)
        f = list(map(int, f))
        fig = plot_graph(data, f)
    else:
        MASE, fig, f = naive_forecast(df_name, country)
        
    datelist = pd.date_range(data.index[-1], periods=8).tolist()[1:]
    predictions = pd.DataFrame(
        data={"Date": list(map(lambda x: x.strftime('%d/%m/%Y'), datelist)), "Cases": f[:7]})

    return predictions, MASE, fig

## Examples

### India - Confirmed Cases

In [None]:
pred,_,figure = cnn_predict("confirmed","India")

In [None]:
pred

Unnamed: 0,Date,Cases
0,29/12/2020,10244919
1,30/12/2020,10265629
2,31/12/2020,10285156
3,01/01/2021,10304596
4,02/01/2021,10322177
5,03/01/2021,10338294
6,04/01/2021,10354309


In [None]:
figure.show()

### US - Deaths

In [None]:
pred,_,figure = cnn_predict("deaths","US")

In [None]:
pred

Unnamed: 0,Date,Cases
0,29/12/2020,337320
1,30/12/2020,340197
2,31/12/2020,342742
3,01/01/2021,344414
4,02/01/2021,345597
5,03/01/2021,346561
6,04/01/2021,348086


In [None]:
figure.show()

### Japan - Recoveries

In [None]:
pred,_,figure = cnn_predict("recovered","Japan")

In [None]:
pred

Unnamed: 0,Date,Cases
0,29/12/2020,187066
1,30/12/2020,189929
2,31/12/2020,192972
3,01/01/2021,196288
4,02/01/2021,199424
5,03/01/2021,202356
6,04/01/2021,205179


In [None]:
figure.show()