# Part 10: Logical and Semantic routing (逻辑和语义路由)
由llm进行分类，在rag前，先选择合适的数据库。

## Logical routing (逻辑路由)
由llm进行分类，在rag前，路由到合适的数据库。

In [None]:
import os
from pprint import pprint

from typing import Literal

from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI

# 定义llm的返回结果
class RouteQuery(BaseModel):
    """将用户的查询路由到最相关的数据源"""
    datasource: Literal["python_docs", "js_docs", "golang_docs"] = Field(
        ...,
        description="Given a user question choose "
    )

# LLM
llm = ChatOpenAI(
    model=os.getenv("ARK_MODEL"),
    api_key=os.getenv("ARK_API_KEY"),
    base_url=os.getenv("ARK_API_URL"),
    temperature=0.0,
)
# 结构化模型的输出，实际上做了两件事：
# 1.提示llm输出json format的结果
# 2.将llm的输出结果转换为pydantic定义的对象
structured_llm = llm.with_structured_output(RouteQuery)

# 设计提示词，由llm进行数据源的选择
system_prompt = (
    "You are an expert at routing a user question to "
    "the appropriate data source."
    "Based on the programming language the question is referring to, "
    "route it to the relevant data source."
)
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("human", "{question}")
    ]
)

# 定义路由处理链
router = (
    prompt
    | structured_llm
)

In [5]:
# 使用question调用
question = """Why doesn't the following code work:

from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages(["human", "speak in {language}"])
prompt.invoke("french")
"""
result = router.invoke({"question": question})
pprint(result)

RouteQuery(datasource='python_docs')


In [7]:
# 根据返回的“代码语言”，选择合适的数据库（此处只作示意，没有实现具体的逻辑）
def choose_route(result):
    if "python_docs" in result.datasource.lower():
        return "chain for python_docs"
    elif "js_docs" in result.datasource.lower():
        return "chain for js_docs"
    else:
        return "golang_docs"

from langchain_core.runnables import RunnableLambda

full_chain = (
    router
    | RunnableLambda(choose_route)
)
final_result = full_chain.invoke({"question": question})
pprint(final_result)

'chain for python_docs'


## Semantic routing 语义路由
根据用户的输入内容，计算query和prompt间的语义相似度，路由到合适的提示词模板

In [None]:
from langchain.utils.math import cosine_similarity
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

# 两个分别适用于物理、数据场景的提示词
physics_template = """You are a very smart physics professor. \
You are great at answering questions about physics in a concise and easy to understand manner. \
When you don't know the answer to a question you admit that you don't know.

Here is a question:
{query}"""

math_template = """You are a very good mathematician. You are great at answering math questions. \
You are so good because you are able to break down hard problems into their component parts, \
answer the component parts, and then put them together to answer the broader question.

Here is a question:
{query}"""

from ark_embedding import ArkEmbeddings

embd = ArkEmbeddings(
    model=os.getenv("ALIYUN_EMBEDDING_MODEL"),
    api_key=os.getenv("ALIYUN_API_KEY"),
    api_url=os.getenv("ALIYUN_API_URL"),
    batch_size=10
)

# 将prompt向量化
prompt_templates = [physics_template, math_template]
prompt_embeddings = embd.embed_documents(prompt_templates)

In [None]:
# 将问题路由到合适的prompt,并回答问题
def prompt_router(input):
    # question向量化
    query_embedding = embd.embed_query(input["query"])
    # 计算相似度
    similarity = cosine_similarity([query_embedding], prompt_embeddings)[0]
    most_similar = prompt_templates[similarity.argmax()]
    # 选择相似度最高的prompt
    print("Using MATH" if most_similar == math_template else "Using PHYSICS")
    # 显式完成提示词构建（可选，也可仅返回prompt模板）
    prompt = PromptTemplate.from_template(most_similar).invoke(input)
    return prompt

from langchain_core.runnables import RunnableParallel

chain = (
    RunnableParallel({"query": RunnablePassthrough()})
    | RunnableLambda(prompt_router)  # 将获取的“提示词模板+问题”进行拼装，将调用llm
    | llm  # 如果不显示进行拼装，实际也可以运行，因为LangChain维护“执行上下文”，自动进行变量回溯查找。
    | StrOutputParser()
)

result = chain.invoke("What's a black hole?")
pprint(result)

Using PHYSICS
("Of course. That's an excellent and fundamental question.\n"
 '\n'
 'In the simplest terms, a **black hole is a region of space where gravity is '
 'so intense that nothing, not even light, can escape from it.**\n'
 '\n'
 "Let's break that down:\n"
 '\n'
 '1.  **The "Point of No Return":** The outer boundary of a black hole is '
 'called the **event horizon**. Think of it as a one-way door. Once anything—a '
 'spaceship, a planet, or a particle of light (a photon)—crosses this '
 'boundary, it can never come back out. We cannot see what happens inside, '
 'hence the name "black hole."\n'
 '\n'
 '2.  **Why the Gravity is So Strong:** The extreme gravity comes from an '
 'immense amount of mass being crushed into a vanishingly small point at the '
 'very center, called a **singularity**. Imagine crushing the entire mass of '
 'our Sun into a sphere less than 4 miles across, or the entire Earth into a '
 'sphere the size of a marble. This incredible density warps the fabric

# Part 11: 查询构建

前提，创建数据库含有元数据。  
可以考虑，采用结构化的查询（带过滤条件），查询根据元数据过滤后的数据。

In [2]:
from langchain_community.document_loaders import YoutubeLoader

try:
    docs = YoutubeLoader.from_youtube_url(
        "https://www.youtube.com/watch?v=pbAd8O1Lvm4",
        add_video_info=True,
    ).load()
    
    print("成功加载YouTube字幕数据!")
    print(f"文档数量: {len(docs)}")
    print(f"第一个文档的元数据: {docs[0].metadata}")
    print(f"第一个文档的前200个字符: {docs[0].page_content[:200]}...")
    
except Exception as e:
    print(f"YouTube加载失败: {e}")
    print("使用模拟数据代替...")
    
    # 解决方案: 使用模拟的YouTube视频数据
    from langchain_core.documents import Document
    
    # 模拟YouTube视频的字幕数据
    mock_transcript = """
    Welcome to this tutorial on RAG systems. Today we'll discuss how to build 
    retrieval augmented generation systems from scratch. We'll cover topics like 
    semantic search, vector embeddings, and how to integrate them with language models.
    
    First, let's understand what RAG means. RAG stands for Retrieval Augmented Generation.
    It's a technique that combines information retrieval with text generation to create
    more accurate and informative responses.
    
    The key components of a RAG system include: document chunking, vector embeddings,
    semantic search, and response generation using language models.
    """
    
    docs = [Document(
        page_content=mock_transcript,
        metadata={
            "source": "https://www.youtube.com/watch?v=pbAd8O1Lvm4",
            "title": "RAG Tutorial Video",
            "author": "Tutorial Channel",
            "length": 600,
            "view_count": 10000,
            "publish_date": "2024-01-15"
        }
    )]
    
    print("使用模拟数据成功!")
    print(f"文档数量: {len(docs)}")
    print(f"元数据: {docs[0].metadata}")

docs[0].metadata

YouTube加载失败: HTTP Error 400: Bad Request
使用模拟数据代替...
使用模拟数据成功!
文档数量: 1
元数据: {'source': 'https://www.youtube.com/watch?v=pbAd8O1Lvm4', 'title': 'RAG Tutorial Video', 'author': 'Tutorial Channel', 'length': 600, 'view_count': 10000, 'publish_date': '2024-01-15'}


{'source': 'https://www.youtube.com/watch?v=pbAd8O1Lvm4',
 'title': 'RAG Tutorial Video',
 'author': 'Tutorial Channel',
 'length': 600,
 'view_count': 10000,
 'publish_date': '2024-01-15'}

In [17]:
from typing import Optional
from pydantic import BaseModel, Field

class TutorialSearch(BaseModel):
    """结构化搜索查询，包含搜索词和过滤条件"""

    query: str = Field(
        ...,
        description="用于语义搜索的主要查询词",
    )
    title_search: str = Field(
        ...,
        description=(
            "用于标题搜索的查询词"
        ),
    )
    min_view_count: Optional[int] = Field(
        None,
        description="Minimum view count filter, inclusive. Only use if explicitly specified.",
    )
    max_view_count: Optional[int] = Field(
        None,
        description="Maximum view count filter, exclusive. Only use if explicitly specified.",
    )
    earliest_publish_date: Optional[str] = Field(
        None,
        description="Earliest publish date filter, inclusive. Only use if explicitly specified.",
    )
    latest_publish_date: Optional[str] = Field(
        None,
        description="Latest publish date filter, exclusive. Only use if explicitly specified.",
    )
    min_length_sec: Optional[int] = Field(
        None,
        description="Minimum video length in seconds, inclusive. Only use if explicitly specified.",
    )
    max_length_sec: Optional[int] = Field(
        None,
        description="Maximum video length in seconds, exclusive. Only use if explicitly specified.",
    )

    def pretty_print(self) -> None:
        for field_name, field_info in self.__class__.model_fields.items():
            value = getattr(self, field_name)
            if value is not None and value != field_info.default:
                print(f"{field_name}: {value}")


In [18]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
import os

system = """You are an expert at converting user questions into database queries. \
You have access to a database of tutorial videos about a software library for building LLM-powered applications. \
Given a question, return a database query optimized to retrieve the most relevant results.

If there are acronyms or words you are not familiar with, do not try to rephrase them."""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("user", "{question}"),
    ]
)
# LLM
llm = ChatOpenAI(
    model=os.getenv("ARK_MODEL"),
    api_key=os.getenv("ARK_API_KEY"),
    base_url=os.getenv("ARK_API_URL"),
    temperature=0.0,
)
structured_llm = llm.with_structured_output(TutorialSearch)
query_analyzer = (
    prompt
    | structured_llm
)

In [19]:
res = query_analyzer.invoke({"question": "rag from scratch"})
res.pretty_print()

query: rag from scratch
title_search: rag from scratch


In [20]:
query_analyzer.invoke(
    {"question": "videos on chat langchain published in 2023"}
).pretty_print()

query: chat langchain
title_search: chat langchain
earliest_publish_date: 2023-01-01
latest_publish_date: 2024-01-01


In [21]:
query_analyzer.invoke(
    {"question": "videos that are focused on the topic of chat langchain that are published before 2024"}
).pretty_print()

query: chat langchain
title_search: chat langchain
latest_publish_date: 2024-01-01


In [22]:
res = query_analyzer.invoke(
    {
        "question": "how to use multi-modal models in an agent, only videos under 5 minutes"
    }
)
res.pretty_print()

query: multi-modal models in agent
title_search: multi-modal agent
max_length_sec: 300


In [23]:
from langchain.chains.query_constructor.base import (
    StructuredQueryOutputParser,
)

output_parser = StructuredQueryOutputParser.from_components()

In [42]:
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.structured_query import (
    StructuredQuery, 
    Comparison, 
    Operation, 
    Comparator, 
    Operator,
    FilterDirective
)
from typing import Optional, List
import json

class CustomTutorialQueryOutputParser(BaseOutputParser[StructuredQuery]):
    """自定义输出解析器，将 TutorialSearch 格式转换为 StructuredQuery"""
    
    def parse(self, text: str) -> StructuredQuery:
        """解析 JSON 字符串并转换为 StructuredQuery"""
        try:
            # 解析 JSON 数据
            data = json.loads(text)
            
            # 构建查询字符串（主要用于语义搜索）
            query = data.get("query", "")
            
            # 构建过滤条件列表
            filter_directives: List[FilterDirective] = []
            
            # 标题搜索 - 使用 CONTAIN 比较
            if data.get("title_search"):
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.CONTAIN,
                        attribute="title",
                        value=data["title_search"]
                    )
                )
            
            # 观看次数过滤
            if data.get("min_view_count") is not None:
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.GTE,
                        attribute="view_count",
                        value=data["min_view_count"]
                    )
                )
            if data.get("max_view_count") is not None:
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.LT,
                        attribute="view_count", 
                        value=data["max_view_count"]
                    )
                )
            
            # 发布时间过滤
            if data.get("earliest_publish_date"):
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.GTE,
                        attribute="publish_date",
                        value=data["earliest_publish_date"]
                    )
                )
            if data.get("latest_publish_date"):
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.LT,
                        attribute="publish_date",
                        value=data["latest_publish_date"]
                    )
                )
            
            # 视频长度过滤
            if data.get("min_length_sec") is not None:
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.GTE,
                        attribute="length",
                        value=data["min_length_sec"]
                    )
                )
            if data.get("max_length_sec") is not None:
                filter_directives.append(
                    Comparison(
                        comparator=Comparator.LT,
                        attribute="length",
                        value=data["max_length_sec"]
                    )
                )
            
            # 如果有多个过滤条件，用 AND 操作组合
            filter_expr: Optional[FilterDirective] = None
            if len(filter_directives) == 1:
                filter_expr = filter_directives[0]
            elif len(filter_directives) > 1:
                filter_expr = Operation(
                    operator=Operator.AND,
                    arguments=filter_directives
                )
            
            # 处理 limit（设置为 None，因为我们没有明确的限制）
            limit = None
            
            return StructuredQuery(
                query=query,
                filter=filter_expr,
                limit=limit
            )
            
        except Exception as e:
            raise ValueError(f"解析查询时出错: {e}")
    
    @classmethod
    def from_components(cls) -> "CustomTutorialQueryOutputParser":
        """工厂方法创建解析器实例"""
        return cls()

output_parser = CustomTutorialQueryOutputParser.from_components()

In [44]:
res_parser = output_parser.invoke(res.model_dump_json())
res_parser

StructuredQuery(query='multi-modal models in agent', filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.CONTAIN: 'contain'>, attribute='title', value='multi-modal agent'), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='length', value=300)]), limit=None)

结构化查询，参考文档：https://python.langchain.com/docs/how_to/self_query/#constructing-from-scratch-with-lcel