Skip to content

Commit 468bde6

Browse files
authoredFeb 19, 2025
Support dicts in lists in metadata (#148)
* Support dicts in lists in metadata This allows cases like: ``` metadata = { "entities": [ {"entity": "Bell", "type": "PEOPLE"}, ... ] } ``` Note: For stores that perform shredding this works by JSON encoding the entire item `{"entity": "Bell", "type": "PEOPLE"}` into the key. This means that equality on the items of `entities` are supported, by digging into fields won't be. * lint/fmt * lint
1 parent b99bddc commit 468bde6

File tree

16 files changed

+172
-63
lines changed

16 files changed

+172
-63
lines changed
 

‎data/animals.jsonl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
{"id": "aardvark", "text": "the aardvark is a nocturnal mammal known for its burrowing habits and long snout used to sniff out ants.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["burrowing", "nocturnal", "ants", "savanna"], "habitat": "savanna"}}
2-
{"id": "albatross", "text": "the albatross is a large seabird with the longest wingspan of any bird, allowing it to glide effortlessly over oceans.", "metadata": {"type": "bird", "number_of_legs": 2, "keywords": ["seabird", "wingspan", "ocean"], "habitat": "marine"}}
1+
{"id": "aardvark", "text": "the aardvark is a nocturnal mammal known for its burrowing habits and long snout used to sniff out ants.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["burrowing", "nocturnal", "ants", "savanna"], "habitat": "savanna", "tags": [{"a": 5, "b": 7}, {"a": 8, "b": 10}]}}
2+
{"id": "albatross", "text": "the albatross is a large seabird with the longest wingspan of any bird, allowing it to glide effortlessly over oceans.", "metadata": {"type": "bird", "number_of_legs": 2, "keywords": ["seabird", "wingspan", "ocean"], "habitat": "marine", "tags": [{"a": 5, "b": 8}, {"a": 8, "b": 10}]}}
33
{"id": "alligator", "text": "alligators are large reptiles with powerful jaws and are commonly found in freshwater wetlands.", "metadata": {"type": "reptile", "number_of_legs": 4, "keywords": ["reptile", "jaws", "wetlands"], "diet": "carnivorous", "nested": { "a": 5 }}}
44
{"id": "alpaca", "text": "alpacas are domesticated mammals valued for their soft wool and friendly demeanor.", "metadata": {"type": "mammal", "number_of_legs": 4, "keywords": ["wool", "domesticated", "friendly"], "origin": "south america", "nested": { "a": 5 }}}
55
{"id": "ant", "text": "ants are social insects that live in colonies and are known for their teamwork and strength.", "metadata": {"type": "insect", "number_of_legs": 6, "keywords": ["social", "colonies", "strength", "pollinator"], "diet": "omnivorous", "nested": { "a": 6 }}}

‎packages/graph-retriever/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ dependencies = [
4242
"numpy>=1.26.4",
4343
"typing-extensions>=4.12.2",
4444
"pytest>=8.3.4",
45+
"immutabledict>=4.2.1",
4546
]
4647

4748
[project.urls]

‎packages/graph-retriever/src/graph_retriever/adapters/base.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from collections.abc import Iterable, Sequence
66
from typing import Any
77

8+
from immutabledict import immutabledict
9+
810
from graph_retriever.content import Content
911
from graph_retriever.edges import Edge, IdEdge, MetadataEdge
1012
from graph_retriever.utils.run_in_executor import run_in_executor
@@ -355,8 +357,8 @@ async def aadjacent(
355357

356358
def _metadata_filter(
357359
self,
360+
edge: Edge,
358361
base_filter: dict[str, Any] | None = None,
359-
edge: Edge | None = None,
360362
) -> dict[str, Any]:
361363
"""
362364
Return a filter for the `base_filter` and incoming edges from `edge`.
@@ -376,10 +378,8 @@ def _metadata_filter(
376378
:
377379
The metadata dictionary to use for the given filter.
378380
"""
379-
metadata_filter = {**(base_filter or {})}
380381
assert isinstance(edge, MetadataEdge)
381-
if edge is None:
382-
metadata_filter
383-
else:
384-
metadata_filter[edge.incoming_field] = edge.value
385-
return metadata_filter
382+
value = edge.value
383+
if isinstance(value, immutabledict):
384+
value = dict(value)
385+
return {edge.incoming_field: value, **(base_filter or {})}

‎packages/graph-retriever/src/graph_retriever/edges/_base.py

+13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from typing import Any, TypeAlias
55

6+
from immutabledict import immutabledict
7+
68
from graph_retriever import Content
79

810

@@ -34,6 +36,17 @@ class MetadataEdge(Edge):
3436
The value associated with the key for this edge
3537
"""
3638

39+
def __init__(self, incoming_field: str, value: Any) -> None:
40+
# `self.field = value` and `setattr(self, "field", value)` -- don't work
41+
# because of frozen. we need to call `__setattr__` directly (as the
42+
# default `__init__` would do) to initialize the fields of the frozen
43+
# dataclass.
44+
object.__setattr__(self, "incoming_field", incoming_field)
45+
46+
if isinstance(value, dict):
47+
value = immutabledict(value)
48+
object.__setattr__(self, "value", value)
49+
3750
incoming_field: str
3851
value: Any
3952

‎packages/graph-retriever/src/graph_retriever/testing/adapter_tests.py

+85
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ class AdapterComplianceCase(abc.ABC):
5656

5757
id: str
5858
expected: list[str]
59+
5960
requires_nested: bool = False
61+
requires_dict_in_list: bool = False
6062

6163

6264
@dataclass
@@ -191,6 +193,50 @@ class AdjacentCase(AdapterComplianceCase):
191193
"horse",
192194
],
193195
),
196+
AdjacentCase(
197+
id="numeric",
198+
query="domesticated hunters",
199+
edges={
200+
MetadataEdge("number_of_legs", 0),
201+
},
202+
k=20, # more than match the filter so we get all
203+
expected=[
204+
"barracuda",
205+
"cobra",
206+
"dolphin",
207+
"eel",
208+
"fish",
209+
"jellyfish",
210+
"manatee",
211+
"narwhal",
212+
],
213+
),
214+
AdjacentCase(
215+
id="two_edges_diff_field",
216+
query="domesticated hunters",
217+
edges={
218+
MetadataEdge("type", "reptile"),
219+
MetadataEdge("number_of_legs", 0),
220+
},
221+
k=20, # more than match the filter so we get all
222+
expected=[
223+
"alligator",
224+
"barracuda",
225+
"chameleon",
226+
"cobra",
227+
"crocodile",
228+
"dolphin",
229+
"eel",
230+
"fish",
231+
"gecko",
232+
"iguana",
233+
"jellyfish",
234+
"komodo dragon",
235+
"lizard",
236+
"manatee",
237+
"narwhal",
238+
],
239+
),
194240
AdjacentCase(
195241
id="one_ids",
196242
query="domesticated hunters",
@@ -262,6 +308,39 @@ class AdjacentCase(AdapterComplianceCase):
262308
"komodo dragon", # reptile
263309
],
264310
),
311+
AdjacentCase(
312+
id="dict_in_list",
313+
query="domesticated hunters",
314+
edges={
315+
MetadataEdge("tags", {"a": 5, "b": 7}),
316+
},
317+
expected=[
318+
"aardvark",
319+
],
320+
requires_dict_in_list=True,
321+
),
322+
AdjacentCase(
323+
id="dict_in_list_multiple",
324+
query="domesticated hunters",
325+
edges={
326+
MetadataEdge("tags", {"a": 5, "b": 7}),
327+
MetadataEdge("tags", {"a": 5, "b": 8}),
328+
},
329+
expected=[
330+
"aardvark",
331+
"albatross",
332+
],
333+
requires_dict_in_list=True,
334+
),
335+
AdjacentCase(
336+
id="absent_dict",
337+
query="domesticated hunters",
338+
edges={
339+
MetadataEdge("tags", {"a": 5, "b": 10}),
340+
},
341+
expected=[],
342+
requires_dict_in_list=True,
343+
),
265344
AdjacentCase(
266345
id="nested",
267346
query="domesticated hunters",
@@ -318,6 +397,10 @@ def supports_nested_metadata(self) -> bool:
318397
"""Return whether nested metadata is expected to work."""
319398
return True
320399

400+
def supports_dict_in_list(self) -> bool:
401+
"""Return whether dicts can appear in list fields in metadata."""
402+
return True
403+
321404
def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
322405
"""
323406
Override to change the expected behavior of a case.
@@ -346,6 +429,8 @@ def expected(self, method: str, case: AdapterComplianceCase) -> list[str]:
346429
"""
347430
if not self.supports_nested_metadata() and case.requires_nested:
348431
pytest.xfail("nested metadata not supported")
432+
if not self.supports_dict_in_list() and case.requires_dict_in_list:
433+
pytest.xfail("dict-in-list fields is not supported")
349434
return case.expected
350435

351436
@pytest.fixture(params=GET_CASES, ids=lambda c: c.id)

‎packages/langchain-graph-retriever/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ classifiers = [
4242
dependencies = [
4343
"backoff>=2.2.1",
4444
"graph-retriever",
45+
"immutabledict>=4.2.1",
4546
"langchain-core>=0.3.29",
4647
"networkx>=3.4.2",
4748
"pydantic>=2.10.4",

‎packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/astra.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from graph_retriever.utils import merge
1313
from graph_retriever.utils.batched import batched
1414
from graph_retriever.utils.top_k import top_k
15+
from immutabledict import immutabledict
1516
from typing_extensions import override
1617

1718
try:
@@ -107,13 +108,27 @@ def with_user_filters(
107108
) -> dict[str, Any]:
108109
return filter if encoded else codec.encode_filter(filter)
109110

111+
def process_value(v: Any) -> Any:
112+
if isinstance(v, immutabledict):
113+
return dict(v)
114+
else:
115+
return v
116+
110117
for k, v in metadata.items():
111118
for v_batch in batched(v, 100):
112-
batch = list(v_batch)
113-
if len(batch) == 1:
114-
yield (with_user_filters({k: batch[0]}, encoded=False))
119+
batch = [process_value(v) for v in v_batch]
120+
if isinstance(batch[0], dict):
121+
if len(batch) == 1:
122+
yield with_user_filters({k: {"$all": [batch[0]]}}, encoded=False)
123+
else:
124+
yield with_user_filters(
125+
{"$or": [{k: {"$all": [v]}} for v in batch]}, encoded=False
126+
)
115127
else:
116-
yield (with_user_filters({k: {"$in": batch}}, encoded=False))
128+
if len(batch) == 1:
129+
yield (with_user_filters({k: batch[0]}, encoded=False))
130+
else:
131+
yield (with_user_filters({k: {"$in": batch}}, encoded=False))
117132

118133
for id_batch in batched(ids, 100):
119134
ids = list(id_batch)

‎packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/langchain.py

-32
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from graph_retriever import Content
88
from graph_retriever.adapters import Adapter
9-
from graph_retriever.edges import Edge, MetadataEdge
109
from langchain_core.documents import Document
1110
from langchain_core.embeddings import Embeddings
1211
from langchain_core.runnables import run_in_executor
@@ -358,37 +357,6 @@ async def _aget(
358357
**kwargs,
359358
)
360359

361-
def _metadata_filter(
362-
self,
363-
base_filter: dict[str, Any] | None = None,
364-
edge: Edge | None = None,
365-
) -> dict[str, Any]:
366-
"""
367-
Return a filter for the `base_filter` and incoming edges from `edge`.
368-
369-
Parameters
370-
----------
371-
base_filter :
372-
Any base metadata filter that should be used for search.
373-
Generally corresponds to the user specified filters for the entire
374-
traversal. Should be combined with the filters necessary to support
375-
nodes with an *incoming* edge matching `edge`.
376-
edge :
377-
An optional edge which should be added to the filter.
378-
379-
Returns
380-
-------
381-
:
382-
The metadata dictionary to use for the given filter.
383-
"""
384-
metadata_filter = {**(base_filter or {})}
385-
assert isinstance(edge, MetadataEdge)
386-
if edge is None:
387-
metadata_filter
388-
else:
389-
metadata_filter[edge.incoming_field] = edge.value
390-
return metadata_filter
391-
392360

393361
class ShreddedLangchainAdapter(LangchainAdapter[StoreT]):
394362
"""

‎packages/langchain-graph-retriever/src/langchain_graph_retriever/adapters/open_search.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, vector_store: OpenSearchVectorSearch):
5454
self._id_field = "_id"
5555

5656
def _build_filter(
57-
self, filter: dict[str, str] | None = None
57+
self, filter: dict[str, Any] | None = None
5858
) -> list[dict[str, Any]] | None:
5959
"""
6060
Build a filter query for OpenSearch based on metadata.
@@ -68,17 +68,24 @@ def _build_filter(
6868
-------
6969
:
7070
Filter query for OpenSearch.
71+
72+
Raises
73+
------
74+
ValueError
75+
If the query is not supported by OpenSearch adapter.
7176
"""
7277
if filter is None:
7378
return None
74-
return [
75-
{
76-
"terms" if isinstance(value, list) else "term": {
77-
f"metadata.{key}.keyword": value
78-
}
79-
}
80-
for key, value in filter.items()
81-
]
79+
80+
filters = []
81+
for key, value in filter.items():
82+
if isinstance(value, list):
83+
filters.append({"terms": {f"metadata.{key}": value}})
84+
elif isinstance(value, dict):
85+
raise ValueError("Open Search doesn't suport dictionary searches.")
86+
else:
87+
filters.append({"term": {f"metadata.{key}": value}})
88+
return filters
8289

8390
@override
8491
def _search(
@@ -92,9 +99,8 @@ def _search(
9299
# use an efficient_filter to collect results that
93100
# are near the embedding vector until up to 'k'
94101
# documents that match the filter are found.
95-
kwargs["efficient_filter"] = {
96-
"bool": {"must": self._build_filter(filter=filter)}
97-
}
102+
query = {"bool": {"must": self._build_filter(filter=filter)}}
103+
kwargs["efficient_filter"] = query
98104

99105
if k == 0:
100106
return []

‎packages/langchain-graph-retriever/src/langchain_graph_retriever/transformers/shredding.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ def restore_documents(
110110
and value == self.static_value
111111
):
112112
original_key, original_value = split_key
113+
value = json.loads(original_value)
113114
if original_key not in new_doc.metadata:
114115
new_doc.metadata[original_key] = []
115-
new_doc.metadata[original_key].append(original_value)
116+
new_doc.metadata[original_key].append(value)
116117
else:
117118
# Retain non-shredded metadata as is
118119
new_doc.metadata[key] = value
@@ -137,7 +138,7 @@ def shredded_key(self, key: str, value: Any) -> str:
137138
str
138139
the shredded key
139140
"""
140-
return f"{key}{self.path_delimiter}{value}"
141+
return f"{key}{self.path_delimiter}{json.dumps(value)}"
141142

142143
def shredded_value(self) -> str:
143144
"""

‎packages/langchain-graph-retriever/tests/adapters/test_astra.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,8 @@ def astra_config(enabled_stores: set[str]) -> Iterator[_AstraConfig | None]:
212212
assert found, f"Keyspace '{keyspace}' not created"
213213
yield _AstraConfig(token=token, keyspace=keyspace, api_endpoint=api_endpoint)
214214

215-
admin.drop_keyspace(keyspace)
215+
if keyspace != "default_keyspace":
216+
admin.drop_keyspace(keyspace)
216217

217218

218219
class TestAstraAdapter(AdapterComplianceSuite):

‎packages/langchain-graph-retriever/tests/adapters/test_cassandra.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def adapter(
8383
)
8484
docs = list(shredder.transform_documents(animal_docs))
8585
store.add_documents(docs)
86-
yield CassandraAdapter(store, shredder, {"keywords"})
86+
yield CassandraAdapter(store, shredder, {"keywords", "tags"})
8787

8888
if session:
8989
session.shutdown()

‎packages/langchain-graph-retriever/tests/adapters/test_chroma.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def remove_nested_metadata(doc: Document) -> Document:
4848
collection_metadata={"hnsw:space": "cosine"},
4949
)
5050

51-
yield ChromaAdapter(store, shredder, nested_metadata_fields={"keywords"})
51+
yield ChromaAdapter(
52+
store, shredder, nested_metadata_fields={"keywords", "tags"}
53+
)
5254

5355
store.delete_collection()

0 commit comments

Comments
 (0)
Failed to load comments.