In [0]:
%run ./0.Config

In [0]:
def shortlist_top_vendors(_: str) -> str:
    """
    Ranks vendors per route based on cost, delivery, service, safety, and route count.
    Accepts a dummy string input for compatibility with ZeroShotAgent.
    Writes top 3 vendors per route to a Unity Catalog Delta table.
    """

    top_vendors = spark.read.option("header", True).option("inferSchema", "true").csv(
        "abfss://vendor-rfq@genaiautomationsa.dfs.core.windows.net/processed/top_vendors.csv")
    vendor_data = spark.read.option("header", True).option("inferSchema", "true").csv(
        "abfss://vendor-rfq@genaiautomationsa.dfs.core.windows.net/vendor_data/vendor_data.csv")

    top_vendors = top_vendors.withColumn("Total Quoted Cost", col("Total Quoted Cost").cast("double"))
    vendor_data = vendor_data.withColumn("On-time Delivery", col("On-time Delivery").cast("double")) \
                             .withColumn("Service Rating", col("Service Rating").cast("double")) \
                             .withColumn("Safety Compliance", col("Safety Compliance").cast("double"))

    joined_df = top_vendors.join(vendor_data, on="Vendor Email", how="inner")
    route_counts = joined_df.groupBy("Vendor Email").agg(countDistinct("Route ID").alias("route_count"))
    joined_df = joined_df.join(route_counts, on="Vendor Email", how="left")

    def normalize(df, colname, new_colname, inverse=False):
        min_val = df.agg({colname: "min"}).collect()[0][0]
        max_val = df.agg({colname: "max"}).collect()[0][0]
        if inverse:
            return df.withColumn(new_colname, (F.lit(max_val) - col(colname)) / (F.lit(max_val) - F.lit(min_val)))
        else:
            return df.withColumn(new_colname, (col(colname) - F.lit(min_val)) / (F.lit(max_val) - F.lit(min_val)))

    joined_df = normalize(joined_df, "Total Quoted Cost", "norm_cost", inverse=True)
    joined_df = normalize(joined_df, "On-time Delivery", "norm_ontime")
    joined_df = normalize(joined_df, "Service Rating", "norm_rating")
    joined_df = normalize(joined_df, "Safety Compliance", "norm_safety")
    joined_df = normalize(joined_df, "route_count", "norm_routes")

    joined_df = joined_df.withColumn(
        "Final Score",
        F.round(
            100 * (
                0.40 * col("norm_cost") +
                0.30 * col("norm_ontime") +
                0.10 * col("norm_rating") +
                0.10 * col("norm_safety") +
                0.10 * col("norm_routes")
            ), 2
        )
    )

    window_spec = Window.partitionBy("Route ID").orderBy(col("Final Score").desc())
    ranked_df = joined_df.withColumn("Rank", rank().over(window_spec))
    final_df = ranked_df.filter(col("Rank") <= 3)

    final_df = final_df.select(
        "Vendor ID",
        top_vendors["Vendor Name"],
        "Vendor Email",
        "Route ID",
        "Total Quoted Cost",
        "Final Score",
        "Rank")
    
#Saving to ADLS as csv file
    final_df.write \
    .mode("overwrite") \
    .option("header", "true") \
    .csv("abfss://vendor-rfq@genaiautomationsa.dfs.core.windows.net/processed/Shortlisted_vendors")

#Saving to Unity catalog table
    final_df_cleaned = final_df.toDF(*[col.replace(" ", "_") for col in final_df.columns])
    final_df_cleaned.write.mode("overwrite").option("mergeSchema", "true").saveAsTable(
        "genai_catalog.genai_schema.shortlisted_vendors")

    return "✅ Top 3 vendors per route have been shortlisted and saved to Unity Catalog genai_catalog.genai_schema.shortlisted_vendors"

shortlist_top_vendors_tool = Tool(
    name="shortlist_top_vendors",
    func=shortlist_top_vendors,
    description="Ranks and selects top 3 vendors per route based on multiple KPIs. Use when asked to calculate top vendors.",
    return_direct=True  # Optional: skips LLM response wrapping
)


In [0]:
tools = [shortlist_top_vendors_tool]

agent = initialize_agent(
    tools=tools,
    llm=ChatDatabricks(endpoint="databricks-llama-4-maverick"),
    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True
)

agent.run("Calculate top 3 vendors per route")