In [None]:
#!/usr/bin/env python3

import os
import sqlite3
import argparse
import time
import pandas as pd
from tqdm import tqdm
from typing import List

import google_genai  # Google GenAI SDK for Gemini
from google_genai import client as genai_client

# -------------------- Utils & API wrapper --------------------

def init_gemini_client(model: str = "gemini-2.0-flash", location: str = "global"):
    """Initialize and return a Gemini/GenAI client."""
    # This assumes you have set environment variables or used ADC
    return genai_client.Client(model=model, location=location)


def ask_gemini_generate_queries(client, chunk: str, num_queries: int) -> List[str]:
    """Ask Gemini to generate `num_queries` queries for the given chunk."""
    prompt = (
        f"Here is a document chunk:\n\"\"\"\n{chunk}\n\"\"\"\n"
        f"Generate {num_queries} unique user queries that would " 
        f"retrieve or refer to this chunk in a search/retrieval system."
    )
    response = client.generate_content(prompt=prompt)
    text = response.text.strip()
    # Split into individual queries (assuming newline-separated)
    queries = [q.strip() for q in text.split("\n") if q.strip()]
    # If more than needed, take first num_queries
    return queries[:num_queries]


# -------------------- Main workflow --------------------

def main(args):
    # Connect to ChromaDB (sqlite) or adjust to your backend
    conn = sqlite3.connect(args.db_path)
    cursor = conn.cursor()

    # Query only chunk rows
    cursor.execute("SELECT id, content FROM chunks_table WHERE type = 'C'")
    rows = cursor.fetchall()
    print(f"Found {len(rows)} chunks (type='C')")

    # Prepare DataFrame to hold pairs
    pairs_df = pd.DataFrame(columns=["query", "chunk"])

    gemini = init_gemini_client(model=args.gemini_model, location=args.location)

    # Process in batches of size batch_size
    for i in tqdm(range(0, len(rows), args.batch_size), desc="Processing chunk-batches"):
        batch = rows[i : i + args.batch_size]
        for (chunk_id, chunk_content) in batch:
            # Ask Gemini to generate num_queries queries for this chunk
            queries = ask_gemini_generate_queries(gemini, chunk_content, args.num_queries_per_chunk)
            for q in queries:
                pairs_df = pairs_df.append({"query": q, "chunk": chunk_content}, ignore_index=True)
        # Optional sleep to respect API rate limits
        time.sleep(args.sleep_seconds)

    # Persist to CSV
    pairs_df.to_csv(args.output_csv, index=False)
    print(f"Saved {len(pairs_df)} query-chunk pairs to {args.output_csv}")

    conn.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Automate query generation for chunks using Gemini API")
    parser.add_argument("--db_path", type=str, required=True, help="Path to ChromaDB sqlite DB")
    parser.add_argument("--batch_size", type=int, default=10, help="Number of chunks per API batch")
    parser.add_argument("--num_queries_per_chunk", type=int, default=20, help="Number of queries to generate per chunk")
    parser.add_argument("--gemini_model", type=str, default="gemini-2.0-flash", help="Gemini model to use")
    parser.add_argument("--location", type=str, default="global", help="Cloud location for Gemini API")
    parser.add_argument("--output_csv", type=str, default="generated_pairs.csv", help="Output CSV path")
    parser.add_argument("--sleep_seconds", type=float, default=1.0, help="Sleep seconds between batches (API rate control)")
    args = parser.parse_args()

    main(args)