# This notebook is to test parts of the code in isolation 

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px

In [None]:
data = pd.read_csv("raw_sql.csv", engine="pyarrow")
print("Data already imported!")

### Data cleaning Part

In [None]:
def process_sales_data(data: pd.DataFrame):
    """Processes the sales data.

    This function renames columns, converts date formats, and aggregates quantities by date.

    Parameters:
    data (pd.DataFrame): The input dataframe with columns:
        "order_id",
        "date_created",
        "fulfilled",
        "order_items_item_seller_sku",
        "order_items_quantity",
        "order_items_unit_price",

    Returns:
    pd.DataFrame: Processed dataframe with cleaned and aggregated sales data.
    """
    data = data.copy()
    _fail_if_invalid_sales_data(data)

    data = _rename_columns(data)
    data = _convert_date_column(data)
    data = _collapse_sales_data(data)
    data = _set_datetime_index(data)
    data = _mark_missing_data(data)
    data = _remove_long_zero_periods(data)

    return data


def _rename_columns(data: pd.DataFrame):
    """Renames columns to more understandable names."""
    return data.rename(
        columns={
            "date_created": "date",
            "order_items_quantity": "quant",
            "order_items_unit_price": "price",
            "order_items_item_seller_sku": "sku",
        }
    ).copy()


def _convert_date_column(data: pd.DataFrame):
    """Converts the 'date' column to datetime format, removes timezones, and normalizes to midnight."""
    data = data.copy()
    # Convert the 'date' column to datetime
    data["date"] = pd.to_datetime(data["date"], errors="coerce")
    # Remove timezone: if the timestamp is timezone-aware, use tz_convert(None) to remove the tz info.
    data["date"] = data["date"].apply(
        lambda x: x.tz_convert(None) if x is not pd.NaT and x.tzinfo is not None else x
    )
    # Normalize the datetime to midnight (i.e., keep only the date part)
    data["date"] = data["date"].dt.normalize()
    return data


def _collapse_sales_data(data: pd.DataFrame):
    """Aggregates sales data by date and SKU, summing the quantity and price."""
    data = data.copy()
    # Ensure that the 'date' column is converted to datetime and normalized
    data["date"] = pd.to_datetime(data["date"], errors="coerce").dt.normalize()
    # Group by both date and SKU
    return data.groupby(["date", "sku"], as_index=False).agg(
        {"quant": "sum", "price": "sum"}
    )


def _set_datetime_index(data: pd.DataFrame):
    """Sets 'date' as index and converts it to a DatetimeIndex."""
    data = data.copy()
    data = data.set_index("date")
    data.index = pd.to_datetime(data.index, errors="coerce")
    return data


def _mark_missing_data(data: pd.DataFrame) -> pd.DataFrame:
    """
    For each SKU in the data, reindex the DataFrame so that every day between the first and
    last available date is present. For missing days, the "quant" and "price" values will be
    set to 0, while the "sku" column will be filled appropriately.

    Args:
        data (pd.DataFrame): DataFrame with a datetime index and a "sku" column.

    Returns:
        pd.DataFrame: Reindexed DataFrame with all days present for each SKU.
    """
    df_list = []
    for sku in data["sku"].unique():
        sku_data = data[data["sku"] == sku].sort_index()
        # Drop rows with invalid (NaT) dates in the index
        sku_data = sku_data[sku_data.index.notna()]
        if sku_data.empty:
            continue

        # Normalize the index so that timestamps become midnight
        sku_data.index = sku_data.index.normalize()
        # Remove duplicate date entries (keeping the first occurrence)
        sku_data = sku_data[~sku_data.index.duplicated(keep="first")]

        start_date = sku_data.index.min()
        end_date = sku_data.index.max()
        if pd.isna(start_date) or pd.isna(end_date):
            continue

        full_date_range = pd.date_range(start=start_date, end=end_date, freq="D")
        sku_data_reindexed = sku_data.reindex(full_date_range)
        # Fill the "sku" column for missing days with the current SKU value.
        sku_data_reindexed["sku"] = sku
        # Fill missing 'quant' and 'price' with 0
        sku_data_reindexed["quant"] = sku_data_reindexed["quant"].fillna(0)
        sku_data_reindexed["price"] = sku_data_reindexed["price"].fillna(0)
        sku_data_reindexed.index.name = "date"
        df_list.append(sku_data_reindexed)

    if df_list:
        return pd.concat(df_list)
    else:
        return data.copy()


def _remove_long_zero_periods(
    data: pd.DataFrame, zero_period_days: int = 7
) -> pd.DataFrame:
    """
    For each SKU in the data, for periods where 'quant' is 0 continuously for longer than
    zero_period_days, set the 'quant' value to NaN. This indicates that the sales data in that
    period is missing or unreliable, so the model can handle it appropriately.

    Args:
        data (pd.DataFrame): DataFrame with a datetime index and a "sku" column.
        zero_period_days (int): Minimum number of consecutive days with quant equal to 0 to trigger setting to NaN.

    Returns:
        pd.DataFrame: DataFrame with long zero periods marked as missing (NaN).
    """
    df_list = []
    for sku in data["sku"].unique():
        sku_data = data[data["sku"] == sku].sort_index()
        # Create a mask for rows where quant is 0
        zero_mask = sku_data["quant"] == 0
        zero_data = sku_data[zero_mask]
        if zero_data.empty:
            df_list.append(sku_data)
            continue

        # Group consecutive dates in zero_data: if the difference between consecutive dates
        # is not exactly one day, then it's a new group.
        groups = (zero_data.index.to_series().diff() != pd.Timedelta(days=1)).cumsum()

        # For each group of consecutive zeros, if the group is longer than the threshold, mark those rows as missing
        for _, group in zero_data.groupby(groups):
            if len(group) > zero_period_days:
                sku_data.loc[group.index, "quant"] = np.nan
        df_list.append(sku_data)
    if df_list:
        return pd.concat(df_list).sort_index()
    else:
        return data.copy()


def _fail_if_invalid_sales_data(data: pd.DataFrame):
    """Raise an error if data is not a DataFrame or is missing required columns."""
    required_columns = {
        "order_id",
        "date_created",
        "fulfilled",
        "order_items_item_seller_sku",
        "order_items_quantity",
        "order_items_unit_price",
    }

    if not isinstance(data, pd.DataFrame):
        error_msg = f"'data' must be a pandas DataFrame, got {type(data)}."
        raise TypeError(error_msg)

    if not required_columns.issubset(data.columns):
        missing_columns = required_columns - set(data.columns)
        error_msg = f"'data' DataFrame is missing required columns: {missing_columns}"
        raise ValueError(error_msg)

In [None]:
clean_data = process_sales_data(data)

In [None]:
clean_data

In [None]:
def plot_timeseries_for_sku(
    data: pd.DataFrame, sku: str, start_date: str, end_date: str
):
    """
    Plot the time series for a specific SKU between the given start_date and end_date using Plotly,
    showing gaps for missing data.

    Args:
        data (pd.DataFrame): The DataFrame containing sales information with columns including 'sku', 'quant',
                             and either a datetime index or a 'date' column.
        sku (str): The SKU identifier to filter the data.
        start_date (str): The start date of the period (in 'YYYY-MM-DD' format).
        end_date (str): The end date of the period (in 'YYYY-MM-DD' format).

    Returns:
        fig: The Plotly figure object containing the plot.
    """
    # Filter data for the specified SKU
    sku_data = data[data["sku"] == sku].copy()

    # Ensure the index is datetime; if not, use the 'date' column if available
    if not isinstance(sku_data.index, pd.DatetimeIndex):
        if "date" in sku_data.columns:
            sku_data["date"] = pd.to_datetime(sku_data["date"])
            sku_data.set_index("date", inplace=True)
        else:
            raise ValueError("Data must have a datetime index or a 'date' column.")

    # Filter data for the given time period
    mask = (sku_data.index >= start_date) & (sku_data.index <= end_date)
    sku_data = sku_data.loc[mask]

    # Reindex to include all dates in the period, leaving missing values as NaN
    full_date_range = pd.date_range(start=start_date, end=end_date, freq="D")
    sku_data = sku_data.reindex(full_date_range)
    sku_data.index.name = "date"
    # Fill the 'sku' column for missing rows
    sku_data["sku"] = sku

    if sku_data.empty:
        raise ValueError(
            f"No data found for SKU {sku} between {start_date} and {end_date}."
        )

    # Reset index for Plotly
    sku_data_reset = sku_data.reset_index()

    # Create a Plotly line chart; missing data (NaN) will create gaps if connectgaps is False
    fig = px.line(
        sku_data_reset,
        x="date",
        y="quant",
        title=f"Time Series for SKU {sku} from {start_date} to {end_date}",
        labels={"date": "Date", "quant": "Quantity"},
    )

    # Update traces to not connect gaps
    fig.update_traces(connectgaps=False)

    return fig

In [None]:
data

In [None]:
sku = "SHIELD"

fig = plot_timeseries_for_sku(clean_data, sku, "2024-01-01", "2025-01-01")
produces = f"{sku}_time_series.html"
fig