Skip to content

Air new methods #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion aixplain/modules/model/index_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.split_length = split_length
self.split_overlap = split_overlap


class IndexModel(Model):
def __init__(
self,
Expand Down Expand Up @@ -151,7 +152,7 @@ def search(self, query: str, top_k: int = 10, filters: List[IndexFilter] = []) -
"data": query or uri,
"dataType": value_type,
"filters": [filter.to_dict() for filter in filters],
"payload": {"uri": uri, "value_type": value_type, "top_k": top_k}
"payload": {"uri": uri, "value_type": value_type, "top_k": top_k},
}
return self.run(data=data)

Expand Down Expand Up @@ -246,3 +247,49 @@ def delete_record(self, record_id: Text) -> ModelResponse:
if response.status == "SUCCESS":
return response
raise Exception(f"Failed to delete record: {response.error_message}")

def retrieve_records_with_filter(self, filter: IndexFilter) -> ModelResponse:
"""
Retrieve records from the index that match the given filter.

Args:
filter (IndexFilter): The filter criteria to apply when retrieving records.

Returns:
ModelResponse: Response containing the retrieved records.

Raises:
Exception: If retrieval fails.

Example:
>>> from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator
>>> my_filter = IndexFilter(field="category", value="world", operator=IndexFilterOperator.EQUALS)
>>> index_model.retrieve_records_with_filter(my_filter)
"""
data = {"action": "retrieve_by_filter", "data": filter.to_dict()}
response = self.run(data=data)
if response.status == "SUCCESS":
return response
raise Exception(f"Failed to retrieve records with filter: {response.error_message}")

def delete_records_by_date(self, date: float) -> ModelResponse:
"""
Delete records from the index that match the given date.

Args:
date (float): The date (as a timestamp) to match records for deletion.

Returns:
ModelResponse: Response containing the result of the deletion operation.

Raises:
Exception: If deletion fails.

Example:
>>> index_model.delete_records_by_date(1717708800)
"""
data = {"action": "delete_by_date", "data": date}
response = self.run(data=data)
if response.status == "SUCCESS":
return response
raise Exception(f"Failed to delete records by date: {response.error_message}")
94 changes: 94 additions & 0 deletions tests/functional/model/run_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,97 @@ def test_index_model_air_with_splitter(embedding_model, supplier_params):
assert str(response.status) == "SUCCESS"
assert "berlin" in response.data.lower()
index_model.delete()


def _test_records():
from aixplain.modules.model.record import Record
from aixplain.enums import DataType

return [
Record(
value="Artificial intelligence is transforming industries worldwide, from healthcare to finance.",
value_type=DataType.TEXT,
id="doc1",
uri="",
attributes={"category": "technology", "date": 1751464788},
),
Record(
value="The Mona Lisa, painted by Leonardo da Vinci, is one of the most famous artworks in history.",
value_type=DataType.TEXT,
id="doc2",
uri="",
attributes={"category": "art", "date": 1751464790},
),
Record(
value="Machine learning algorithms are being used to predict patient outcomes in hospitals.",
value_type=DataType.TEXT,
id="doc3",
uri="",
attributes={"category": "technology", "date": 1751464795},
),
Record(
value="The Earth orbits the Sun once every 365.25 days, creating the calendar year.",
value_type=DataType.TEXT,
id="doc4",
uri="",
attributes={"category": "science", "date": 1751464798},
),
Record(
value="Quantum computing promises to solve complex problems that are currently intractable for classical computers.",
value_type=DataType.TEXT,
id="doc5",
uri="",
attributes={"category": "technology", "date": 1751464801},
),
]


@pytest.fixture(scope="function")
def setup_index_with_test_records():
from aixplain.factories import IndexFactory
from aixplain.enums import EmbeddingModel
from aixplain.factories.index_factory.utils import AirParams
from uuid import uuid4
import time

# Clean up all existing indexes
for index in IndexFactory.list()["results"]:
index.delete()

params = AirParams(
name=f"Test Index {uuid4()}",
description="Test index for filter/date tests",
embedding_model=EmbeddingModel.OPENAI_ADA002,
)
index_model = IndexFactory.create(params=params)
records = _test_records()

index_model.upsert(records)

yield index_model
index_model.delete()


def test_retrieve_records_with_filter(setup_index_with_test_records):
from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator

index_model = setup_index_with_test_records
filter_ = IndexFilter(field="category", value="technology", operator=IndexFilterOperator.EQUALS)
response = index_model.retrieve_records_with_filter(filter_)
assert response.status == "SUCCESS"
assert len(response.details) == 3
for item in response.details:
assert item["metadata"]["category"] == "technology"


def test_delete_records_by_date(setup_index_with_test_records):
from aixplain.modules.model.index_model import IndexFilter, IndexFilterOperator

index_model = setup_index_with_test_records
response = index_model.delete_records_by_date(1751464796)
assert response.status == "SUCCESS"
assert response.data == "2" # 2 records should remain
filter_all = IndexFilter(field="date", value=0, operator=IndexFilterOperator.GREATER_THAN)
response = index_model.retrieve_records_with_filter(filter_all)
assert response.status == "SUCCESS"
assert len(response.details) == 2