Skip to content

Commit ad42cd4

Browse files
authored
Merge pull request weaviate#367 from weaviate/hybrid_fusion
Add hybrid fusion type
2 parents a6a0069 + 4d6cd3f commit ad42cd4

File tree

3 files changed

+65
-9
lines changed

3 files changed

+65
-9
lines changed

integration/test_graphql.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import weaviate
1010
from weaviate import Tenant
1111
from weaviate.data.replication import ConsistencyLevel
12+
from weaviate.gql.get import HybridFusion
1213

1314
schema = {
1415
"classes": [
@@ -197,9 +198,14 @@ def test_bm25_no_result(client):
197198

198199

199200
@pytest.mark.parametrize("query", ["sponges", "sponges\n"])
200-
def test_hybrid(client, query: str):
201+
@pytest.mark.parametrize("fusion_type", [HybridFusion.RANKED, HybridFusion.RELATIVE_SCORE, None])
202+
def test_hybrid(client, query: str, fusion_type: Optional[HybridFusion]):
201203
"""Test hybrid search with alpha=0.5 to have a combination of BM25 and vector search."""
202-
result = client.query.get("Ship", ["name", "description"]).with_hybrid(query, alpha=0.5).do()
204+
result = (
205+
client.query.get("Ship", ["name", "description"])
206+
.with_hybrid(query, alpha=0.5, fusion_type=fusion_type)
207+
.do()
208+
)
203209

204210
# will find more results. "The Crusty Crab" is still first, because it matches with the BM25 search
205211
assert len(result["data"]["Get"]["Ship"]) >= 1

test/gql/test_get.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,15 @@
66

77
from test.util import check_error_message
88
from weaviate.data.replication import ConsistencyLevel
9-
from weaviate.gql.get import GetBuilder, BM25, Hybrid, LinkTo, GroupBy, AdditionalProperties
9+
from weaviate.gql.get import (
10+
GetBuilder,
11+
BM25,
12+
Hybrid,
13+
LinkTo,
14+
GroupBy,
15+
AdditionalProperties,
16+
HybridFusion,
17+
)
1018

1119
mock_connection_v117 = Mock()
1220
mock_connection_v117.server_version = "1.17.4"
@@ -70,34 +78,61 @@ def test_get_references(property_name: str, in_class: str, properties: List[str]
7078

7179

7280
@pytest.mark.parametrize(
73-
"query,vector,alpha,properties,expected",
81+
"query,vector,alpha,properties,fusion_type,expected",
7482
[
7583
(
7684
"query",
7785
[1, 2, 3],
7886
0.5,
7987
None,
88+
None,
8089
'hybrid:{query: "query", vector: [1, 2, 3], alpha: 0.5}',
8190
),
82-
("query", None, None, None, 'hybrid:{query: "query"}'),
83-
("query", None, None, ["prop1"], 'hybrid:{query: "query", properties: ["prop1"]}'),
91+
("query", None, None, None, None, 'hybrid:{query: "query"}'),
92+
("query", None, None, ["prop1"], None, 'hybrid:{query: "query", properties: ["prop1"]}'),
8493
(
8594
"query",
8695
None,
8796
None,
8897
["prop1", "prop2"],
98+
None,
8999
'hybrid:{query: "query", properties: ["prop1","prop2"]}',
90100
),
101+
(
102+
"query",
103+
None,
104+
None,
105+
None,
106+
HybridFusion.RANKED,
107+
'hybrid:{query: "query", fusionType: rankedFusion}',
108+
),
109+
(
110+
"query",
111+
None,
112+
None,
113+
None,
114+
HybridFusion.RELATIVE_SCORE,
115+
'hybrid:{query: "query", fusionType: relativeScoreFusion}',
116+
),
117+
(
118+
"query",
119+
None,
120+
None,
121+
None,
122+
"relativeScoreFusion",
123+
'hybrid:{query: "query", fusionType: relativeScoreFusion}',
124+
),
91125
],
92126
)
93127
def test_hybrid(
94128
query: str,
95129
vector: Optional[List[float]],
96130
alpha: Optional[float],
97131
properties: Optional[List[str]],
132+
fusion_type: HybridFusion,
98133
expected: str,
99134
):
100-
hybrid = Hybrid(query, alpha, vector, properties)
135+
hybrid = Hybrid(query, alpha, vector, properties, fusion_type)
101136
assert str(hybrid) == expected
102137

103138

weaviate/gql/get.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
GraphQL `Get` command.
33
"""
44
from dataclasses import dataclass, Field, fields
5+
from enum import Enum
56
from json import dumps
67
from typing import List, Union, Optional, Dict, Tuple
78

@@ -21,7 +22,7 @@
2122
Sort,
2223
)
2324
from weaviate.types import UUID
24-
from weaviate.util import image_encoder_b64, _capitalize_first_letter, get_valid_uuid
25+
from weaviate.util import image_encoder_b64, _capitalize_first_letter, get_valid_uuid, BaseEnum
2526
from weaviate.warnings import _Warnings
2627

2728
try:
@@ -44,12 +45,18 @@ def __str__(self) -> str:
4445
return "bm25:{" + ret + "}"
4546

4647

48+
class HybridFusion(str, BaseEnum):
49+
RANKED = "rankedFusion"
50+
RELATIVE_SCORE = "relativeScoreFusion"
51+
52+
4753
@dataclass
4854
class Hybrid:
4955
query: str
5056
alpha: Optional[float]
5157
vector: Optional[List[float]]
5258
properties: Optional[List[str]]
59+
fusion_type: Optional[HybridFusion]
5360

5461
def __str__(self) -> str:
5562
ret = f'query: "{util.strip_newlines(self.query)}"'
@@ -60,6 +67,11 @@ def __str__(self) -> str:
6067
if self.properties is not None and len(self.properties) > 0:
6168
props = '","'.join(self.properties)
6269
ret += f', properties: ["{props}"]'
70+
if self.fusion_type is not None:
71+
if isinstance(self.fusion_type, Enum):
72+
ret += f", fusionType: {self.fusion_type.value}"
73+
else:
74+
ret += f", fusionType: {self.fusion_type}"
6375

6476
return "hybrid:{" + ret + "}"
6577

@@ -1035,6 +1047,7 @@ def with_hybrid(
10351047
alpha: Optional[float] = None,
10361048
vector: Optional[List[float]] = None,
10371049
properties: Optional[List[str]] = None,
1050+
fusion_type: Optional[HybridFusion] = None,
10381051
):
10391052
"""Get objects using bm25 and vector, then combine the results using a reciprocal ranking algorithm.
10401053
@@ -1053,8 +1066,10 @@ def with_hybrid(
10531066
properties: Optional[List[str]]:
10541067
Which properties should be searched by BM25. Does not have any effect for vector search. If None or empty
10551068
all properties are searched.
1069+
fusion_type: Optional[HybridFusionType]:
1070+
Which fusion type should be used to merge keyword and vector search.
10561071
"""
1057-
self._hybrid = Hybrid(query, alpha, vector, properties)
1072+
self._hybrid = Hybrid(query, alpha, vector, properties, fusion_type)
10581073
self._contains_filter = True
10591074
return self
10601075

0 commit comments

Comments
 (0)