In [None]:
!pip install kfp

Collecting kfp
  Downloading kfp-2.12.1.tar.gz (345 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m345.4/345.4 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting kfp-pipeline-spec==0.6.0 (from kfp)
  Downloading kfp_pipeline_spec-0.6.0-py3-none-any.whl.metadata (293 bytes)
Collecting kfp-server-api<2.5.0,>=2.1.0 (from kfp)
  Downloading kfp_server_api-2.4.0.tar.gz (83 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting kubernetes<31,>=8.0.0 (from kfp)
  Downloading kubernetes-30.1.0-py2.py3-none-any.whl.metadata (1.5 kB)
Collecting urllib3<2.0.0 (from kfp)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Downloading kfp_pipe

In [None]:
# --- Imports and Config ---
import kfp
from kfp import dsl
from kfp.dsl import Input, Output, Model, Dataset, Artifact, OutputPath
from typing import NamedTuple, Optional
import google.cloud.aiplatform as aip
import os
import shutil
import tempfile
import logging # Import logging

# --- Configuration ---
PROJECT_ID = "YOUR-PROJECT-ID"
REGION = "us-central1"
PIPELINE_ROOT = "gs://YOUR-BUCKET/meridian-pipeline-root"
BQ_DATASET = "meridiansampledataset" # Your dataset
BQ_TABLE_NAME = "meridiantable" # your table name
BQ_SUMMARY_TABLE_NAME = "meridian_media_summary_report" # Name for the new BQ out table
OUTPUT_GCS_DIR = f"{PIPELINE_ROOT}/outputs"
ROI_MU = 0.2
ROI_SIGMA = 0.9
N_CHAINS = 7
N_ADAPT = 500
N_BURNIN = 500
N_KEEP = 1000
RANDOM_SEED = 1
REPORT_START_DATE = '2021-01-25'
REPORT_END_DATE = '2024-01-15'
STANDARD_BASE_IMAGE = "python:3.10-slim"
GPU_BASE_IMAGE = "gcr.io/deeplearning-platform-release/tf-gpu.2-15.py310"
MERIDIAN_MODEL_FILENAME = "model_save.pkl"
PIPELINE_NAME = "meridian-mmm-gpu-bq-pipeline" # pipeline name
PIPELINE_JSON = f"{PIPELINE_NAME}.json"

# --- Configure logging for components ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


# --- train_meridian_model ---
@dsl.component(
    base_image=GPU_BASE_IMAGE,
    packages_to_install=[
        "google-meridian[and-cuda]", "numpy<2","tensorflow_probability", "pandas",
        "google-cloud-storage", "arviz", "matplotlib", "dill",
        "google-cloud-bigquery","db-dtypes",
        "pyarrow"
    ],
)
def train_meridian_model(
    project_id: str,
    bq_dataset: str,
    bq_table_name: str,
    roi_mu: float, roi_sigma: float, n_chains: int,
    n_adapt: int, n_burnin: int, n_keep: int, seed: int,
    output_model: Output[Model],
):
    # --- Imports inside component ---
    import numpy as np
    import pandas as pd
    import tensorflow as tf
    import tensorflow_probability as tfp
    import os
    import logging
    import time
    import datetime
    from google.cloud import bigquery
    from meridian import constants
    from meridian.data import load
    from meridian.model import model, spec, prior_distribution
    import dill

    # --- Reconfigure logging inside component if needed, or rely on root config ---
    # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Optional reconfig

    MERIDIAN_MODEL_FILENAME = "model_save.pkl" # Model Name

    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        logging.info(f"GPUs available: {gpus}")
        try:
            for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
            logging.info("Enabled memory growth for GPUs.")
        except RuntimeError as e: logging.error(f"Error setting memory growth: {e}")
    else: logging.warning("No GPU detected by TensorFlow. Running on CPU.")

    # --- Define Mappings ---
    coord_to_columns = load.CoordToColumns(
        time='time', geo='geo', controls=['GQV', 'Competitor_Sales'], population='population',
        kpi='conversions', revenue_per_kpi='revenue_per_conversion',
        media=[f'Channel{i}_impression' for i in range(5)], ## HERE FOR THE SAMPLE DATASET, change with your own channels names
        media_spend=[f'Channel{i}_spend' for i in range(5)], ## HERE FOR THE SAMPLE DATASET, change with your own channels names
        organic_media=['Organic_channel0_impression'], non_media_treatments=['Promo'],
    )
    correct_media_to_channel = {f'Channel{i}_impression': f'Channel_{i}' for i in range(5)} ## HERE FOR THE SAMPLE DATASET, change with your own channels names
    correct_media_spend_to_channel = {f'Channel{i}_spend': f'Channel_{i}' for i in range(5)} ## HERE FOR THE SAMPLE DATASET, change with your own channels names
    # ----------------------------------------------------------------------

    # --- BigQuery Data Loading Start ---
    bq_table_full_id = f"{project_id}.{bq_dataset}.{bq_table_name}"
    logging.info(f"Attempting to load data from BigQuery table: {bq_table_full_id}")

    try:
        client = bigquery.Client(project=project_id)
        logging.info("BigQuery client created successfully.")
    except Exception as e:
        logging.error(f"Failed to create BigQuery client: {e}")
        raise e

    sql_query = f"SELECT * FROM `{bq_table_full_id}`"
    logging.info(f"Executing query: {sql_query}")

    try:
        df = client.query(sql_query).to_dataframe()
        logging.info(f"Successfully loaded {len(df)} rows and {len(df.columns)} columns from BigQuery.")

        # --- Convert time column, BQ_To_Dataframe converts the datetime so we need to convert it yyyy-mm-dd ---
        time_col_name = coord_to_columns.time
        if time_col_name in df.columns:
            logging.info(f"Converting time column '{time_col_name}' to string format 'YYYY-MM-DD'")
            if pd.api.types.is_datetime64_any_dtype(df[time_col_name]) or isinstance(df[time_col_name].iloc[0], pd.Timestamp) or isinstance(df[time_col_name].iloc[0], datetime.date):
                 df[time_col_name] = pd.to_datetime(df[time_col_name]).dt.strftime('%Y-%m-%d')
                 logging.info(f"Conversion of '{time_col_name}' complete.")
            elif pd.api.types.is_string_dtype(df[time_col_name]):
                 logging.info(f"Column '{time_col_name}' is already string type. Checking format (first row): {df[time_col_name].iloc[0]}")
            else:
                 logging.warning(f"Column '{time_col_name}' is not a recognized datetime or string type ({df[time_col_name].dtype}). Meridian might still fail.")
        else:
            logging.error(f"Specified time column '{time_col_name}' not found in DataFrame!")
            raise ValueError(f"Time column '{time_col_name}' defined in coord_to_columns not found in BigQuery results.")
        # --- End Time Conversion ---

        logging.info("First 5 rows of loaded data (post-conversion):")
        logging.info(df.head().to_string()) # Use to_string for logging DataFrames

    except Exception as e:
        logging.error(f"Error loading data from BigQuery or processing DataFrame: {e}")
        raise e

    # --- Use DataFrameDataLoader ---
    logging.info("Initializing Meridian DataFrameDataLoader...")
    try:
        loader = load.DataFrameDataLoader(
            df=df, # Pass the DataFrame loaded from BQ
            kpi_type='non_revenue',
            coord_to_columns=coord_to_columns,
            media_to_channel=correct_media_to_channel,
            media_spend_to_channel=correct_media_spend_to_channel,
        )
        data = loader.load()
        logging.info("Data successfully loaded into Meridian InputData format.")
    except Exception as e:
        logging.error(f"Error during Meridian data loading process (DataFrameDataLoader): {e}")
        raise e
    # --- BigQuery Data Loading End ---


    logging.info("Configuring model...")
    prior = prior_distribution.PriorDistribution(
        roi_m=tfp.distributions.LogNormal(roi_mu, roi_sigma, name=constants.ROI_M)
    )
    model_spec_obj = spec.ModelSpec(prior=prior)
    mmm = model.Meridian(input_data=data, model_spec=model_spec_obj) # Use the 'data' object loaded from BQ

    logging.info("Sampling prior...")
    mmm.sample_prior(500)
    logging.info(f"Sampling posterior with {n_chains} chains...")
    start_time = time.time()
    mmm.sample_posterior(
        n_chains=n_chains, n_adapt=n_adapt, n_burnin=n_burnin, n_keep=n_keep, seed=seed
    )
    end_time = time.time()
    logging.info(f"Posterior sampling complete. Duration: {end_time - start_time:.2f} seconds.")

    save_file_path = os.path.join(output_model.path, MERIDIAN_MODEL_FILENAME)
    logging.info(f"Saving model artifact using model.save_mmm to file: {save_file_path}")
    try:
        os.makedirs(output_model.path, exist_ok=True)
        model.save_mmm(mmm, save_file_path)
        logging.info("Model saved successfully using meridian.model.model.save_mmm.")
    except Exception as e:
        logging.error(f"meridian.model.model.save_mmm failed: {e}")
        raise e

    output_model.metadata["framework"] = "Meridian"
    output_model.metadata["saved_filename"] = MERIDIAN_MODEL_FILENAME
    output_model.metadata["description"] = f"Trained Meridian MMM model (BQ Input, saved via save_mmm to {MERIDIAN_MODEL_FILENAME})"
    logging.info("Training component finished.")


# --- generate_summary_report HTML ---
@dsl.component(
    base_image=STANDARD_BASE_IMAGE,
    packages_to_install=[
        "google-meridian",
        "tensorflow", "tensorflow_probability",
        "pandas", "numpy", "arviz", "matplotlib", "google-cloud-storage","dill"
    ],
)
def generate_summary_report(
    model_artifact: Input[Model],
    output_gcs_dir: str,
    report_filename: str,
    start_date: str,
    end_date: str,
    summary_report_artifact: Output[Artifact],
):
    import os
    import logging
    import time
    import tempfile
    from meridian.analysis import summarizer # Use summarizer for HTML report
    from meridian.model import model
    from google.cloud import storage
    from urllib.parse import urlparse
    import dill # Ensure dill is imported, needed by load_mmm

    # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Optional reconfig
    MERIDIAN_MODEL_FILENAME = "model_save.pkl"
    def upload_local_file_to_gcs(local_path: str, gcs_uri: str):
        storage_client = storage.Client()
        parsed_uri = urlparse(gcs_uri)
        bucket_name = parsed_uri.netloc
        destination_blob_name = parsed_uri.path.lstrip('/')
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(destination_blob_name)
        blob.upload_from_filename(local_path)
        logging.info(f"File {local_path} uploaded to {gcs_uri}")

    model_dir_path = model_artifact.path
    load_file_path = os.path.join(model_dir_path, MERIDIAN_MODEL_FILENAME)
    logging.info(f"Attempting to load model from file: {load_file_path}")
    if not os.path.exists(load_file_path):
        raise FileNotFoundError(f"Expected model file {MERIDIAN_MODEL_FILENAME} not found in {model_dir_path}")
    try:
        mmm = model.load_mmm(load_file_path)
        logging.info("Model loaded successfully for HTML report generation.")
    except Exception as e:
        logging.error(f"Model loading failed: {e}")
        raise e

    if not output_gcs_dir.startswith("gs://"):
        raise ValueError("output_gcs_dir must be a GCS path (gs://...)")
    final_gcs_uri = os.path.join(output_gcs_dir, report_filename)

    with tempfile.TemporaryDirectory() as temp_dir:
        logging.info(f"Generating summary HTML report locally in: {temp_dir}")
        local_report_source_path = os.path.join(temp_dir, report_filename)
        try:
            # Use Summarizer for the HTML report as in original code
            mmm_summarizer = summarizer.Summarizer(mmm)
            mmm_summarizer.output_model_results_summary(
                filename=report_filename,
                filepath=temp_dir,
                start_date=start_date,
                end_date=end_date
            )
            logging.info(f"Meridian saved HTML report locally to: {local_report_source_path}")
            if not os.path.exists(local_report_source_path):
                logging.error(f"Meridian did not create the expected local HTML report file: {local_report_source_path}")
                raise FileNotFoundError(f"HTML Report file not created locally by Meridian at {local_report_source_path}")
            logging.info(f"Manually uploading {local_report_source_path} to {final_gcs_uri}")
            upload_local_file_to_gcs(local_report_source_path, final_gcs_uri)
            summary_report_artifact.uri = final_gcs_uri
            summary_report_artifact.metadata["gcs_path"] = final_gcs_uri
            summary_report_artifact.metadata["filename"] = report_filename
            logging.info(f"Set KFP artifact URI for HTML report to: {summary_report_artifact.uri}")
        except Exception as e:
            logging.error(f"Failed to generate or upload HTML summary report: {e}")
            raise e
    logging.info("HTML Summary report component finished.")


# --- Generate and Save Summary Table to BigQuery ---
@dsl.component(
    base_image=STANDARD_BASE_IMAGE,
    packages_to_install=[
        "google-meridian", # Base meridian package
        "tensorflow", "tensorflow_probability", # Dependencies for meridian
        "pandas",
        "google-cloud-bigquery",
        "pandas_gbq",
        "google-cloud-storage", # Needed for artifact loading
        "dill", # For loading the model
        "pyarrow",
        "db-dtypes"
    ]
)
def generate_and_save_summary_bq(
    model_artifact: Input[Model],
    project_id: str,
    bq_dataset: str,
    bq_table_name: str, # Target table for this summary
    bq_output_table: Output[Artifact], # Output artifact to track the BQ table
):
    import os
    import logging
    import pandas_gbq
    import pandas as pd
    from meridian.analysis import visualizer # Use visualizer as per user image for the table
    from meridian.model import model
    from google.cloud import bigquery
    import dill # Ensure dill is imported if needed by load_mmm

    # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Optional reconfig
    MERIDIAN_MODEL_FILENAME = "model_save.pkl"

    model_dir_path = model_artifact.path
    load_file_path = os.path.join(model_dir_path, MERIDIAN_MODEL_FILENAME)
    logging.info(f"Attempting to load model from file: {load_file_path} for BQ summary")
    if not os.path.exists(load_file_path):
        raise FileNotFoundError(f"Expected model file {MERIDIAN_MODEL_FILENAME} not found in {model_dir_path}")

    try:
        mmm = model.load_mmm(load_file_path)
        logging.info("Model loaded successfully for BQ summary generation.")
    except Exception as e:
        logging.error(f"Model loading failed: {e}")
        raise e

    logging.info("Generating media summary table using visualizer.MediaSummary...")
    try:
        # Instantiate the visualizer's MediaSummary class
        media_summary_visualizer = visualizer.MediaSummary(mmm)
        summary_df = media_summary_visualizer.summary_table()
        logging.info("Successfully generated summary DataFrame.")
        logging.info("First 5 rows of summary DataFrame:")
        logging.info(summary_df.head().to_string())
        logging.info("\nDataFrame Info:")
        # Use a StringIO buffer to capture info() output for logging
        import io
        buffer = io.StringIO()
        summary_df.info(buf=buffer)
        logging.info(buffer.getvalue())

    except AttributeError:
         logging.error("AttributeError: Could not find 'MediaSummary' or 'summary_table' in 'meridian.analysis.visualizer'. "
                       "Perhaps the class/method name is different or in another module (e.g., summarizer)?")
         # --- Fallback attempt using Summarizer if Visualizer fails ---
         logging.warning("Attempting fallback using meridian.analysis.summarizer.Summarizer...")
         try:
             from meridian.analysis import summarizer
             mmm_summarizer = summarizer.Summarizer(mmm)
             if hasattr(mmm_summarizer, 'get_summary_dataframe'):
                 summary_df = mmm_summarizer.get_summary_dataframe() # Hypothetical method
                 logging.info("Successfully generated summary DataFrame using Summarizer fallback.")
             elif hasattr(mmm_summarizer, '_create_summary_table'): # Check private methods if desperate
                  summary_df = mmm_summarizer._create_summary_table() # Highly discouraged, likely to break
                  logging.info("Successfully generated summary DataFrame using Summarizer fallback (_create_summary_table).")
             else:
                 logging.error("Fallback failed: Summarizer does not have a known method to return the summary DataFrame.")
                 raise ValueError("Could not generate summary DataFrame using known Meridian methods.")
         except Exception as fallback_e:
             logging.error(f"Error during Summarizer fallback: {fallback_e}")
             raise fallback_e # Re-raise the fallback error
         # --- End Fallback attempt ---
    except Exception as e:
        logging.error(f"Failed to generate summary table: {e}")
        raise e

    # --- Prepare DataFrame for BigQuery ---
    # BQ prefers snake_case column names without special characters or spaces
    original_columns = summary_df.columns.tolist()
    new_columns = []
    for col in original_columns:
        new_col = str(col).lower() # Convert to string just in case, then lowercase
        new_col = new_col.replace('% ', 'pct_').replace(' ', '_').replace('.', '').replace('(', '').replace(')', '')
        new_columns.append(new_col)
    summary_df.columns = new_columns
    logging.info(f"Renamed DataFrame columns for BQ compatibility: {new_columns}")

    # Convert complex object columns (like tuples represented as strings) to plain strings
    # This prevents potential 'to_gbq' errors with complex types
    for col in summary_df.columns:
        if summary_df[col].dtype == 'object':
            # Check if the first non-null value looks like a tuple/list string representation
            first_val = summary_df[col].dropna().iloc[0] if not summary_df[col].dropna().empty else None
            if isinstance(first_val, (tuple, list)) or (isinstance(first_val, str) and first_val.strip().startswith(('(', '['))):
                 logging.info(f"Converting object column '{col}' to string for BQ.")
                 summary_df[col] = summary_df[col].astype(str)
            elif pd.api.types.is_numeric_dtype(summary_df[col].dropna()):
                # Sometimes mixed types get 'object', try converting back to numeric if possible
                 try:
                     summary_df[col] = pd.to_numeric(summary_df[col])
                     logging.info(f"Converted object column '{col}' back to numeric.")
                 except: # Keep as object/string if conversion fails
                     logging.warning(f"Could not convert object column '{col}' to numeric, keeping as object/string.")
                     summary_df[col] = summary_df[col].astype(str) # Ensure string if not numeric
            else: # Default to string conversion for other objects
                 logging.info(f"Converting object column '{col}' to string for BQ.")
                 summary_df[col] = summary_df[col].astype(str)


    # Handle potential 'nan' strings from conversions if needed
    summary_df = summary_df.fillna(pd.NA).replace(['nan', 'NaN', 'None', '(nan, nan)', 'nan (nan, nan)'], [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA]) # Replace various nan strings with proper NA for BQ

    # Reset index if it's meaningful (like the 0, 1, 2... row numbers) to make it a column
    if summary_df.index.name is None and pd.api.types.is_integer_dtype(summary_df.index):
         summary_df = summary_df.reset_index()
         # Rename the new 'index' column if desired
         index_col_name = 'original_index'
         if index_col_name in summary_df.columns: # Avoid collision
             index_col_name = 'row_index'
         summary_df = summary_df.rename(columns={'index': index_col_name})
         logging.info(f"Reset DataFrame index and added column '{index_col_name}'.")

    logging.info("Final DataFrame Schema before BQ Upload:")
    buffer = io.StringIO()
    summary_df.info(buf=buffer)
    logging.info(buffer.getvalue())
    logging.info("First 5 rows before BQ Upload:")
    logging.info(summary_df.head().to_string())


    # --- Save to BigQuery ---
    bq_table_full_id = f"{project_id}.{bq_dataset}.{bq_table_name}"
    logging.info(f"Attempting to save summary DataFrame to BigQuery table: {bq_table_full_id}")

    try:
        client = bigquery.Client(project=project_id)
        logging.info("BigQuery client created successfully.")

        # Use pandas_gbq or DataFrame.to_gbq (uses pandas_gbq backend)
        summary_df.to_gbq(
            destination_table=f"{bq_dataset}.{bq_table_name}",
            project_id=project_id,
            if_exists='replace', # Options: 'fail', 'replace', 'append'
            # Optional: Define schema explicitly for more control if needed
            # table_schema=[{'name': 'col1', 'type': 'STRING'}, ...]
        )
        logging.info(f"Successfully wrote summary data to BigQuery table: {bq_table_full_id}")

        # Set output artifact metadata
        bq_output_table.metadata["table_id"] = bq_table_full_id
        bq_output_table.uri = f"https://console.cloud.google.com/bigquery?project={project_id}&ws=!1m5!1m4!4m3!1s{project_id}!2s{bq_dataset}!3s{bq_table_name}" # URI to the BQ table

    except Exception as e:
        logging.error(f"Failed to write DataFrame to BigQuery: {e}")
        # Log dataframe details that might cause issues
        logging.error(f"DataFrame dtypes:\n{summary_df.dtypes}")
        raise e

    logging.info("Generate and Save Summary to BQ component finished.")


# --- run_budget_optimization (Unchanged) ---
@dsl.component(
    base_image=STANDARD_BASE_IMAGE,
    packages_to_install=[
        "google-meridian",
        "pandas", "numpy", "google-cloud-storage", "dill"
    ],
)
def run_budget_optimization(
    model_artifact: Input[Model],
    output_gcs_dir: str,
    report_filename: str,
    optimization_report_artifact: Output[Artifact],
):
    # --- This component's *internal* code does not need to change ---
    import os
    import logging
    import time
    import tempfile
    from meridian.analysis import optimizer
    from meridian.model import model
    from google.cloud import storage
    from urllib.parse import urlparse
    import dill # Ensure dill is imported if needed by load_mmm

    # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Optional reconfig
    MERIDIAN_MODEL_FILENAME = "model_save.pkl"
    def upload_local_file_to_gcs(local_path: str, gcs_uri: str):
        storage_client = storage.Client()
        parsed_uri = urlparse(gcs_uri)
        bucket_name = parsed_uri.netloc
        destination_blob_name = parsed_uri.path.lstrip('/')
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(destination_blob_name)
        blob.upload_from_filename(local_path)
        logging.info(f"File {local_path} uploaded to {gcs_uri}")

    model_dir_path = model_artifact.path
    load_file_path = os.path.join(model_dir_path, MERIDIAN_MODEL_FILENAME)
    logging.info(f"Attempting to load model from file: {load_file_path}")
    if not os.path.exists(load_file_path):
        raise FileNotFoundError(f"Expected model file {MERIDIAN_MODEL_FILENAME} not found in {model_dir_path}")
    try:
        mmm = model.load_mmm(load_file_path)
        logging.info("Model loaded successfully.")
    except Exception as e:
        logging.error(f"Model loading failed: {e}")
        raise e

    if not output_gcs_dir.startswith("gs://"):
        raise ValueError("output_gcs_dir must be a GCS path (gs://...)")
    final_gcs_uri = os.path.join(output_gcs_dir, report_filename)

    with tempfile.TemporaryDirectory() as temp_dir:
        logging.info(f"Running optimization and generating report locally in: {temp_dir}")
        local_report_source_path = os.path.join(temp_dir, report_filename)
        try:
            budget_optimizer = optimizer.BudgetOptimizer(mmm)
            optimization_results = budget_optimizer.optimize()
            logging.info("Optimization calculation complete.")
            optimization_results.output_optimization_summary(
                filename=report_filename,
                filepath=temp_dir
            )
            logging.info(f"Meridian saved optimization report locally to: {local_report_source_path}")
            if not os.path.exists(local_report_source_path):
                 logging.error(f"Meridian did not create the expected local report file: {local_report_source_path}")
                 raise FileNotFoundError(f"Optimization report file not created locally by Meridian at {local_report_source_path}")
            logging.info(f"Manually uploading {local_report_source_path} to {final_gcs_uri}")
            upload_local_file_to_gcs(local_report_source_path, final_gcs_uri)
            optimization_report_artifact.uri = final_gcs_uri
            optimization_report_artifact.metadata["gcs_path"] = final_gcs_uri
            optimization_report_artifact.metadata["filename"] = report_filename
            logging.info(f"Set KFP artifact URI to: {optimization_report_artifact.uri}")
        except Exception as e:
            logging.error(f"Failed during budget optimization or reporting/uploading: {e}")
            raise e
    logging.info("Optimization component finished.")


# --- Pipeline Definition ---
@dsl.pipeline(
    name=PIPELINE_NAME,
    description="Runs Meridian MMM (GPU) reading from BigQuery, saves summary table to BQ",
    pipeline_root=PIPELINE_ROOT,
)
def meridian_pipeline(
    project_id: str = PROJECT_ID,
    bq_dataset: str = BQ_DATASET,
    bq_table_name: str = BQ_TABLE_NAME, # Input data table
    summary_bq_table_name: str = BQ_SUMMARY_TABLE_NAME, # Output summary table
    output_gcs_dir: str = OUTPUT_GCS_DIR,
    roi_mu: float = ROI_MU,
    roi_sigma: float = ROI_SIGMA,
    n_chains: int = N_CHAINS,
    n_adapt: int = N_ADAPT,
    n_burnin: int = N_BURNIN,
    n_keep: int = N_KEEP,
    seed: int = RANDOM_SEED,
    report_start_date: str = REPORT_START_DATE,
    report_end_date: str = REPORT_END_DATE,
    summary_report_filename: str = "summary_output.html", # HTML report
    optimization_report_filename: str = "optimization_output.html",
):
    # Step 1: Train Model
    train_task = train_meridian_model(
        project_id=project_id,
        bq_dataset=bq_dataset,
        bq_table_name=bq_table_name, # Input table
        roi_mu=roi_mu, roi_sigma=roi_sigma,
        n_chains=n_chains, n_adapt=n_adapt, n_burnin=n_burnin, n_keep=n_keep, seed=seed,
    )
    train_task.set_cpu_limit("16").set_memory_limit("64G")
    train_task.set_accelerator_limit(1).set_accelerator_type('NVIDIA_TESLA_T4')

    # Step 2: Generate Summary Table and Save to BigQuery
    save_summary_bq_task = generate_and_save_summary_bq(
        model_artifact=train_task.outputs["output_model"],
        project_id=project_id,
        bq_dataset=bq_dataset,
        bq_table_name=summary_bq_table_name, # Output table name for summary
    )
    save_summary_bq_task.set_cpu_limit("16").set_memory_limit("64G") # Adjust resources as needed

    # Step 3: Generate HTML Summary Report (Runs in parallel with BQ save if desired, or after)
    summary_html_task = generate_summary_report(
        model_artifact=train_task.outputs["output_model"],
        output_gcs_dir=output_gcs_dir,
        report_filename=summary_report_filename,
        start_date=report_start_date,
        end_date=report_end_date,
    )
    # Can run after BQ save by adding: .after(save_summary_bq_task)
    summary_html_task.set_cpu_limit("16").set_memory_limit("64G") # Keep original resources

    # Step 4: Run Budget Optimization (Runs in parallel with reports if desired, or after)
    optimization_task = run_budget_optimization(
        model_artifact=train_task.outputs["output_model"],
        output_gcs_dir=output_gcs_dir,
        report_filename=optimization_report_filename,
    )
    # Can run after reports by adding: .after(summary_html_task, save_summary_bq_task)
    optimization_task.set_cpu_limit("16").set_memory_limit("64G") # Keep original resources


# --- Pipeline Compilation and Execution ---
if __name__ == "__main__":
    kfp.compiler.Compiler().compile(
        pipeline_func=meridian_pipeline, package_path=PIPELINE_JSON
    )
    print(f"Pipeline compiled to {PIPELINE_JSON}")

    aip.init(project=PROJECT_ID, location=REGION, staging_bucket=PIPELINE_ROOT)
    print(f"Initialized Vertex AI SDK for project {PROJECT_ID} in {REGION}")

    job = aip.PipelineJob(
        display_name=PIPELINE_NAME, # Use updated name
        template_path=PIPELINE_JSON,
        pipeline_root=PIPELINE_ROOT,
        parameter_values={
            "project_id": PROJECT_ID,
            "bq_dataset": BQ_DATASET,
            "bq_table_name": BQ_TABLE_NAME, # Input table
            "summary_bq_table_name": BQ_SUMMARY_TABLE_NAME, # Output summary table
            "output_gcs_dir": OUTPUT_GCS_DIR,
            "roi_mu": ROI_MU,
            "roi_sigma": ROI_SIGMA,
            "n_chains": N_CHAINS,
            "n_adapt": N_ADAPT,
            "n_burnin": N_BURNIN,
            "n_keep": N_KEEP,
            "seed": RANDOM_SEED,
            "report_start_date": REPORT_START_DATE,
            "report_end_date": REPORT_END_DATE,
            "summary_report_filename": "summary_output.html",
            "optimization_report_filename": "optimization_output.html",
        },
        enable_caching=True, # Caching to edit as desired
    )

    print("Submitting pipeline job...")
    job.submit()
    print(f"Pipeline job submitted. View in Cloud Console: {job._dashboard_uri()}")

Pipeline compiled to meridian-mmm-gpu-bq-pipeline-v2.json
Initialized Vertex AI SDK for project cloud-llm-preview2 in us-central1
Submitting pipeline job...


INFO:google.cloud.aiplatform.pipeline_jobs:Creating PipelineJob
INFO:google.cloud.aiplatform.pipeline_jobs:PipelineJob created. Resource name: projects/323656405210/locations/us-central1/pipelineJobs/meridian-mmm-gpu-bq-pipeline-v2-20250402170850
INFO:google.cloud.aiplatform.pipeline_jobs:To use this PipelineJob in another session:
INFO:google.cloud.aiplatform.pipeline_jobs:pipeline_job = aiplatform.PipelineJob.get('projects/323656405210/locations/us-central1/pipelineJobs/meridian-mmm-gpu-bq-pipeline-v2-20250402170850')
INFO:google.cloud.aiplatform.pipeline_jobs:View Pipeline Job:
https://console.cloud.google.com/vertex-ai/locations/us-central1/pipelines/runs/meridian-mmm-gpu-bq-pipeline-v2-20250402170850?project=323656405210


Pipeline job submitted. View in Cloud Console: https://console.cloud.google.com/vertex-ai/locations/us-central1/pipelines/runs/meridian-mmm-gpu-bq-pipeline-v2-20250402170850?project=323656405210
