In [1]:
import os

import numpy as np  #  type: ignore
import matplotlib.pyplot as plt  #  type: ignore
import pandas as pd  #  type: ignore

from sklearn.metrics import mean_squared_error, mean_absolute_error
from statsmodels.tsa.stattools import acf, pacf  #  type: ignore
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf  #  type: ignore


In [2]:
def load_data():
    df = pd.read_csv("../files/input/sutter.csv")
    df = df.set_index("date")
    return df

In [None]:
def plot_time_series(df, yt_col="yt_true"):
    """Time series plot."""

    plt.figure(figsize=(12, 4))

    # yt_real
    plt.plot(df[yt_col], ".-k", linewidth=0.5, label=yt_col)
    plt.grid(color="lightgray", linestyle="--", linewidth=0.5)

    cols = [col for col in df.columns if col.startswith("yt_pred")]
    colors = "rbgcmy"
    for i, col in enumerate(cols):
        plt.plot(df[col], ".-", color=colors[i], linewidth=0.7, label=col)

    # line division
    plt.plot(
        [len(df[yt_col]) - 24, len(df[yt_col]) - 24],
        [min(df[yt_col]), max(df[yt_col])],
        "--",
        linewidth=2,
    )

    # format
    plt.xticks(rotation=90)
    plt.xticks(range(0, len(df[yt_col]), 12), df[yt_col].index[::12])
    plt.yticks(fontsize=8)
    plt.xticks(fontsize=8)
    plt.legend()
    plt.show()


In [4]:

def acf_pacf_plots(z):
    """Correlation plot."""

    def format_plot():
        plt.gca().spines["top"].set_visible(False)
        plt.gca().spines["bottom"].set_visible(False)
        plt.gca().spines["right"].set_visible(False)
        plt.gca().collections[0].set_color("k")
        plt.gca().collections[1].set_color("gray")
        plt.grid(color="lightgray", linestyle="--", linewidth=0.5)
        plt.ylim(-1.03, 1.03)
        plt.yticks(fontsize=8)
        plt.xticks(fontsize=8)
        plt.title(plt.gca().get_title(), fontsize=8)

    plt.figure(figsize=(9, 3))

    plt.subplot(1, 2, 1)
    plot_acf(z, lags=24, ax=plt.gca(), color="k")
    format_plot()

    plt.subplot(1, 2, 2)
    plot_pacf(z, lags=24, ax=plt.gca(), color="k")
    format_plot()

    plt.show()
