In [None]:
from airflow import DAG
from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator
from airflow.providers.google.cloud.operators.bigquery import BigQueryCreateModelOperator, BigQueryExecuteQueryOperator
from airflow.operators.dummy_operator import DummyOperator
from google.cloud import bigquery
from datetime import datetime, timedelta
import requests
import pandas as pd
from statsmodels.tsa.seasonal import STL
from google.cloud import storage

# DAG setup
default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': datetime(2023, 1, 1),
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
}
dag = DAG(
    'furniture_forecast_pipeline',
    default_args=default_args,
    description='End-to-end pipeline for furniture demand forecasting with ARIMA_PLUS, Poisson regression, and automated retraining',
    schedule_interval='@monthly',
)

# Task 1: Ingest Data
def ingest_data():
    response = requests.get("https://api.example.com/furniture_sales_data")
    data = response.json()
    df = pd.DataFrame(data)

    storage_client = storage.Client()
    bucket = storage_client.get_bucket("furniture-forecasting-bucket")
    bucket.blob("raw_data/furniture_sales_data.csv").upload_from_string(df.to_csv(index=False), 'text/csv')

ingest_task = PythonOperator(
    task_id='ingest_data',
    python_callable=ingest_data,
    dag=dag,
)

# Task 2: Preprocess Data
def preprocess_data():
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("furniture-forecasting-bucket")
    blob = bucket.blob("raw_data/furniture_sales_data.csv")
    df = pd.read_csv(blob.download_as_text())

    df = df.dropna(subset=['sales'])
    df['sales'] = df['sales'].fillna(method='ffill')

    bucket.blob("preprocessed_data/preprocessed_sales_data.csv").upload_from_string(df.to_csv(index=False), 'text/csv')

preprocess_task = PythonOperator(
    task_id='preprocess_data',
    python_callable=preprocess_data,
    dag=dag,
)

# Task 3: Load Data to BigQuery
load_to_bigquery_task = GCSToBigQueryOperator(
    task_id='load_data_to_bigquery',
    bucket='furniture-forecasting-bucket',
    source_objects=['preprocessed_data/preprocessed_sales_data.csv'],
    destination_project_dataset_table='your_project.your_dataset.sales_data',
    source_format='CSV',
    skip_leading_rows=1,
    write_disposition='WRITE_TRUNCATE',
    dag=dag,
)

# Task 4: STL Decomposition
def stl_decomposition():
    storage_client = storage.Client()
    bucket = storage_client.get_bucket("furniture-forecasting-bucket")
    blob = bucket.blob("preprocessed_data/preprocessed_sales_data.csv")
    df = pd.read_csv(blob.download_as_text(), parse_dates=['date'])

    stl = STL(df['sales'], seasonal=12)
    result = stl.fit()
    df['trend'] = result.trend
    df['seasonal'] = result.seasonal
    df['residual'] = result.resid

    bucket.blob("decomposed_data/stl_decomposed_sales_data.csv").upload_from_string(df.to_csv(index=False), 'text/csv')

stl_decomposition_task = PythonOperator(
    task_id='stl_decomposition',
    python_callable=stl_decomposition,
    dag=dag,
)

# Task 5a: Train ARIMA_PLUS Model on Seasonal Component
train_arima_task = BigQueryCreateModelOperator(
    task_id='train_arima_plus',
    model_id='your_project.your_dataset.seasonal_forecast_model',
    sql="""
        CREATE OR REPLACE MODEL `your_project.your_dataset.seasonal_forecast_model`
        OPTIONS(
            model_type = 'ARIMA_PLUS',
            time_series_timestamp_col = 'date',
            time_series_data_col = 'seasonal',
            horizon = 12,
            auto_arima = TRUE
        ) AS
        SELECT
            date,
            seasonal
        FROM
            `your_project.your_dataset.stl_decomposed_sales_data`
        ORDER BY
            date
    """,
    dag=dag,
)

# Task 5b: Train Poisson Regression Model on Trend Component
train_poisson_task = BigQueryCreateModelOperator(
    task_id='train_poisson',
    model_id='your_project.your_dataset.trend_forecast_model',
    sql="""
        CREATE OR REPLACE MODEL `your_project.your_dataset.trend_forecast_model`
        OPTIONS(
            model_type = 'GLM',
            model_registry = 'POISSON_REGRESSION',
            time_series_timestamp_col = 'date',
            time_series_data_col = 'trend'
        ) AS
        SELECT
            date,
            trend
        FROM
            `your_project.your_dataset.stl_decomposed_sales_data`
        ORDER BY
            date
    """,
    dag=dag,
)

# Task 6a: Forecasting with ARIMA_PLUS for Seasonal Component
forecast_arima_task = BigQueryExecuteQueryOperator(
    task_id='forecast_arima_plus',
    sql="""
        SELECT
            forecast_timestamp,
            forecast_value,
            prediction_interval_lower_bound,
            prediction_interval_upper_bound
        FROM
            ML.FORECAST(
                MODEL `your_project.your_dataset.seasonal_forecast_model`,
                STRUCT(12 AS horizon)
            )
    """,
    destination_dataset_table='your_project.your_dataset.arima_plus_forecast',
    write_disposition='WRITE_TRUNCATE',
    dag=dag,
)

# Task 6b: Forecasting with Poisson Model for Trend Component
forecast_poisson_task = BigQueryExecuteQueryOperator(
    task_id='forecast_poisson',
    sql="""
        SELECT
            forecast_timestamp,
            forecast_value
        FROM
            ML.FORECAST(
                MODEL `your_project.your_dataset.trend_forecast_model`,
                STRUCT(12 AS horizon)
            )
    """,
    destination_dataset_table='your_project.your_dataset.poisson_forecast',
    write_disposition='WRITE_TRUNCATE',
    dag=dag,
)

# Task 7: Combine Forecasts
def combine_forecasts():
    client = bigquery.Client()
    query = """
        SELECT
            a.forecast_timestamp,
            a.forecast_value + b.forecast_value AS combined_forecast
        FROM
            `your_project.your_dataset.arima_plus_forecast` AS a
        JOIN
            `your_project.your_dataset.poisson_forecast` AS b
        ON
            a.forecast_timestamp = b.forecast_timestamp
    """
    combined_forecast_df = client.query(query).to_dataframe()
    combined_forecast_df.to_csv("gs://furniture-forecasting-bucket/predictions/combined_forecast.csv", index=False)

combine_forecast_task = PythonOperator(
    task_id='combine_forecasts',
    python_callable=combine_forecasts,
    dag=dag,
)

# Task 8a: Evaluate ARIMA_PLUS Model
evaluate_arima_task = BigQueryExecuteQueryOperator(
    task_id='evaluate_arima_plus',
    sql="""
        SELECT
            *
        FROM
            ML.EVALUATE(
                MODEL `your_project.your_dataset.seasonal_forecast_model`,
                STRUCT(0.8 AS split_fraction)
            )
    """,
    destination_dataset_table='your_project.your_dataset.arima_plus_eval_metrics',
    write_disposition='WRITE_TRUNCATE',
    dag=dag,
)

# Task 8b: Evaluate Poisson Model
evaluate_poisson_task = BigQueryExecuteQueryOperator(
    task_id='evaluate_poisson',
    sql="""
        SELECT
            *
        FROM
            ML.EVALUATE(
                MODEL `your_project.your_dataset.trend_forecast_model`,
                STRUCT(0.8 AS split_fraction)
            )
    """,
    destination_dataset_table='your_project.your_dataset.poisson_eval_metrics',
    write_disposition='WRITE_TRUNCATE',
    dag=dag,
)

# Task 9: Check Retraining Condition
def check_retraining_condition():
    client = bigquery.Client()
    query = """
        SELECT mean_absolute_error
        FROM `your_project.your_dataset.arima_plus_eval_metrics`
        ORDER BY evaluation_date DESC
        LIMIT 1
    """
    results = client.query(query).result()
    mae = next(results).mean_absolute_error

    # Define your threshold for retraining
    threshold = 10.0
    if mae > threshold:
        return 'retrain_model'
    else:
        return 'skip_retrain'

check_retraining_task = BranchPythonOperator(
    task_id='check_retraining_condition',
    python_callable=check_retraining_condition,
    dag=dag,
)

# Task 10a: Retrain ARIMA_PLUS Model (if required)
retrain_arima_task = BigQueryCreateModelOperator(
    task_id='retrain_arima_plus',
    model_id='your_project.your_dataset.seasonal_forecast_model',
    sql="""
        CREATE OR REPLACE MODEL `your_project.your_dataset.seasonal_forecast_model`
        OPTIONS(
            model_type = 'ARIMA_PLUS',
            time_series_timestamp_col = 'date',
            time_series_data_col = 'seasonal',
            horizon = 12,
            auto_arima = TRUE
        ) AS
        SELECT
            date,
            seasonal
        FROM
            `your_project.your_dataset.stl_decomposed_sales_data`
        ORDER BY
            date
    """,
    dag=dag,
)

# Task 10b: Retrain Poisson Model (if required)
retrain_poisson_task = BigQueryCreateModelOperator(
    task_id='retrain_poisson',
    model_id='your_project.your_dataset.trend_forecast_model',
    sql="""
        CREATE OR REPLACE MODEL `your_project.your_dataset.trend_forecast_model`
        OPTIONS(
            model_type = 'GLM',
            model_registry = 'POISSON_REGRESSION',
            time_series_timestamp_col = 'date',
            time_series_data_col = 'trend'
        ) AS
        SELECT
            date,
            trend
        FROM
            `your_project.your_dataset.stl_decomposed_sales_data`
        ORDER BY
            date
    """,
    dag=dag,
)

# Task 11: Dummy Task to Skip Retraining
skip_retrain_task = DummyOperator(task_id='skip_retrain', dag=dag)

# DAG Dependencies
ingest_task >> preprocess_task >> load_to_bigquery_task >> stl_decomposition_task
stl_decomposition_task >> [train_arima_task, train_poisson_task]
train_arima_task >> forecast_arima_task
train_poisson_task >> forecast_poisson_task
[forecast_arima_task, forecast_poisson_task] >> combine_forecast_task
combine_forecast_task >> [evaluate_arima_task, evaluate_poisson_task]
[evaluate_arima_task, evaluate_poisson_task] >> check_retraining_task
check_retraining_task >> [retrain_arima_task, retrain_poisson_task, skip_retrain_task]


