**Install dependency libraries**

In [1]:
!pip install -q google-adk google-genai


[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


**Configure Gemini API Key**

In [15]:
import os
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "FALSE")
os.environ["GOOGLE_API_KEY"] = 'GOOGLE-API-KEY'

**Import necessary libraries**

In [3]:
import asyncio
import re
from typing import Optional, List, Dict, Any

from google.adk.agents import LlmAgent, SequentialAgent
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types

**Configuring Constants**

In [4]:
APP_NAME = "sql_to_pyspark_app"
USER_ID = "demo_user"
SESSION_ID = "sql_to_pyspark_session"

**SQL Normalizer Agent** - Clean and standardize the SQL code

In [5]:
sql_normalizer_agent = LlmAgent(
    name="SQL_Normalizer",
    model='gemini-2.5-flash-lite',
    instruction=(
        "You are a SQL normalization assistant.\n"
        "- Input: raw SQL query from the user.\n"
        "- Output: the same query, but formatted and canonicalized.\n"
        "- Do NOT explain anything. Return ONLY the SQL text.\n"
    ),
    description="Normalizes user SQL into a clean canonical form.",
    output_key="normalized_sql",
)

**PySpark Conversion Agent** - Converts SQL into PySpark DataFrames

In [6]:
sql_to_pyspark_agent = LlmAgent(
    name="SQL_to_PySpark_Converter",
    model='gemini-2.5-flash-lite',
    instruction=(
        "You convert SQL queries into equivalent PySpark DataFrame code.\n\n"
        "Input:\n"
        "  A normalized SQL query is available as {normalized_sql} in the state.\n\n"
        "General rules:\n"
        "  - Always include necessary imports at the top:\n"
        "        from pyspark.sql import functions as F\n"
        "        from pyspark.sql import SparkSession\n\n"
        "  - Assume a SparkSession named 'spark' already exists.\n"
        "  - Use spark.table(\"<db.table>\") for each base table in FROM/JOIN.\n"
        "  - Translate WHERE, SELECT, GROUP BY, HAVING, ORDER BY, JOIN into PySpark DataFrame API.\n"
        "  - Prefer method chaining on DataFrames.\n"
        "  - Use variable name 'final_df' for the final resulting DataFrame.\n\n"
        "Conditional logic (IMPORTANT):\n"
        "  - For SQL CASE WHEN / THEN / ELSE expressions, map them to PySpark using F.when / .otherwise.\n"
        "    Example:\n"
        "       CASE WHEN status = 'ACTIVE' THEN 1 ELSE 0 END AS is_active\n"
        "    should become something like:\n"
        "       df = df.withColumn(\n"
        "           'is_active',\n"
        "           F.when(F.col('status') == 'ACTIVE', F.lit(1)).otherwise(F.lit(0))\n"
        "       )\n\n"
        "  - For multiple WHEN branches, chain F.when calls:\n"
        "       CASE\n"
        "         WHEN score >= 90 THEN 'A'\n"
        "         WHEN score >= 75 THEN 'B'\n"
        "         ELSE 'C'\n"
        "       END AS grade\n"
        "    ->\n"
        "       df = df.withColumn(\n"
        "           'grade',\n"
        "           F.when(F.col('score') >= 90, F.lit('A'))\n"
        "            .when(F.col('score') >= 75, F.lit('B'))\n"
        "            .otherwise(F.lit('C'))\n"
        "       )\n\n"
        "  - For SQL IF(condition, a, b) expressions, also use F.when(condition, a).otherwise(b).\n\n"
        "Output format:\n"
        "  - Return ONLY valid Python code in a single ```python ... ``` block.\n"
        "  - No extra explanation or markdown outside the code block.\n"
    ),
    description="Converts normalized SQL into PySpark DataFrame code, including CASE using when/otherwise.",
    output_key="pyspark_code",
)

**Sequential Agent** - for ordered exexution of Normalizer -> Converter

In [7]:
sql_to_pyspark_pipeline = SequentialAgent(
    name="SQL_to_PySpark_Pipeline",
    sub_agents=[sql_normalizer_agent, sql_to_pyspark_agent],
    description="Pipeline: normalize SQL, then convert to PySpark.",
)

**Initialize Memory Service**

In [8]:
#Create Service Session
session_service = InMemorySessionService()

In [9]:
session = await session_service.create_session(
    app_name=APP_NAME,
    user_id=USER_ID,
    session_id=SESSION_ID,
)

#Create Runner
runner = Runner(
    agent=sql_to_pyspark_pipeline,
    app_name=APP_NAME,
    session_service=session_service,
)

In [10]:
def convert_sql_to_pyspark(sql_query: str) -> str:
    """
    Run the sequential agent pipeline on the given SQL and return PySpark code.
    """
    content = types.Content(
        role="user",
        parts=[types.Part(text=sql_query)],
    )

    final_text = ''

    for event in runner.run(
        user_id=USER_ID,
        session_id=SESSION_ID,
        new_message=content,
    ):
        if event.is_final_response():
            # The final response comes from the last sub-agent in the sequence
            if event.content and event.content.parts:
                final_text = event.content.parts[0].text

    if not final_text:
        raise RuntimeError("Agent did not return any final text response.")

    # Strip code fences if present
    if "```" in final_text:
        lines = []
        for line in final_text.splitlines():
            if line.strip().startswith("```"):
                continue
            lines.append(line)
        final_text = "\n".join(lines).strip()

    return final_text

In [11]:
def _evaluate_pyspark_rule_based(sql_query: str, pyspark_code: str, must_contain: List[str], nice_to_have: List[str] | None = None,) -> Dict[str, Any]:
    """
    - must_contain: list of substrings that MUST appear in the PySpark code
    - nice_to_have: optional list of substrings that give bonus score if present
    """
    if nice_to_have is None:
        nice_to_have = []

    missing_required = [pattern for pattern in must_contain if pattern not in pyspark_code]
    matched_required = len(must_contain) - len(missing_required)

    matched_optional = [pattern for pattern in nice_to_have if pattern in pyspark_code]

    # Simple scoring:
    # - semantic_correctness ≈ how many required patterns matched
    # - syntactic_validity: crude check that we see "df =" and "spark.table("
    # - readability: crude check for chaining / withColumn usage
    
    total_required = max(len(must_contain), 1)
    semantic_score = int(10 * matched_required / total_required)

    syntactic_score = 0
    if "spark.table(" in pyspark_code:
        syntactic_score = 7
    if "spark.table(" in pyspark_code and "groupBy(" in pyspark_code:
        syntactic_score = 9

    readability_score = 0
    if ".groupBy(" in pyspark_code or ".withColumn(" in pyspark_code:
        readability_score = 7
    if ".groupBy(" in pyspark_code and ".agg(" in pyspark_code:
        readability_score = 9

    # Small bonus for nice-to-have patterns
    bonus = min(len(matched_optional), 2)
    overall_score = max(0, min(10, int((semantic_score + syntactic_score + readability_score) / 3) + bonus))

    comments = []
    if missing_required:
        comments.append(f"Missing expected patterns: {missing_required}")
    if matched_optional:
        comments.append(f"Good: found optional patterns {matched_optional}")
    if not comments:
        comments.append("Looks structurally OK based on rule-based checks.")

    return {
        "semantic_correctness": semantic_score,
        "syntactic_validity": syntactic_score,
        "readability": readability_score,
        "overall_score": overall_score,
        "comments": " | ".join(comments),
    }

In [12]:
def run_sql_to_pyspark_tests(test_cases: List[Dict[str, Any]],) -> List[Dict[str, Any]]:
    """
    Run SQL → PySpark → rule-based evaluation for one or many test cases.

    test_cases example:
        [
            {
                "name": "case1",
                "sql": "...",
                "must_contain": ["groupBy(", ".agg("],
                "nice_to_have": ["F.when("],
            },
            ...
        ]
    """
    results: List[Dict[str, Any]] = []

    for case in test_cases:
        name = case.get("name", "unnamed_case")
        sql = case["sql"]
        must_contain = case.get("must_contain", [])
        nice_to_have = case.get("nice_to_have", [])

        # 1) Convert SQL → PySpark via your main agent
        pyspark_code = convert_sql_to_pyspark(sql)

        # 2) Evaluate using rule-based checks ONLY (no Gemini)
        evaluation = _evaluate_pyspark_rule_based(
            sql_query=sql,
            pyspark_code=pyspark_code,
            must_contain=must_contain,
            nice_to_have=nice_to_have,
        )

        results.append(
            {
                "name": name,
                "sql": sql,
                "pyspark_code": pyspark_code,
                "evaluation": evaluation,
            }
        )

    return results

**Test Cases** for evaluation

In [13]:
test_cases = [
    {
        "name": "simple_where",
        "sql": """
            SELECT
                id,
                name,
                country
            FROM sales.customers
            WHERE country = 'IN';
        """,
        "must_contain": [
            "spark.table(\"sales.customers\")",
            ".filter(",
        ],
        "nice_to_have": [
            "F.col(\"country\")",
        ],
    },
    {
        "name": "group_by_agg_having",
        "sql": """
            SELECT
                country,
                COUNT(*) AS customer_count,
                AVG(age) AS avg_age
            FROM sales.customers
            WHERE status = 'ACTIVE'
            GROUP BY country
            HAVING COUNT(*) >= 100
            ORDER BY customer_count DESC;
        """,
        "must_contain": [
            "spark.table(\"sales.customers\")",
            ".filter(",
            ".groupBy(",
            ".agg(",
            ".orderBy(",
        ],
        "nice_to_have": [
            "F.count(",
            "F.avg(",
        ],
    },
    {
        "name": "join_inner_group_by",
        "sql": """
            SELECT
                c.country,
                COUNT(DISTINCT c.id) AS active_customers,
                SUM(o.amount) AS total_amount
            FROM sales.customers c
            JOIN sales.orders o
              ON c.id = o.customer_id
            WHERE o.status = 'COMPLETED'
              AND o.order_date >= DATE '2024-01-01'
            GROUP BY c.country
            HAVING SUM(o.amount) > 100000
            ORDER BY total_amount DESC;
        """,
        "must_contain": [
            "spark.table(\"sales.customers\")",
            "spark.table(\"sales.orders\")",
            ".join(",
            ".groupBy(",
            ".agg(",
        ],
        "nice_to_have": [
            "F.sum(",
            "F.countDistinct(",
        ],
    },
    {
        "name": "case_when_flag",
        "sql": """
            SELECT
                id,
                status,
                CASE
                    WHEN status = 'ACTIVE' THEN 1
                    ELSE 0
                END AS is_active
            FROM sales.customers;
        """,
        "must_contain": [
            "spark.table(\"sales.customers\")",
            ".withColumn(",
            "F.when(",
            ".otherwise(",
        ],
        "nice_to_have": [
            "F.col(\"status\")",
        ],
    },
    {
        "name": "case_when_multi_segment",
        "sql": """
            SELECT
                customer_id,
                CASE
                    WHEN total_amount >= 10000 THEN 'HIGH'
                    WHEN total_amount >= 5000 THEN 'MEDIUM'
                    ELSE 'LOW'
                END AS customer_segment
            FROM sales.customer_agg;
        """,
        "must_contain": [
            "spark.table(\"sales.customer_agg\")",
            ".withColumn(",
            "F.when(",
            ".otherwise(",
        ],
        "nice_to_have": [
            ".when(",
            "F.col(\"total_amount\")",
        ],
    },
]

***Run the tests***

In [14]:
all_results = run_sql_to_pyspark_tests(test_cases)

for res in all_results:
    print("=" * 80)
    print("Test case:", res["name"])
    print("\nGenerated PySpark Code:\n", res["pyspark_code"])
    print("\nEvaluation:\n", res["evaluation"])


Test case: simple_where

Generated PySpark Code:
 from pyspark.sql import functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("sql_to_pyspark").getOrCreate()

customers_df = spark.table("sales.customers")

final_df = customers_df.filter(F.col("country") == "IN") \
    .select("id", "name", "country")

final_df.show()

Evaluation:
 {'semantic_correctness': 10, 'syntactic_validity': 7, 'readability': 0, 'overall_score': 6, 'comments': 'Good: found optional patterns [\'F.col("country")\']'}
Test case: group_by_agg_having

Generated PySpark Code:
 from pyspark.sql import functions as F
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("sql_to_pyspark").getOrCreate()

customers_df = spark.table("sales.customers")

final_df = customers_df.filter(F.col("status") == "ACTIVE") \
    .groupBy("country") \
    .agg(
        F.count("*").alias("customer_count"),
        F.avg("age").alias("avg_age")
    ) \
    .filter(F.col("customer_