In [0]:
# !pip install databricks_langchain

In [0]:
# %pip install --force-reinstall typing-extensions
# %restart_python

In [0]:
from pyspark.sql import functions as F
from databricks_langchain import ChatDatabricks
from pyspark.sql.types import ArrayType, StringType
import json
import ast

In [0]:
df_gold_summarized = spark.table("gold.conversations_summarized")

In [0]:
# CREATE TAXONOMY USING LLM
DATABRICKS_ENDPOINT = "databricks-meta-llama-3-3-70b-instruct"
llm = ChatDatabricks(endpoint=DATABRICKS_ENDPOINT, temperature=0.0)

CLASSIFICATION_SAMPLE_LIMIT = 5000 # 10% of dataset
df_gold_sampled = (
    df_gold_summarized
    .orderBy(F.rand())
    .limit(CLASSIFICATION_SAMPLE_LIMIT)
)

# Prepare summaries as a single concatenated string
summaries_concat = df_gold_sampled.select(
    F.concat_ws("\n\n", F.collect_list("convo_summary")).alias("all_summaries")
).limit(1)

summaries_str = summaries_concat.collect()[0]['all_summaries']
print(len(summaries_str))

sys_prompt = """
You are a classification assistant. INPUT: a single text block containing 5000+ short summaries, each separated by a blank line. TASK: read **all** summaries in full before responding, and then generate a final list (array) of **no more than 15** broad, meaningful, non-overlapping categories that represent the themes present across all summaries.

REQUIREMENTS:
1. Do not output anything until you have processed the entire input.
2. Produce between 1 and 15 categories, using fewer if appropriate.
3. Categories must be:
   • Distinct and non-overlapping  
   • Broad and meaningful (avoid hyper-specific labels)  
   • Representative of major themes across the summaries  
4. The **output format must be exactly**:

   ["category_one", "category_two", "category_three", ...]

5. Category strings must be:
   • lowercase_snake_case  
   • 1–5 words  
   • concise, thematic, and mutually exclusive  
6. Internal reasoning steps:
   a. Scan all summaries and identify recurring topics.  
   b. Draft 10–25 candidate themes.  
   c. Merge overlapping themes; eliminate categories that are too narrow.  
   d. Validate that all summaries can reasonably fit under one of the final categories.  
   e. Limit the final categories to a max of 15.  
7. Do **not** provide explanations, descriptions, or examples.  
8. Output only the final array. No commentary.

Now read the provided summaries and output the final category list.
""".strip()
        
msgs = [
    {"role": "system", "content": sys_prompt},
    {"role": "user", "content": summaries_str},
]

response = llm.invoke(msgs).content
        
# Parse string representation of list into actual list
try:
    # Try JSON parsing first (cleanest)
    categories_list = json.loads(response)
except json.JSONDecodeError:
    try:
        # Fall back to ast.literal_eval for Python list format
        categories_list = ast.literal_eval(response)
    except (ValueError, SyntaxError):
        # If both fail, try to extract list manually
        # Strip any surrounding text and get content between brackets
        start_idx = response.find('[')
        end_idx = response.rfind(']')
        if start_idx != -1 and end_idx != -1:
            list_str = response[start_idx:end_idx+1]
            categories_list = ast.literal_eval(list_str)
        else:
            # Return empty list if parsing completely fails
            categories_list = []
print(len(categories_list))

In [0]:
with open('/Workspace/Users/ammarbagharib@gmail.com/categories_list.json', 'w') as f:
    json.dump(categories_list, f)
print('categories_list saved to /Workspace/Users/ammarbagharib@gmail.com/categories_list.json')