##Creating a table for Forecasting

In [0]:
from pyspark.sql.types import DoubleType, StringType, StructType, StructField, IntegerType

schema = StructType([
    StructField("date", StringType(), True),
    StructField("county", StringType(), True),
    StructField("state", StringType(), True),
    StructField("fips", DoubleType(), True),
    StructField("cases", DoubleType(), True),
    StructField("deaths", DoubleType(), True)
])

covid_df = spark.read.format("csv").schema(schema).option("header", "true").load("/databricks-datasets/COVID/covid-19-data/us-counties.csv")



In [0]:
display(covid_df)

Writing the spark dataframe into table

In [0]:
covid_df.write.saveAsTable("covid_t")

In [0]:
%sql
SELECT * FROM covid_t

In [0]:

import pyspark.pandas as ps
df = ps.read_csv("/databricks-datasets/COVID/covid-19-data")
df["date"] = ps.to_datetime(df['date'], errors='coerce')
df["cases"] = df["cases"].astype(int)
display(df)

In [0]:
import databricks.automl
import logging

In [0]:
logging.getLogger("py4j").setLevel(logging.WARNING)

In [0]:
print(summary)

In [0]:
print(summary.output_table_name)

In [0]:
forecast_pd = spark.table(summary.output_table_name)
display(forecast_pd)

In [0]:
import mlflow.pyfunc
from mlflow.tracking import MlflowClient
 
run_id = MlflowClient()
trial_id = summary.best_trial.mlflow_run_id
 
model_uri = "runs:/{run_id}/model".format(run_id=trial_id)
pyfunc_model = mlflow.pyfunc.load_model(model_uri)

In [0]:
forecasts = pyfunc_model._model_impl.python_model.predict_timeseries()
display(forecasts)

In [0]:
df_true = df.groupby("date").agg(y=("cases", "avg")).reset_index().to_pandas()
import matplotlib.pyplot as plt
 
fig = plt.figure(facecolor='w', figsize=(10, 6))
ax = fig.add_subplot(111)
forecasts = pyfunc_model._model_impl.python_model.predict_timeseries(include_history=True)
fcst_t = forecasts['ds'].dt.to_pydatetime()
ax.plot(df_true['date'].dt.to_pydatetime(), df_true['y'], 'k.', label='Observed data points')
ax.plot(fcst_t, forecasts['yhat'], ls='-', c='#0072B2', label='Forecasts')
ax.fill_between(fcst_t, forecasts['yhat_lower'], forecasts['yhat_upper'],
                color='#0072B2', alpha=0.2, label='Uncertainty interval')
ax.legend()
plt.show()