In [1]:
import os
from getpass import getpass

os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Text-mining-for-taxonomy"

In [2]:
import logging
import operator
from typing import Annotated, List, Optional, TypedDict

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger("tnt-llm")

class Doc(TypedDict):
    id: str
    content: str
    summary: Optional[str]
    explanation: Optional[str]
    category: Optional[str]

class TaxonomyGenerationState(TypedDict):
    # The raw docs; we inject summaries within them in the first step
    documents: List[Doc]
    # Indices to be concise
    minibatches: List[List[int]]
    # Candidate Taxonomes (full trajectory)
    clusters: Annotated[List[List[dict]], operator.add]

# Summarize docs

Chat logs can get quite long. Our taxonomy generation step needs to see large, diverse minibatches to be able to adequately capture the distribution of categories. To ensure they can all fit efficiently into the context window, we first summarize each chat log. Downstream steps will use these summaries instead of the raw doc content.

In [None]:
import re
from langchain import hub
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnablePassthrough

summary_prompt = hub.pull("wfh/tnt-llm-summary-generation").partial(
    summary_prompt=20, explanation_length=30
)

def parse_summary(xml_string: str) -> dict:
    summary_pattern = r"<summary>(.*?)</summary>"
    explanation_pattern = r"<explanation>(.*?)</explanation>"

    summary_match = re.search(summary_pattern, xml_string, re.DOTALL)
    explanation_match = re.search(explanation_pattern, xml_string, re.DOTALL)

    summary = summary_match.group(1).strip() if summary_match else ""
    explanation = explanation = explanation_match.group(1).strip() if explanation_match else ""

    return {"summary": summary, "explanation": explanation}

summary_llm_chain = (
    summary_prompt
    | ChatOpenAI(model="gpt-4o-mini")
    | StrOutputParser()
    # Customise the tracking name for easier organization
).with_config(run_name="GenerateSummary")
summary_chain = summary_llm_chain | parse_summary

# Now combine as a "map" operation in a map-reduce chain 
# Input: state
# Output: state U summaries
# Processes docs in parallel
def get_content(state: TaxonomyGenerationState):
    docs = state['documents']
    return [{"content": doc['content']} for doc in docs]

map_step = RunnablePassthrough.assign(
    summaries=get_content
    # This effectively creates a "map" operation
    # Note you can make this more robust by handling individual errors
    | RunnableLambda(func=summary_chain.batch, afunc=summary_chain.abatch)
)

def reduce_summaries(combined: dict) -> TaxonomyGenerationState:
    summaries = combined['summaries']
    documents = combined['documents']
    return {
        "documents": [
            {
                "id": doc['id'],
                "content": doc['content'],
                "summary": summ_info['summary'],
                "explanation": summ_info['explanation'],
            }
            for doc, summ_info in zip(documents, summaries)
        ]
    }
map_reduce_chain = map_step | reduce_summaries

# Split into Minibatches

Each minibatch contains a random sample of docs. This lets the flow identify inadequacies in the current taxonomy using new data.

In [None]:
import random

def get_minibatches(state: TaxonomyGenerationState, config: RunnableConfig):
    batch_size = config['configurable'].get("batch_size", 200)
    original = state['documents']
    indices = list(range(len(original)))
    random.shuffle(indices)
    if len(indices) < batch_size:
        # Don't pad needlessly if we can't fill a single batch
        return [indices]
    
    num_full_batches = len(indices) // batch_size

    batches = [
        indices[i * batch_size: (i+1) * batch_size] for i in range(num_full_batches)
    ]
    leftovers = len(indices) % batch_size
    if leftovers:
        last_batch = indices[num_full_batches * batch_size :]
        elements_to_add = batch_size - leftovers
        last_batch += random.sample(indices, elements_to_add)
        batches.append(last_batch)

    return {
        "minibatches": batches,
    }

# Taxonomy Generation Utilities

In [None]:
from typing import Dict
from langchain_core.runnables import Runnable

def parse_taxa(output_text: str) -> Dict:
    """Extract the taxonomy from the generated output."""
    cluster_matches = re.findall(
        r"\s*<id>(.*?)</id>\s*<name>(.*?)</name>\s*<description>(.*?)</description>\s*",
        output_text,
        re.DOTALL,
    )
    clusters = [
        {'id': id.strip(), "name": name.strip(), "description": description.strip()}
        for id, description in cluster_matches
    ]

    