# TODO modifications

Steps:
- Build StructuredQuery (with limited comparators => EQ, GT, GTE, LT, LTE)
- Translate to Milvus Query
- Postprocessing:
    - Check validity of values assigned to individual fields
    - Regex: Replace EQ => ARRAY_CONTAINS for array metadata fields

Problems with the current approach:
- Doesnt adhere to schema or specified valid values (for enums and such)
- Hallucination

Potential solution: 2-stage user query building approach

In [1]:
from langchain_ollama import ChatOllama
from dotenv import load_dotenv
from langchain_core.prompt_values import HumanMessage

from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.chains.query_constructor.base import (
    StructuredQueryOutputParser,
    get_query_constructor_prompt
)
from langchain_core.structured_query import Comparator, Operator
from langchain.retrievers.self_query.milvus import MilvusTranslator

In [2]:
import sys
sys.path.append("./src")

In [3]:
load_dotenv()
MODEL_NAME = "llama3.1:8b"

model = ChatOllama(model=MODEL_NAME, num_predict=4096, num_ctx=8192,)

In [4]:
# import os
# from langchain_openai import AzureChatOpenAI

# model = AzureChatOpenAI(
#     azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT"],
# )

In [5]:
from llm_metadata_filter import DatasetMetadataTemplate

In [20]:
attribute_types = {
    "platform": "string",
    "date_published": "string",
    "year": "integer",
    "month": "integer",
    "domains": "string",
    "task_types": "string",
    "license": "string",
    "size_in_mb": "float",
    "num_datapoints": "integer",
    "size_category": "string",
    "modalities": "string",
    "data_formats": "string",
    "languages": "string",
}

In [21]:
metadata_field_info = [
    AttributeInfo(
        name=name, 
        description=field.description, 
        type=attribute_types[name]
    )
    for name, field in DatasetMetadataTemplate.model_fields.items()
]
document_content_description = DatasetMetadataTemplate.__doc__

In [22]:
allowed_comparators = [
    Comparator.EQ,
    Comparator.GT,
    Comparator.GTE,
    Comparator.LT,
    Comparator.LTE
]
allowed_operators = [
    Operator.AND, 
    Operator.OR, 
    Operator.NOT, 
]

In [23]:
examples = [
    (
        "Retrieve HuggingFace datasets about stocks",
        {
            "filter": 'eq("platform", "huggingface")',
            "query": "stock datasets"
        }
    ),
    (
        "Show me the summarization news datasets containing both the French as well as English data. The dataset however can't include any German data nor any Slovak data.",
        {
            "filter": 'and(eq("task_types", "summarization"), and(eq("languages", "fr"), eq("languages", "en")), not( or(eq("languages", "de"), eq("languages", "sk"))))',
            "query": "news datasets"
        },
    ),
    (
        "Find all chocolate datasets created after January 1, 2022, that are represented in textual or image format with its dataset size smaller than 500 000KB.",
        {
            "filter": 'and(gte("date_published", "2022-01-01"), or(eq("modalities", "text"), eq("modalities", "image")), lt("size_in_mb", 488))',
            "query": "chocolate datasets"
        },
    ),
    (
        "COVID-19 datasets",
        {
            "filter": "NO_FILTER",
            "query": "COVID-19 datasets"
        }
    )
]

In [24]:
custom_schema = """
<< Structured Request Schema >>

When responding use a markdown code snippet with a JSON object formatted in the following schema:

```json
{{{{
    "filter": string, \\ logical condition statement for filtering documents
    "query": string \\ text string to compare to document contents
}}}}
```
Your response should only consist of the said schema with no prefix or suffix. Respond only to the last user query as the others are only examples.

The query string should contain only text that is expected to match the contents of documents or its main description. Any conditions in the filter should not be mentioned in the query as well.

A logical condition statement is composed of one or more comparison and logical operation statements.

A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
- `attr` (string):  name of attribute to apply the comparison to
- `val` (string): is the comparison value

A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to

Make sure that you only use the comparators and logical operators listed above and no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that values of the filters equal to one of the values found within the 'valid_values' field representing the only permitted values of specific attributes.
Make sure to include only those filters that are explicitly defined in the user query. Don't try to infer new ones based on the context.
Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. To this end, you may need to convert the filters to comply with expected values or formats.
Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value.
"""

In [25]:
prompt = get_query_constructor_prompt(
    document_contents="Gist of the dataset",
    attribute_info=metadata_field_info,
    allowed_comparators=allowed_comparators,
    allowed_operators=allowed_operators,
    examples=examples,
    schema_prompt=custom_schema
)
output_parser = StructuredQueryOutputParser.from_components(fix_invalid=True)
query_constructor = prompt | model

In [38]:
user_query = (
    "Retrieve all the translation Stanford datasets with at least 10k datapoints and has over 100k KB in size" +
    "and the dataset should have contain Slovak language, Polish language, but no Czech language."
)

In [204]:
user_query = (
    "Retrieve all translation datasets that either have at least 10k datapoints and has over 100k KB in size" +
    "or they contain Slovak language and Polish language, but no Czech language."
)

In [26]:
user_query = "Retrieve all the translation datasets from AIOD platform"

In [45]:
output = query_constructor.invoke({"query": user_query})
print(output.content)

{
    "filter": "and(eq(\"task_types\", \"translation\"), eq(\"domains\", \"Stanford\"), gte(\"num_datapoints\", 10e4), gt(\"size_in_mb\", 100), and(eq(\"languages\", \"sk\"), eq(\"languages\", \"pl\")), not(eq(\"languages\", \"cz\")))",
    "query": "translation datasets"
}


In [49]:
StructuredQueryOutputParser.from_components(fix_invalid=False).invoke(output)

StructuredQuery(query='translation datasets', filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='task_types', value='translation'), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='domains', value='Stanford'), Comparison(comparator=<Comparator.GTE: 'gte'>, attribute='num_datapoints', value=100000.0), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='size_in_mb', value=100), Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='languages', value='sk'), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='languages', value='pl')]), Operation(operator=<Operator.NOT: 'not'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='languages', value='cz')])]), limit=None)

In [47]:
MilvusTranslator().visit_structured_query(
    StructuredQueryOutputParser.from_components(fix_invalid=True).invoke(output)
)


('translation datasets',
 {'expr': '(( task_types == "translation" ) and ( domains == "Stanford" ) and ( num_datapoints >= 100000.0 ) and ( size_in_mb > 100 ) and (( languages == "sk" ) and ( languages == "pl" )) and not(( languages == "cz" )))'})

Main problems to tackle:
- Hallucinations
- Not adhering to schema, to permitted values of individual fields...
    - If model creates a filter that doesnt adhere to the schema of a particular metadata field (lets say: datasets from AIoD platform), in that case we shall add this information to the query instead