In [None]:
from google.cloud import storage
from vertexai.language_models import TextEmbeddingModel
import pandas as pd
import numpy as np
import csv
import os
import json
import logging
from google.api_core import retry
from google.api_core import exceptions

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# GCS client
storage_client = storage.Client()
bucket_name = "mw-llm-poc-dataproc"
file_path = "merged_dataset.csv"
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(file_path)

# Local paths
local_file = "/tmp/merged_dataset_1.csv"
embeddings_csv = "/tmp/embeddings.csv"
checkpoint_file = "/tmp/checkpoint.json"

# Configurable summary column name
SUMMARY_COLUMN = "Human_Readable_Summary"

# Download file locally
logger.info(f"Downloading {file_path} from GCS to {local_file}")
blob.download_to_filename(local_file)

# Debug: Check column names
try:
    df = pd.read_csv(local_file, nrows=5)
    logger.info(f"Columns in {local_file}: {df.columns.tolist()}")
    logger.info(f"First few rows:\n{df.head()}")
    if SUMMARY_COLUMN not in df.columns:
        logger.error(f"Column '{SUMMARY_COLUMN}' not found. Available columns: {df.columns.tolist()}")
        raise KeyError(f"Column '{SUMMARY_COLUMN}' not found in {local_file}")
except Exception as e:
    logger.error(f"Error reading {local_file}: {e}")
    raise

# Initialize the embedding model
logger.info("Initializing textembedding-gecko@003 model")
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@003")

# Function to generate embeddings
def get_embeddings(texts):
    try:
        embeddings = model.get_embeddings(texts)
        return [embedding.values for embedding in embeddings]
    except Exception as e:
        logger.error(f"Error generating embeddings: {e}")
        return None

# Checkpoint management
def load_checkpoint():
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, "r") as f:
            return json.load(f)
    checkpoint_blob = bucket.blob("checkpoint.json")
    if checkpoint_blob.exists():
        checkpoint_blob.download_to_filename(checkpoint_file)
        with open(checkpoint_file, "r") as f:
            return json.load(f)
    return {
        "step1_rows_processed": 0,
        "step1_complete": False,
        "step2_complete": False,
        "step3_complete": False,
        "step4_complete": False,
        "step5_complete": False
    }

def save_checkpoint(rows_processed, step1_complete=False, step2_complete=False, step3_complete=False, step4_complete=False, step5_complete=False):
    checkpoint = load_checkpoint()
    checkpoint["step1_rows_processed"] = rows_processed
    checkpoint["step1_complete"] = step1_complete
    checkpoint["step2_complete"] = step2_complete
    checkpoint["step3_complete"] = step3_complete
    checkpoint["step4_complete"] = step4_complete
    checkpoint["step5_complete"] = step5_complete
    with open(checkpoint_file, "w") as f:
        json.dump(checkpoint, f)
    @retry.Retry(predicate=retry.if_exception_type(exceptions.TooManyRequests))
    def upload_with_retry():
        checkpoint_blob = bucket.blob("checkpoint.json")
        checkpoint_blob.upload_from_filename(checkpoint_file)
    try:
        upload_with_retry()
    except Exception as e:
        logger.warning(f"Failed to upload checkpoint to GCS: {e}. Continuing with local checkpoint.")

# Step 1: Generate embeddings
checkpoint = load_checkpoint()
rows_processed = checkpoint["step1_rows_processed"]
total_rows_processed = rows_processed  # Initialize total_rows_processed
chunk_size = 10000
batch_size = 50  # Token limit safe

# Check if embeddings file exists locally or in GCS
embeddings_exists_locally = os.path.exists(embeddings_csv)
embeddings_blob = bucket.blob("embeddings.csv")
embeddings_exists_in_gcs = embeddings_blob.exists()

# If Step 1 is marked as complete but the embeddings file is missing, reset the checkpoint
if checkpoint.get("step1_complete", False) and not (embeddings_exists_locally or embeddings_exists_in_gcs):
    logger.warning("Step 1 marked as complete, but embeddings file is missing both locally and in GCS. Resetting Step 1 checkpoint.")
    checkpoint["step1_complete"] = False
    checkpoint["step1_rows_processed"] = 0
    rows_processed = 0
    total_rows_processed = 0
    with open(checkpoint_file, "w") as f:
        json.dump(checkpoint, f)

if checkpoint.get("step1_complete", False):
    logger.info("Step 1 already complete for all 630,000 rows.")
    if not embeddings_exists_locally and embeddings_exists_in_gcs:
        logger.info(f"Downloading {embeddings_csv} from GCS since it’s missing locally")
        embeddings_blob.download_to_filename(embeddings_csv)
else:
    logger.info("Regenerating embeddings.")
    mode = "a" if rows_processed > 0 else "w"
    with open(embeddings_csv, mode, newline="", encoding="utf-8") as f:
        writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL, lineterminator="\n")
        if rows_processed == 0:
            writer.writerow(["id", SUMMARY_COLUMN, "embedding"])

        skiprows = range(1, rows_processed + 1) if rows_processed > 0 else None
        data_chunks = pd.read_csv(local_file, chunksize=chunk_size, skiprows=skiprows, header=0)
        total_rows_processed = rows_processed
        for chunk_idx, chunk in enumerate(data_chunks):
            logger.info(f"Processing chunk {chunk_idx + 1} with {len(chunk)} rows")
            logger.info(f"Chunk columns: {chunk.columns.tolist()}")
            if SUMMARY_COLUMN not in chunk.columns:
                logger.error(f"Column '{SUMMARY_COLUMN}' not found in chunk. Available columns: {chunk.columns.tolist()}")
                raise KeyError(f"Column '{SUMMARY_COLUMN}' not found in chunk")

            batch_summaries = []
            batch_ids = []
            for idx, row in chunk.iterrows():
                summary = str(row[SUMMARY_COLUMN])
                batch_summaries.append(summary)
                batch_ids.append(row["id"])
                if len(batch_summaries) == batch_size:
                    embeddings = get_embeddings(batch_summaries)
                    if embeddings:
                        for id_val, summary, embedding in zip(batch_ids, batch_summaries, embeddings):
                            embedding_str = ','.join(map(str, embedding))
                            writer.writerow([id_val, summary, embedding_str])
                        total_rows_processed += len(batch_summaries)
                    else:
                        logger.warning(f"Skipping batch due to embedding error")
                        for id_val, summary in zip(batch_ids, batch_summaries):
                            writer.writerow([id_val, summary, ""])
                        total_rows_processed += len(batch_summaries)
                    batch_summaries = []
                    batch_ids = []

            if batch_summaries:
                embeddings = get_embeddings(batch_summaries)
                if embeddings:
                    for id_val, summary, embedding in zip(batch_ids, batch_summaries, embeddings):
                        embedding_str = ','.join(map(str, embedding))
                        writer.writerow([id_val, summary, embedding_str])
                    total_rows_processed += len(batch_summaries)
                else:
                    logger.warning(f"Skipping remaining batch due to embedding error")
                    for id_val, summary in zip(batch_ids, batch_summaries):
                        writer.writerow([id_val, summary, ""])
                    total_rows_processed += len(batch_summaries)

            save_checkpoint(total_rows_processed, step1_complete=False)
            logger.info(f"Total rows processed so far: {total_rows_processed}")
            if total_rows_processed >= 630000:
                break
        save_checkpoint(total_rows_processed, step1_complete=True)
    logger.info("Step 1 complete: Embeddings generated and saved to local CSV.")
    # Upload embeddings to GCS for persistence
    embeddings_blob = bucket.blob("embeddings.csv")
    embeddings_blob.upload_from_filename(embeddings_csv)
    logger.info(f"Uploaded {embeddings_csv} to gs://{bucket_name}/embeddings.csv for persistence")

INFO:__main__:Downloading merged_dataset.csv from GCS to /tmp/merged_dataset_1.csv
INFO:__main__:Columns in /tmp/merged_dataset_1.csv: ['id', 'Human_Readable_Summary']
INFO:__main__:First few rows:
   id                             Human_Readable_Summary
0   1  During the period of 2024 June to 2024 July, i...
1   2  During the period of 2024 June to 2024 July, i...
2   3  During the period of 2024 June to 2024 July, i...
3   4  During the period of 2024 June to 2024 July, i...
4   5  During the period of 2024 June to 2024 July, i...
INFO:__main__:Initializing textembedding-gecko@003 model
INFO:__main__:Regenerating embeddings.
INFO:__main__:Processing chunk 1 with 0 rows
INFO:__main__:Chunk columns: ['id', 'Human_Readable_Summary']
INFO:__main__:Total rows processed so far: 625671
INFO:__main__:Step 1 complete: Embeddings generated and saved to local CSV.
INFO:__main__:Uploaded /tmp/embeddings.csv to gs://mw-llm-poc-dataproc/embeddings.csv for persistence


In [None]:
# Step 2: Load embeddings into BigQuery
from google.cloud import bigquery

client = bigquery.Client()
project_id = "856598595188"
dataset_id = "mw_llm_poc"
table_id = "embeddings_final_table"
table_ref = f"{project_id}.{dataset_id}.{table_id}"

if checkpoint.get("step2_complete", False):
    logger.info("Step 2 already complete based on checkpoint.")
else:
    # Ensure embeddings file exists before proceeding
    if not os.path.exists(embeddings_csv):
        if embeddings_blob.exists():
            logger.info(f"Downloading {embeddings_csv} from GCS since it’s missing locally")
            embeddings_blob.download_to_filename(embeddings_csv)
        else:
            raise FileNotFoundError(f"Embeddings file {embeddings_csv} not found locally or in GCS. Cannot proceed with Step 2.")

    logger.info(f"Creating dataset {dataset_id} if it doesn’t exist")
    dataset = bigquery.Dataset(f"{project_id}.{dataset_id}")
    dataset = client.create_dataset(dataset, exists_ok=True)

    schema = [
        bigquery.SchemaField("id", "INTEGER", mode="REQUIRED"),
        bigquery.SchemaField(SUMMARY_COLUMN, "STRING", mode="REQUIRED"),
        bigquery.SchemaField("embedding", "STRING", mode="NULLABLE"),
    ]
    logger.info(f"Creating table {table_id}")
    table = bigquery.Table(table_ref, schema=schema)
    table = client.create_table(table, exists_ok=True)

    gcs_path = f"gs://{bucket_name}/embeddings.csv"
    logger.info(f"Uploading {embeddings_csv} to {gcs_path}")
    blob = bucket.blob("embeddings.csv")
    blob.upload_from_filename(embeddings_csv)

    logger.info(f"Loading {gcs_path} into {table_ref}")
    job_config = bigquery.LoadJobConfig(
        source_format=bigquery.SourceFormat.CSV,
        skip_leading_rows=1,
        schema=schema,
        write_disposition="WRITE_TRUNCATE",
        allow_quoted_newlines=True,
    )
    load_job = client.load_table_from_uri(gcs_path, table_ref, job_config=job_config)
    load_job.result()
    if load_job.errors:
        logger.error(f"Load job failed: {load_job.errors}")
        raise Exception(f"Load job failed: {load_job.errors}")
    else:
        logger.info("Embeddings loaded into BigQuery.")

    logger.info("Converting embedding column to ARRAY<FLOAT64>")
    query = f"""
        CREATE OR REPLACE TABLE `{table_ref}` AS
        SELECT
            id,
            {SUMMARY_COLUMN},
            CASE
                WHEN embedding IS NOT NULL AND embedding != ''
                THEN SPLIT(embedding, ',')
                ELSE NULL
            END AS embedding
        FROM `{table_ref}`
    """
    query_job = client.query(query)
    query_job.result()
    save_checkpoint(total_rows_processed, step2_complete=True)
    logger.info("Step 2 complete: Embeddings stored in BigQuery.")

INFO:__main__:Step 2 already complete based on checkpoint.


In [None]:
#step3: Retrieving Similar Summaries
def retrieve_similar_summaries(query_text, top_k=5):
    query_embedding = get_embeddings([query_text])
    if query_embedding is None:
        logger.error("Failed to generate embedding for query: %s", query_text)
        return []

    query_embedding_str = ', '.join(map(str, query_embedding[0]))
    logger.info("Generated query embedding with length: %d", len(query_embedding[0]))

    similarity_query = f"""
        SELECT
            t.id,
            t.{SUMMARY_COLUMN},
            SUM(CAST(t.embedding[offset(i)] AS FLOAT64) * q_emb) AS dot_product
        FROM `{table_ref}` t
        CROSS JOIN UNNEST([{query_embedding_str}]) AS q_emb WITH OFFSET i
        WHERE ARRAY_LENGTH(t.embedding) = {len(query_embedding[0])}
        GROUP BY t.id, t.{SUMMARY_COLUMN}
        ORDER BY dot_product DESC
        LIMIT {top_k}
    """

    try:
        logger.info("Executing similarity query: %s", query_text)
        query_job = client.query(similarity_query)
        results = query_job.result()
        retrieved = [{"id": row["id"], "summary": row[SUMMARY_COLUMN]} for row in results]
        logger.info("Retrieved %d summaries", len(retrieved))
        return retrieved
    except Exception as e:
        logger.error("Error executing similarity query: %s", str(e))
        return []

# Test Step 3
query = "If Comcast had served 2.5 million households instead, what would its market share have been?"
top_summaries = retrieve_similar_summaries(query, top_k=5)
for summary in top_summaries:
    print(f"ID: {summary['id']}, Summary: {summary['summary']}")
logger.info("Step 3 complete: Retrieval implemented.")

INFO:__main__:Generated query embedding with length: 768
INFO:__main__:Executing similarity query: If Comcast had served 2.5 million households instead, what would its market share have been?
INFO:__main__:Retrieved 5 summaries
INFO:__main__:Step 3 complete: Retrieval implemented.


ID: 202992, Summary: In the USA, during 2024, July, Comcast providing Fixed Broadband services served 38,458,487 households out of a total of 136,378,955 in the country, resulting in a market share of 28.2%.
ID: 201822, Summary: In the state of North Carolina (NC) in USA, during 2024, July, Comcast providing Fixed Broadband services served 19,418 households out of a total of 4,780,848 in the country, resulting in a market share of 0.41%.
ID: 201414, Summary: In the state of Maryland (MD) in USA, during 2024, July, Comcast providing Fixed Broadband services served 1,532,831 households out of a total of 2,812,665 in the country, resulting in a market share of 54.50%.
ID: 201558, Summary: In the state of New York (NY) in USA, during 2024, July, Comcast providing Fixed Broadband services served 111,597 households out of a total of 7,137,002 in the country, resulting in a market share of 1.56%.
ID: 201392, Summary: In the state of California (CA) in USA, during 2024, July, Comcast providing

In [None]:
#Step4: Generating Context-Aware Responses Using Gemini 2.0 Flash
from vertexai.preview.generative_models import GenerativeModel
gen_model = GenerativeModel("gemini-2.0-flash")
def generate_response(query, retrieved_summaries):
    context = "\n".join([s["summary"] for s in retrieved_summaries])
    prompt = f"""You are a professional data analyst providing concise and accurate answers based on the given context. Follow these rules:
1. Answer the query directly in a single, complete sentence, focusing only on the specific information requested (e.g., if the query asks for the number of households, provide the number in a sentence; if it asks for market share, provide the market share in a sentence; if it asks for flow share, provide the flow share in a sentence).
2. Do not include extraneous details not directly relevant to the query (e.g., do not include zip code-level data for a state-level query, or market share for a query about households).
3. If the requested information is not available in the context, respond with: "The available data does not include information on [requested information] for [location/time]."
4. Keep the response simple and professional, without emojis, bold text, or unnecessary formatting.
5. Use the context provided to extract the most relevant information for the query, ensuring the response includes the time period, location, and provider mentioned in the query.

Query: {query}
Context: {context}
Answer based on the context:
"""
    logger.info(f"Generating response for query: {query}")
    try:
        response = gen_model.generate_content(prompt)
        return response.text
    except Exception as e:
        logger.error(f"Error generating response: {e}")
        return "Error generating response"

# Test Step 4
response = generate_response(query, top_summaries)
print(f"Generated Response: {response}")
logger.info("Step 4 complete: Generation implemented.")

INFO:__main__:Generating response for query: If Comcast had served 2.5 million households instead, what would its market share have been?
INFO:__main__:Step 4 complete: Generation implemented.


Generated Response: The available data does not include information on Comcast's market share in July 2024 if it had served 2.5 million households.

