## Exploratory Data Analysis
This notebook will guide you through an exploratory data analysis of your time series data, providing the following visualizations and statistics along with an explanation of the results.

### Visualizations
- Actuals plot
- Seasonal decomposition
- ACF/PACF

### Stationarity Tests
- Augmented Dickey-Fuller

### Baseline Analysis (Optional)

To start, run first cell.

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
from decimal import Decimal
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.stattools import adfuller, acf, pacf
import json
from snowflake.snowpark.context import get_active_session
from dateutil.relativedelta import relativedelta
from datetime import timedelta, datetime

session = get_active_session()
query_tag = '{"origin":"sf_sit", "name":"sit_forecasting", "version":{"major":1, "minor":0}, "attributes":{"component":"eda"}}'
session.query_tag = query_tag
st.success("Import and setup successful.")

# SETUP

- Select your source table and optionally view a sample of your data set.
- You will also need to set your date/time series field and target value.
- Finally you will need to specify if this is a single or multi-series analysis.  If multi-series, you will have to select the series column and determine how many values of the series you want to analyze

In [None]:
# Set database, schema and table for analysis
TS_DB = "FORECAST_MODEL_BUILDER"
TS_SCHEMA = "BASE"
TS_TABLE_NM = "DAILY_PARTITIONED_SAMPLE_DATA"

# Set Time and Target columns
TIME_PERIOD_COLUMN = "ORDER_TIMESTAMP"
TARGET_COLUMN = "TARGET"

# Single or Multi-Series - Set as 1 for multi-series or 0 for single series
MULTISERIES = 0

# If Multi-Series, set Partitions
PARTITION_COLUMNS = ["STORE_ID", "PRODUCT_ID"]

# If Multi-Series, determine how many partitions/partition groups you want to visualize in this analysis.
# - If you have multiple partitions, one partition combination counts as one group
VISUALIZE_COUNT = 5

# Aggregate data for use by seasonal decomposition.
# - If your existing data is at a granularity that is smaller than an hour, use this value to aggregate to hour, day or higher.
# - Options: Hourly = "h", Daily = "d", Weekly = "w", Monthly = "m"
FREQUENCY = "d"

# Set analysis period for your seasonal decomposition.
# - This value will determine for what period seasonality patterns are analyzed.
# - You can iterate on this value and rerun to understand multiple period patters.
# - Options: Daily = 1, Weekly = 7, Monthly = 30, Annual = 365
PERIOD = 30

# Assuming START_DATE_BASELINE_FORECAST and END_DATE_BASELINE_FORECAST are strings in "YYYY-MM-DD" format
START_DATE_BASELINE_FORECAST = "2023-12-01"  # Example start date for testing
END_DATE_BASELINE_FORECAST = "2023-12-31"  # Example end date for testing
BASELINE_RESULT_TBL_NM = "BASELINE_RESULTS"

st.success("Variable setup is successful.")

# Create your dataframe

Run the next cell to create the dataframe that will be used for analysis.

Cells such as the one below are collapsed to allow for the user to have a seamless experience running the code without looking at it.  If users are interested in reviewing the underlying code, they can simply click the Code button next to the Run button to expand the code cell.

In [None]:
session.use_schema(f"{TS_DB}.{TS_SCHEMA}")
session.get_fully_qualified_current_schema()

# Define the mapping dictionary for seasonal analysis
value_to_label = {1: "Daily", 7: "Weekly", 30: "Monthly", 365: "Annual"}


# Function to get the label based on the input value
def get_label(PERIOD):
    return value_to_label.get(
        PERIOD, "Unknown"
    )  # Default to "Unknown" if the value is not in the dictionary


period_label = get_label(PERIOD)

# Build the SQL expression for concatenation of column names and values
partition_expr = " || '_' || ".join(
    [f"'{col}' || '_' || {col}" for col in PARTITION_COLUMNS]
)


# Build the SQL expression for concatenation of column names and values
partition_expr = " || '_' || ".join(
    [f"'{col}' || '_' || {col}" for col in PARTITION_COLUMNS]
)

# Create the GROUP_IDENTIFIER expression
group_identifier_expr = f"CONCAT({partition_expr})"

limit_query = f"""select distinct {group_identifier_expr}  AS GROUP_IDENTIFIER from {TS_DB}.{TS_SCHEMA}.{TS_TABLE_NM} limit {VISUALIZE_COUNT}"""
limit_sdf = session.sql(limit_query).to_pandas()

group_identifiers = tuple(limit_sdf["GROUP_IDENTIFIER"].tolist())
data_query = f"""select {group_identifier_expr} as GROUP_IDENTIFIER,* from {TS_DB}.{TS_SCHEMA}.{TS_TABLE_NM} where {group_identifier_expr} in {group_identifiers} """
data_sdf = session.sql(data_query).to_pandas()

st.success(
    "You have successfully created your dataset.  Here is a sample of 10 records."
)
st.write(data_sdf.head(10))

# Visualize Your Dataset

Run the next cell to visualize your time series data actual trends.

In [None]:
# Plot data for each Partition
if MULTISERIES == 0:
    aggregated_data = (
        data_sdf.groupby(TIME_PERIOD_COLUMN).sum(numeric_only=True).reset_index()
    )
    plt.figure(figsize=(6.5, 3))
    plt.plot(
        aggregated_data[TIME_PERIOD_COLUMN],
        aggregated_data[TARGET_COLUMN],
        label="Single Series Analysis",
        marker="o",
    )
    plt.title(f"Single Series Analysis: {TARGET_COLUMN} Over Time", fontsize=9)
    plt.xlabel(TIME_PERIOD_COLUMN, fontsize=8)
    plt.ylabel(TARGET_COLUMN, fontsize=8)
    # plt.legend(fontsize=8)
    # Change font for axis tick labels
    plt.tick_params(axis="both", labelsize=8)  # Adjust label size
    plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
    plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
    plt.grid()
    plt.tight_layout()
    plt.show()
else:
    series_ids = data_sdf["GROUP_IDENTIFIER"].unique()
    for series_id in series_ids:
        series_data = data_sdf[data_sdf["GROUP_IDENTIFIER"] == series_id]
        plt.figure(figsize=(7, 4))
        plt.plot(
            series_data[TIME_PERIOD_COLUMN],
            series_data[TARGET_COLUMN],
            label=f"Series {series_id}",
            marker="o",
        )
        plt.title(f"Series {series_id}: {TARGET_COLUMN} Over Time", fontsize=9)
        plt.xlabel(TIME_PERIOD_COLUMN, fontsize=8)
        plt.ylabel(TARGET_COLUMN, fontsize=8)
        # plt.legend(fontsize=8)
        # Change font for axis tick labels
        plt.tick_params(axis="both", labelsize=8)  # Adjust label size
        plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
        plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
        plt.grid()
        plt.tight_layout()
        plt.show()

# print(data_sdf)

# Visualize and Analyze Seasonal Decomposition & ACF/PACF and ADF

**Seasonality** in time series modeling refers to patterns that repeat at fixed intervals, such as daily, weekly, monthly, or yearly cycles. These patterns can significantly impact the accuracy of time series models, and understanding seasonality is crucial for making accurate predictions. Using Seasonal Decomposition and ACF/PACF analysis, we can understand seasonality better and build models that account for seasonality.  

**Stationarity** in time series modeling refers to the property of a time series where its statistical properties, such as the mean, variance, and autocorrelation structure, remain constant over time.  The ADF test is a unit root test that will check whether a time series is stationary.  When time series are non-stationary, other data engineering techniques are necessary to account for the changing time series.  

Run the next cell to visualize your Seasonal Decomposition and ACF/PACF analysis, and to analyze ADF for stationarity. 

In [None]:
if MULTISERIES == 0:
    # Sort the data
    data_sdf.sort_values(by=TIME_PERIOD_COLUMN, inplace=True)
    st.write(data_sdf)

    # Combine data for all columns by date
    aggregate_data = data_sdf.groupby([TIME_PERIOD_COLUMN])[TARGET_COLUMN].sum()

    # Set the timestamp column as the index
    aggregate_data.index = pd.to_datetime(aggregate_data.index)

    # Ensure data has a consistent frequency
    aggregate_data = aggregate_data.asfreq(FREQUENCY)

    # Handle missing values
    aggregate_data = aggregate_data.ffill().dropna()

    # print(f"Analyzing {seasonality} Seasonality with Period = {period}")
    result = seasonal_decompose(aggregate_data, model="additive", period=PERIOD)

    # Plot decomposed components
    st.write(f"## {period_label} Seasonal Decomposition", fontsize=9)
    plt.figure(figsize=(6.5, 4))
    plt.subplot(411)
    plt.plot(result.observed, label="Observed", color="blue")
    plt.legend(loc="upper left", fontsize=6)
    plt.tick_params(axis="both", labelsize=8)  # Adjust label size
    plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
    plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
    plt.subplot(412)
    plt.plot(result.trend, label="Trend", color="orange")
    plt.legend(loc="upper left", fontsize=6)
    plt.tick_params(axis="both", labelsize=8)  # Adjust label size
    plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
    plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
    plt.subplot(413)
    plt.plot(result.seasonal, label="Seasonal", color="green")
    plt.legend(loc="upper left", fontsize=6)
    plt.tick_params(axis="both", labelsize=8)  # Adjust label size
    plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
    plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
    plt.subplot(414)
    plt.plot(result.resid, label="Residuals", color="red")
    plt.legend(loc="upper left", fontsize=6)
    plt.tick_params(axis="both", labelsize=8)  # Adjust label size
    plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
    plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
    plt.tight_layout()
    # plt.suptitle(f"{period_selectbox} Seasonal Decomposition", y=1.02)
    plt.show()

    # Plot ACF and PACF
    st.write("## ACF and PACF Analysis", y=1.02, fontsize=9)
    fig, ax = plt.subplots(2, 1, figsize=(6.5, 4))

    # ACF plot
    plot_acf(aggregate_data.dropna(), lags=40, ax=ax[0])
    ax[0].set_title("ACF", fontsize=8)
    ax[0].tick_params(axis="both", labelsize=8)  # Adjust label size for ticks

    # Correct way to set x and y tick labels with desired font properties
    ax[0].tick_params(axis="x", labelsize=7)  # Set x-axis tick label size
    ax[0].tick_params(axis="y", labelsize=7)  # Set y-axis tick label size

    # PACF plot
    plot_pacf(aggregate_data.dropna(), lags=40, ax=ax[1])
    ax[1].set_title("PACF", fontsize=8)
    ax[1].tick_params(axis="both", labelsize=8)  # Adjust label size for ticks

    # Correct way to set x and y tick labels with desired font properties
    ax[1].tick_params(axis="x", labelsize=7)  # Set x-axis tick label size
    ax[1].tick_params(axis="y", labelsize=7)  # Set y-axis tick label size

    # Show the plot
    plt.tight_layout()
    plt.show()

    # Perform Augmented Dickey-Fuller test
    adf_result = adfuller(aggregate_data.dropna())
    st.write("## Augmented Dickey-Fuller")
    st.write(f"**ADF Statistic**: {adf_result[0]}")
    st.write(f"**p-value**: {adf_result[1]}")
    st.write("**Critical Values**:")
    for key, value in adf_result[4].items():
        st.write(f"   **{key}**: {value}")

else:
    series_ids = data_sdf["GROUP_IDENTIFIER"].unique()
    for series_id in series_ids:
        # Sort the data
        data_sdf.sort_values(by=["GROUP_IDENTIFIER", TIME_PERIOD_COLUMN], inplace=True)
        # Filter data for the current series
        series_data = data_sdf[data_sdf["GROUP_IDENTIFIER"] == series_id]
        # Set the timestamp column as the index
        series_data.set_index(TIME_PERIOD_COLUMN, inplace=True)
        # Focus only on the target variable column
        time_series = series_data[TARGET_COLUMN]
        # Ensure data has a consistent frequency
        time_series = time_series.asfreq(FREQUENCY)
        # Handle missing values
        time_series = time_series.ffill().dropna()

        # Seasonal Decomposition
        result = seasonal_decompose(time_series, model="additive", period=PERIOD)

        # Plot decomposed components
        st.write(f"# Seasonal Decomposition for Series {series_id}")
        st.write(f"## {PERIOD} Seasonal Decomposition for Series {series_id}")
        plt.figure(figsize=(6.5, 4))
        plt.subplot(411)
        plt.plot(result.observed, label="Observed", color="blue")
        plt.legend(loc="upper left", fontsize=6)
        plt.tick_params(axis="both", labelsize=8)  # Adjust label size
        plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
        plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
        plt.subplot(412)
        plt.plot(result.trend, label="Trend", color="orange")
        plt.legend(loc="upper left", fontsize=6)
        plt.tick_params(axis="both", labelsize=8)  # Adjust label size
        plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
        plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
        plt.subplot(413)
        plt.plot(result.seasonal, label="Seasonal", color="green")
        plt.legend(loc="upper left", fontsize=6)
        plt.tick_params(axis="both", labelsize=8)  # Adjust label size
        plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
        plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
        plt.subplot(414)
        plt.plot(result.resid, label="Residuals", color="red")
        plt.legend(loc="upper left", fontsize=6)
        plt.tick_params(axis="both", labelsize=8)  # Adjust label size
        plt.xticks(fontsize=7, fontname="Arial")  # Set font and size for x-axis ticks
        plt.yticks(fontsize=7, fontname="Arial")  # Set font and size for y-axis ticks
        plt.tight_layout()
        # plt.suptitle(f"{period_selectbox} Seasonal Decomposition", y=1.02)
        plt.show()

        # Plot ACF and PACF
        st.write(f"## ACF and PACF Analysis for Series {series_id}", fontsize=9)
        fig, ax = plt.subplots(2, 1, figsize=(6.5, 4))

        # ACF plot
        plot_acf(time_series.dropna(), lags=40, ax=ax[0])
        ax[0].set_title(f"Series {series_id}: ACF", fontsize=8)  # Set title font size
        ax[0].tick_params(
            axis="both", labelsize=7
        )  # Adjust tick label size for ACF plot
        ax[0].tick_params(axis="x", labelsize=7)  # Set x-axis label font size for ACF
        ax[0].tick_params(axis="y", labelsize=7)  # Set y-axis label font size for ACF

        # PACF plot
        plot_pacf(time_series.dropna(), lags=40, ax=ax[1])
        ax[1].set_title(f"Series {series_id}: PACF", fontsize=8)  # Set title font size
        ax[1].tick_params(
            axis="both", labelsize=7
        )  # Adjust tick label size for PACF plot
        ax[1].tick_params(axis="x", labelsize=7)  # Set x-axis label font size for PACF
        ax[1].tick_params(axis="y", labelsize=7)  # Set y-axis label font size for PACF

        # Show the plot
        plt.tight_layout()
        plt.show()

        # Perform Augmented Dickey-Fuller test
        adf_result = adfuller(time_series.dropna())
        st.write(f"## Augmented Dickey-Fuller for Series {series_id}")
        st.write(f"**ADF Statistic**: {adf_result[0]}")
        st.write(f"**p-value**: {adf_result[1]}")
        st.write("**Critical Values**:")
        for key, value in adf_result[4].items():
            st.write(f"   **{key}**: {value}")

# Perform Baseline Analysis - Optional

If you would like to perform basic non-ML-based forecasting (baseline modeling) on your data which can be used as a comparison for advanced modeling later, the next cell provides for a simple forecasting method.
- First you will pick which range of dates you want to forecast.  
    - The code will automatically set this as test data and use the rest of the data for forecasting. 
    - Your range of options will be for dates within the most recent 12 months of your data. You should have at least two years of prior data available for forecasting.
- The forecasting method used here is calculating the average of the target value from the same day of week (Sunday through Saturday) of the same week of year (1-52) from prior years.  For example, to predict 2024/12/02 data:
    - December 2, 2024 is a Monday, so a day of week value of 1
    - It falls into the week of year 49
    - The same dates from 2022 and 2023 that are the day of week = 1 and week of year = 49 are:
        - 2022: 2022/12/05
        - 2023: 2023/12/04
    - The forecast will find the target variable values for those dates and average them to predict target for 2024/12/02


In [None]:
BASELINE_RUNTIME = datetime.now()

data = data_sdf[
    data_sdf[TIME_PERIOD_COLUMN] < pd.to_datetime(START_DATE_BASELINE_FORECAST)
].copy()

# Extract Time Related Features
data[TIME_PERIOD_COLUMN] = pd.to_datetime(data[TIME_PERIOD_COLUMN])
data["DAY_OF_WEEK"] = data[TIME_PERIOD_COLUMN].dt.dayofweek  # Monday=0, Sunday=6
data["WEEK_OF_YEAR"] = data[TIME_PERIOD_COLUMN].dt.isocalendar().week
data["YEAR"] = data[TIME_PERIOD_COLUMN].dt.year
data.sort_values(by=["GROUP_IDENTIFIER", TIME_PERIOD_COLUMN], inplace=True)

aggregated_df = (
    data.groupby([TIME_PERIOD_COLUMN, "DAY_OF_WEEK", "WEEK_OF_YEAR"])[TARGET_COLUMN]
    .sum()
    .reset_index()
)  # Aggregate by TIME_PERIOD_COLUMN for baseline modeling

# Extract Time Related Features
data_sdf[TIME_PERIOD_COLUMN] = pd.to_datetime(data_sdf[TIME_PERIOD_COLUMN])
data_sdf["DAY_OF_WEEK"] = data_sdf[
    TIME_PERIOD_COLUMN
].dt.dayofweek  # Monday=0, Sunday=6
data_sdf["WEEK_OF_YEAR"] = data_sdf[TIME_PERIOD_COLUMN].dt.isocalendar().week
data_sdf["YEAR"] = data_sdf[TIME_PERIOD_COLUMN].dt.year
data_sdf.sort_values(by=["GROUP_IDENTIFIER", TIME_PERIOD_COLUMN], inplace=True)

aggjoin_df = (
    data_sdf.groupby([TIME_PERIOD_COLUMN, "DAY_OF_WEEK", "WEEK_OF_YEAR"])[TARGET_COLUMN]
    .sum()
    .reset_index()
)  # Aggregate by TIME_PERIOD_COLUMN for target to actual analysis


# Forecast
forecast_dates = pd.date_range(
    start=START_DATE_BASELINE_FORECAST, end=END_DATE_BASELINE_FORECAST, freq="d"
)
forecast_df = pd.DataFrame({TIME_PERIOD_COLUMN: forecast_dates})
forecast_df["DAY_OF_WEEK"] = forecast_df[TIME_PERIOD_COLUMN].dt.dayofweek
forecast_df["WEEK_OF_YEAR"] = forecast_df[TIME_PERIOD_COLUMN].dt.isocalendar().week

# Set TARGET values to NaN for dates between START_DATE_BASELINE_FORECAST and END_DATE_BASELINE_FORECAST
forecast_df.loc[
    (data[TIME_PERIOD_COLUMN] >= pd.to_datetime(START_DATE_BASELINE_FORECAST))
    & (forecast_df[TIME_PERIOD_COLUMN] <= pd.to_datetime(END_DATE_BASELINE_FORECAST)),
    TARGET_COLUMN,
] = None

# Initialize an empty list to store results
forecast_results = []

if MULTISERIES == 0:
    # Add DAY_OF_WEEK and WEEK_OF_YEAR to the data before aggregation
    data["DAY_OF_WEEK"] = data[TIME_PERIOD_COLUMN].dt.dayofweek
    data["WEEK_OF_YEAR"] = data[TIME_PERIOD_COLUMN].dt.isocalendar().week

    # Aggregate the data by TIME_PERIOD_COLUMN (no grouping by GROUP_IDENTIFIER)
    aggregated_data = (
        data.groupby([TIME_PERIOD_COLUMN, "DAY_OF_WEEK", "WEEK_OF_YEAR"])[TARGET_COLUMN]
        .sum()
        .reset_index()
    )  # Aggregate by TIME_PERIOD_COLUMN

    # Create the forecast for each day in December 2023
    for _, forecast_row in forecast_df.iterrows():
        # Find matching data for the same day of the week and week of the year
        matching_data = aggregated_data[
            (aggregated_data["DAY_OF_WEEK"] == forecast_row["DAY_OF_WEEK"])
            & (aggregated_data["WEEK_OF_YEAR"] == forecast_row["WEEK_OF_YEAR"])
        ]

        # Compute the average traffic (TARGET) from prior years
        if not matching_data.empty:
            forecast_traffic = matching_data[TARGET_COLUMN].mean()
        else:
            forecast_traffic = None  # No historical data available

        # Append the forecast result
        forecast_results.append(
            {
                TIME_PERIOD_COLUMN: forecast_row[TIME_PERIOD_COLUMN],
                "FORECAST": forecast_traffic,
            }
        )

    # Convert forecast results to a DataFrame
    forecast_results_df = pd.DataFrame(forecast_results)
    # st.write(data)
    # st.write(forecast_results_df)
    # st.write(aggregated_df)
    # st.write(aggjoin_df)

    # Merge the forecast results with data_sdf on TIME_PERIOD_COLUMN to get the TARGET value from data_sdf
    merged_df = pd.merge(
        forecast_results_df,
        aggjoin_df[[TIME_PERIOD_COLUMN, TARGET_COLUMN]],
        on=[TIME_PERIOD_COLUMN],
        how="left",
    )

    # Select only the required columns: ORDER_TIMESTAMP (or TIME_PERIOD_COLUMN), FORECAST, and TARGET (TRAFFIC)
    merged_df = merged_df[[TIME_PERIOD_COLUMN, "FORECAST", TARGET_COLUMN]]

    st.write(merged_df)

    # Create the plot for the entire dataset (since there's no grouping by GROUP_IDENTIFIER anymore)
    fig, ax = plt.subplots(figsize=(7.5, 3))

    # Plotting FORECAST vs TARGET for the entire dataset
    ax.plot(
        merged_df[TIME_PERIOD_COLUMN],
        merged_df["FORECAST"],
        label="Forecast",
        marker="o",
        linestyle="-",
        markersize=3,
    )
    ax.plot(
        merged_df[TIME_PERIOD_COLUMN],
        merged_df[TARGET_COLUMN],
        label="Target",
        marker="x",
        linestyle="--",
        markersize=3,
    )

    # Set labels and title for the plot
    ax.set_xlabel("Order Timestamp", fontsize=9)
    ax.set_ylabel("Value", fontsize=9)
    ax.set_title("Forecast vs Target", fontsize=9)

    # Rotate x-axis labels for better visibility
    ax.tick_params(axis="x", rotation=45, labelsize=8)
    ax.tick_params(axis="y", labelsize=8)

    # Set Y-axis minimum to 0
    ax.set_ylim(0, merged_df[["FORECAST", TARGET_COLUMN]].max().max())

    # Add legend
    ax.legend(fontsize=8)

    # Adjust layout to prevent overlap
    plt.tight_layout()

    # Display the plot
    plt.show()

else:
    # Process each group
    for group_id in data["GROUP_IDENTIFIER"].unique():
        group_data = data[data["GROUP_IDENTIFIER"] == group_id]

        # Create the forecast for each day in December 2023
        for _, forecast_row in forecast_df.iterrows():
            # Find past occurrences for the same day of the week in the same week of the year
            matching_data = group_data[
                (group_data["DAY_OF_WEEK"] == forecast_row["DAY_OF_WEEK"])
                & (group_data["WEEK_OF_YEAR"] == forecast_row["WEEK_OF_YEAR"])
            ]

            # Compute the average traffic from prior years
            if not matching_data.empty:
                forecast_traffic = matching_data[TARGET_COLUMN].mean()
            else:
                forecast_traffic = None  # No historical data available

            # Append the forecast result
            forecast_results.append(
                {
                    "GROUP_IDENTIFIER": group_id,
                    TIME_PERIOD_COLUMN: forecast_row[TIME_PERIOD_COLUMN],
                    "FORECAST": forecast_traffic,
                }
            )

    # Convert forecast results to a DataFrame
    forecast_results_df = pd.DataFrame(forecast_results)

    # Merge the forecast results with data_sdf on GROUP_IDENTIFIER and TIME_PERIOD_COLUMN to get the TARGET value from data_sdf
    merged_df = pd.merge(
        forecast_results_df,
        data_sdf[["GROUP_IDENTIFIER", TIME_PERIOD_COLUMN, TARGET_COLUMN]],
        on=["GROUP_IDENTIFIER", TIME_PERIOD_COLUMN],
        how="left",
    )

    # Select only the required columns: GROUP_IDENTIFIER, TIME_PERIOD_COLUMN (or TIME), FORECAST, and TARGET (TRAFFIC)
    merged_df = merged_df[
        ["GROUP_IDENTIFIER", TIME_PERIOD_COLUMN, "FORECAST", TARGET_COLUMN]
    ]

    st.write(merged_df)

    # Get the number of unique GROUP_IDENTIFIERs to determine the number of subplots
    num_groups = len(merged_df["GROUP_IDENTIFIER"].unique())

    # Create subplots (one for each GROUP_IDENTIFIER)
    fig, axes = plt.subplots(nrows=num_groups, ncols=1, figsize=(7.5, 3 * num_groups))

    # If there's only one group, axes will not be an array, so we make it an array for consistency
    if num_groups == 1:
        axes = [axes]

    # Loop through each GROUP_IDENTIFIER and plot on the corresponding subplot
    for i, group_id in enumerate(merged_df["GROUP_IDENTIFIER"].unique()):
        group_data = merged_df[merged_df["GROUP_IDENTIFIER"] == group_id]

        # Plotting FORECAST vs TARGET for the current GROUP_IDENTIFIER
        axes[i].plot(
            group_data[TIME_PERIOD_COLUMN],
            group_data["FORECAST"],
            label="Forecast",
            marker="o",
            linestyle="-",
            markersize=3,
        )
        axes[i].plot(
            group_data[TIME_PERIOD_COLUMN],
            group_data[TARGET_COLUMN],
            label="Target",
            marker="x",
            linestyle="--",
            markersize=3,
        )

        # Set labels and title for each subplot
        axes[i].set_xlabel("Time Period", fontsize=9)
        axes[i].set_ylabel("Value", fontsize=9)
        axes[i].set_title(f"Group {group_id} - Forecast vs Target", fontsize=9)

        # Rotate x-axis labels for better visibility
        axes[i].tick_params(axis="x", rotation=45, labelsize=8)
        axes[i].tick_params(axis="y", labelsize=8)

        # Set Y-axis minimum to 0 for each subplot
        axes[i].set_ylim(0, group_data[["FORECAST", TARGET_COLUMN]].max().max())

        # Add legend
        axes[i].legend(fontsize=8)

    # Adjust layout to prevent overlap
    plt.tight_layout()

    # Display the plot
    plt.show()

st.success("You have successfully run your baseline analysis.")

# Calculate MAPE
merged_df["APE"] = (
    abs((merged_df[TARGET_COLUMN] - merged_df["FORECAST"]) / merged_df[TARGET_COLUMN])
    * 100
)
merged_df["BASELINE_RUNTIME"] = BASELINE_RUNTIME

overall_mape = merged_df["APE"].mean()
print(f"The Overall MAPE (Mean Average Percentage Error) is {overall_mape:.2f}%.")

merged_df = session.create_dataframe(merged_df)
merged_df.write.save_as_table(
    BASELINE_RESULT_TBL_NM, mode="overwrite", comment=query_tag
)