# 2-stage user-query building approach

Steps:
- STAGE 1: Extract conditions in natural language (spans from the input) corresponding to relevant metadata fields
    - Output (simplified one):
    ```python
    [
        ("contains English and French", "languages"),
        ("from Huggigface platform", "platform"),
    ]
    ```
    

- STAGE 2: Process each natural language condition separately (we can dynamically create new output schemas)
    - Output schema (for one condition):
    ```python   
    class Condition(BaseModel):
        values: list[DYNAMIC_TYPE]
        comp: Literal[">", "<", ...]
        log_op: Literal["AND", "OR"]
    ```

Problems:
- It's not easy to model relationships (logical opetaros) between multiple Conditions associated with different types
    - I suppose for now we will not support such use cases: E.g., "Retrieve models that are either in Slovak or that have at least 1 million datapoints."
    - We implicitly apply AND operator in between the Conditions. OR operator can only be applied in between values pertaining to a specific metadata field

In [18]:
import json

from langchain_ollama import ChatOllama
from dotenv import load_dotenv
from langchain_core.prompt_values import HumanMessage

from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.chains.query_constructor.base import (
    StructuredQueryOutputParser,
    get_query_constructor_prompt
)
from langchain_core.prompts import FewShotChatMessagePromptTemplate
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate

from langchain_core.structured_query import Comparator, Operator
from langchain.retrievers.self_query.milvus import MilvusTranslator

In [8]:
import sys
sys.path.append("./src")

In [9]:
load_dotenv()
MODEL_NAME = "llama3.1:8b"

model = ChatOllama(model=MODEL_NAME, num_predict=4096, num_ctx=8192,)

In [10]:
from llm_metadata_filter import DatasetMetadataTemplate, Llama_ManualFunctionCalling

In [11]:
from pydantic import BaseModel, Field

class NaturalLanguageCondition(BaseModel):
    condition: str = Field(..., description="Natural language condition corresponding to a particular metadata field we use for filtering")
    field: str = Field(..., description="Name of the metadata field")

class NaturalLanguageConditions(BaseModel):
    """Extraction of natural language conditions found within a user query"""
    conditions: list[NaturalLanguageCondition] = Field(..., description="Natural language conditions")

In [12]:
attribute_types = {
    "platform": "string",
    "date_published": "string",
    "year": "integer",
    "month": "integer",
    "domains": "string",
    "task_types": "string",
    "license": "string",
    "size_in_mb": "float",
    "num_datapoints": "integer",
    "size_category": "string",
    "modalities": "string",
    "data_formats": "string",
    "languages": "string",
}

In [13]:
metadata_field_info = [
    {
        "name": name, 
        "description": field.description, 
        "type": attribute_types[name]
    } for name, field in DatasetMetadataTemplate.model_fields.items()
]

In [14]:
user_prompt = """
    Your task is to extract user-defined conditions from a query, focusing on metadata fields relevant to filtering that are specified in the schema below. 
    Identify each condition explicitly mentioned in the query and assign it to the appropriate metadata field.

    Extract the conditions only from the last user query as the other ones are used as an examples

    A simple schema below briefly describes all the metadata fields we use for filtering purposes:
    {model_schema}

    **Instructions:**
    1. Extract conditions as they appear in the query in their natural language form. You may slightly modify their word structure if necessary to perserve the logic regarding particular condition.
    2. For each condition, determine the metadata field it pertains to, based on its meaning.
    3. If a condition does not clearly pertain to a known metadata field, exclude it.
"""

In [15]:
examples = [
    {
        "input": "Retrieve HuggingFace datasets about stocks",
        "output": {
            "conditions": [
                {
                    "condition": "HuggingFace datasets", 
                    "field": "platform"
                }
            ]
        }
    },
    {
        "input": "Show me the summarization news datasets containing both the French as well as English data. The dataset however can't include any German data nor any Slovak data.",
        "output": {
            "conditions": [
                {
                    "condition": "summarization datasets", 
                    "field": "task_types"
                },
                {
                    "condition": "containing both the French as well as English data", 
                    "field": "languages"
                },
                {
                    "condition": "can\'t include any German data nor any Slovak data", 
                    "field": "languages"
                },
            ]
        }
    },
    {
        "input": "Find all chocolate datasets created after January 1, 2022, that are represented in textual or image format with its dataset size smaller than 500 000KB.",
        "output": {
            "conditions": [
                {
                    "condition": "datasets created after January 1, 2022", 
                    "field": "date_published"
                },
                {
                    "condition": "represented in textual or image format", 
                    "field": "modalities"
                },
                {
                    "condition": "dataset size smaller than 500 000KB", 
                    "field": "size_in_mb"
                },
            ]
        }
    },
    {
        "input": "Datasets that have either have over 50k datapoints but fewer than 100k, or datasets that have MIT or apache-2.0 license",
        "output": {
            "conditions": [
                {
                    "condition": "have over 50k datapoints but fewer than 100k", 
                    "field": "num_datapoints"
                },
                {
                    "condition": "represented in textual or image format", 
                    "field": "modalities"
                },
                {
                    "condition": "dataset size smaller than 500 000KB", 
                    "field": "size_in_mb"
                },
            ]
        }
    },
    {
        "input": "Search for COVID-19 datasets",
        "output": {
            "conditions": []
        }
    },
]

In [None]:
example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "User Query: {query}"),
        ("ai", "{output}"),
    ]
)

In [None]:


fewshot_prompt = FewShotChatMessagePromptTemplate(
    examples=examples,
    example_prompt=example_prompt,
)

In [36]:
modified_user_prompt = user_prompt.format(model_schema=json.dumps(metadata_field_info))

In [57]:
from langchain_core.prompts import ChatPromptTemplate

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "system_prompt"),
        HumanMessage(modified_user_prompt),
        # ("user", modified_user_prompt),
        fewshot_prompt,
        ("human", "User Query: {query}"),
    ]
)

In [64]:
final_prompt.invoke({"query": "test"}).to_messages()

[SystemMessage(content='system_prompt', additional_kwargs={}, response_metadata={}),
 HumanMessage(content='\n    Your task is to extract user-defined conditions from a query, focusing on metadata fields relevant to filtering that are specified in the schema below. \n    Identify each condition explicitly mentioned in the query and assign it to the appropriate metadata field.\n\n    Extract the conditions only from the last user query as the other ones are used as an examples\n\n    A simple schema below briefly describes all the metadata fields we use for filtering purposes:\n    [{"name": "platform", "description": "The platform where the asset is hosted. ONLY PERMITTED VALUES: [\'huggingface\', \'openml\', \'zenodo\']", "type": "string"}, {"name": "date_published", "description": "The original publication date of the asset in the format \'YYYY-MM-DD\'.", "type": "string"}, {"name": "year", "description": "The year extracted from the publication date.", "type": "integer"}, {"name": "

In [54]:
final_prompt.messages

[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='system_prompt'), additional_kwargs={}),
 HumanMessage(content='\n    Your task is to extract user-defined conditions from a query, focusing on metadata fields relevant to filtering that are specified in the schema below. \n    Identify each condition explicitly mentioned in the query and assign it to the appropriate metadata field.\n\n    Extract the conditions only from the last user query as the other ones are used as an examples\n\n    A simple schema below briefly describes all the metadata fields we use for filtering purposes:\n    [{"name": "platform", "description": "The platform where the asset is hosted. ONLY PERMITTED VALUES: [\'huggingface\', \'openml\', \'zenodo\']", "type": "string"}, {"name": "date_published", "description": "The original publication date of the asset in the format \'YYYY-MM-DD\'.", "type": "string"}, {"name": "year", "description": "The 

In [59]:
import json
json.dumps(metadata_field_info)

'[{"name": "platform", "description": "The platform where the asset is hosted. ONLY PERMITTED VALUES: [\'huggingface\', \'openml\', \'zenodo\']", "type": "string"}, {"name": "date_published", "description": "The original publication date of the asset in the format \'YYYY-MM-DD\'.", "type": "string"}, {"name": "year", "description": "The year extracted from the publication date.", "type": "integer"}, {"name": "month", "description": "The month extracted from the publication date.", "type": "integer"}, {"name": "domains", "description": "The AI technical domains of the asset, describing the type of data and AI task involved. ONLY PERMITTED VALUES: [\'NLP\', \'Computer Vision\', \'Audio Processing\']. Leave the list empty if not specified.", "type": "string"}, {"name": "task_types", "description": "The machine learning tasks supported by this asset. Acceptable values include task types found on HuggingFace (e.g., \'token-classification\', \'question-answering\', ...). Leave the list empty

In [73]:
system_prompt = Llama_ManualFunctionCalling.populate_tool_prompt(NaturalLanguageConditions)
user_prompt = user_prompt.format(
    model_schema=json.dumps(metadata_field_info),
    examples=""
)

In [None]:
from langchain.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage

In [75]:
prompt = ChatPromptTemplate.from_messages([
    SystemMessage(system_prompt),
    HumanMessage(user_prompt)    
])

In [79]:
print(user_prompt)


Your task is to extract user-defined conditions from a query, focusing on metadata fields relevant to filtering that are specified in the schema below. 
Identify each condition explicitly mentioned in the query and assign it to the appropriate metadata field.

Extract the conditions only from the last user query as the other ones are used as an examples

A simple schema below briefly describes all the metadata fields we use for filtering purposes:
[{"name": "platform", "description": "The platform where the asset is hosted. ONLY PERMITTED VALUES: ['huggingface', 'openml', 'zenodo']", "type": "string"}, {"name": "date_published", "description": "The original publication date of the asset in the format 'YYYY-MM-DD'.", "type": "string"}, {"name": "year", "description": "The year extracted from the publication date.", "type": "integer"}, {"name": "month", "description": "The month extracted from the publication date.", "type": "integer"}, {"name": "domains", "description": "The AI technic

In [76]:
user_prompt

'\nYour task is to extract user-defined conditions from a query, focusing on metadata fields relevant to filtering that are specified in the schema below. \nIdentify each condition explicitly mentioned in the query and assign it to the appropriate metadata field.\n\nExtract the conditions only from the last user query as the other ones are used as an examples\n\nA simple schema below briefly describes all the metadata fields we use for filtering purposes:\n[{"name": "platform", "description": "The platform where the asset is hosted. ONLY PERMITTED VALUES: [\'huggingface\', \'openml\', \'zenodo\']", "type": "string"}, {"name": "date_published", "description": "The original publication date of the asset in the format \'YYYY-MM-DD\'.", "type": "string"}, {"name": "year", "description": "The year extracted from the publication date.", "type": "integer"}, {"name": "month", "description": "The month extracted from the publication date.", "type": "integer"}, {"name": "domains", "description":

In [77]:
prompt

ChatPromptTemplate(input_variables=[], input_types={}, partial_variables={}, messages=[SystemMessage(content='\n        You have access to the following functions:\n\n        Use the function \'NaturalLanguageConditions\' to \'Extraction of natural language conditions found within a user query\':\n        {"$defs": {"NaturalLanguageCondition": {"properties": {"condition": {"description": "Natural language condition corresponding to a particular metadata field we use for filtering", "title": "Condition", "type": "string"}, "field": {"description": "Name of the metadata field", "title": "Field", "type": "string"}}, "required": ["condition", "field"], "title": "NaturalLanguageCondition", "type": "object"}}, "description": "Extraction of natural language conditions found within a user query", "name": "NaturalLanguageConditions", "parameters": {"type": "object", "properties": {"conditions": {"description": "Natural language conditions", "items": {"$ref": "#/$defs/NaturalLanguageCondition"},

In [None]:
# TODO incorporate few shot examples into the user prompt + inject information

In [18]:
user_query = (
    "Retrieve all the translation Stanford datasets with at least 10k datapoints and has over 100k KB in size" +
    "and the dataset should have contain Slovak language, Polish language, but no Czech language."
)