diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index dc1c351ea..08be2a457 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -4,7 +4,7 @@ from aleph_message.models import ItemHash, Chain, MessageType from sqlalchemy import func, select, update, text, delete, nullsfirst, nullslast -from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.dialects.postgresql import insert, array from sqlalchemy.orm import selectinload, load_only from sqlalchemy.sql import Insert, Select from sqlalchemy.sql.elements import literal @@ -105,7 +105,7 @@ def make_matching_messages_query( ) if tags: select_stmt = select_stmt.where( - MessageDb.content["content"]["tags"].contains(tags) + MessageDb.content["content"]["tags"].has_any(array(tags)) ) if channels: select_stmt = select_stmt.where(MessageDb.channel.in_(channels)) diff --git a/src/aleph/db/accessors/posts.py b/src/aleph/db/accessors/posts.py index 2f8fc6f3d..50a1e72d4 100644 --- a/src/aleph/db/accessors/posts.py +++ b/src/aleph/db/accessors/posts.py @@ -28,7 +28,7 @@ Float, case, ) -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONB, array from sqlalchemy.orm import aliased from sqlalchemy.sql import Select @@ -230,7 +230,7 @@ def filter_post_select_stmt( select_stmt = select_stmt.where(literal_column("original_type").in_(post_types)) if tags: select_stmt = select_stmt.where( - literal_column("content", type_=JSONB)["tags"].astext.in_(tags) + literal_column("content", type_=JSONB)["tags"].has_any(array(tags)) ) if channels: select_stmt = select_stmt.where(literal_column("channel").in_(channels)) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 151596cc3..10a4ffa12 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -1,7 +1,8 @@ import json from pathlib import Path -from typing import Any, Dict, Sequence, cast +from typing import Any, Dict, Sequence, cast, Tuple +import pytest import pytest_asyncio from aleph_message.models import AggregateContent, PostContent from sqlalchemy import insert @@ -10,12 +11,13 @@ from aleph.db.models import ( MessageDb, ChainTxDb, - AggregateElementDb, message_confirmations, + AggregateElementDb, + message_confirmations, ) from aleph.db.models.posts import PostDb from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.types.db_session import DbSessionFactory - +import datetime as dt # TODO: remove the raw parameter, it's just to avoid larger refactorings async def _load_fixtures( @@ -122,3 +124,67 @@ async def fixture_posts( session.commit() return posts + + +@pytest.fixture +def post_with_refs_and_tags() -> Tuple[MessageDb, PostDb]: + message = MessageDb( + item_hash="1234", + sender="0xdeadbeef", + type="POST", + chain="ETH", + signature=None, + item_type="storage", + item_content=None, + content={"content": {"tags": ["original", "mainnet"], "swap": "this"}}, + time=dt.datetime(2023, 5, 1, tzinfo=dt.timezone.utc), + channel=None, + size=254, + ) + + post = PostDb( + item_hash=message.item_hash, + owner=message.sender, + type=None, + ref="custom-ref", + amends=None, + channel=None, + content=message.content["content"], + creation_datetime=message.time, + latest_amend=None, + ) + + return message, post + + +@pytest.fixture +def amended_post_with_refs_and_tags(post_with_refs_and_tags: Tuple[MessageDb, PostDb]): + original_message, original_post = post_with_refs_and_tags + + amend_message = MessageDb( + item_hash="5678", + sender="0xdeadbeef", + type="POST", + chain="ETH", + signature=None, + item_type="storage", + item_content=None, + content={"content": {"tags": ["amend", "mainnet"], "don't": "swap"}}, + time=dt.datetime(2023, 5, 2, tzinfo=dt.timezone.utc), + channel=None, + size=277, + ) + + amend_post = PostDb( + item_hash=amend_message.item_hash, + owner=original_message.sender, + type="amend", + ref=original_message.item_hash, + amends=original_message.item_hash, + channel=None, + content=amend_message.content["content"], + creation_datetime=amend_message.time, + latest_amend=None, + ) + + return amend_message, amend_post diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index c317b38a9..d75877729 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -5,6 +5,8 @@ import aiohttp import pytest +from aleph.db.models import MessageDb, PostDb +from aleph.types.db_session import DbSessionFactory from .utils import get_messages_by_keys MESSAGES_URI = "/api/v0/messages.json" @@ -149,18 +151,74 @@ async def test_get_messages_multiple_hashes(fixture_messages, ccn_api_client): @pytest.mark.asyncio -async def test_get_messages_filter_by_tags(fixture_messages, ccn_api_client): +async def test_get_messages_filter_by_tags( + fixture_messages, + ccn_api_client, + session_factory: DbSessionFactory, + post_with_refs_and_tags: Tuple[MessageDb, PostDb], + amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb] +): """ Tests getting messages by tags. There's no example in the fixtures, we just test that the endpoint returns a 200. - # TODO: add a POST message fixture with tags. """ - tags = ["mainnet"] + message_db, _ = post_with_refs_and_tags + amend_message_db, _ = amended_post_with_refs_and_tags - response = await ccn_api_client.get( - MESSAGES_URI, params={"tags": ",".join(tags)} - ) + with session_factory() as session: + session.add_all([message_db, amend_message_db]) + session.commit() + + # Matching tag for both messages + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"}) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert len(messages) == 2 + + # Matching tags for both messages + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "original,amend"}) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert len(messages) == 2 + + # Matching the original tag + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "original"}) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert len(messages) == 1 + assert messages[0]["item_hash"] == message_db.item_hash + + # Matching the amend tag + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "amend"}) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert len(messages) == 1 + assert messages[0]["item_hash"] == amend_message_db.item_hash + + # No match + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "not-a-tag"}) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert len(messages) == 0 + + # Matching the amend tag with other tags + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "amend,not-a-tag,not-a-tag-either"}) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert len(messages) == 1 + assert messages[0]["item_hash"] == amend_message_db.item_hash + + +@pytest.mark.asyncio +async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_client): + """ + Tests getting messages by tags. + There's no example in the fixtures, we just test that the endpoint returns a 200. + """ + + # Matching tag + response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"}) assert response.status == 200, await response.text() messages = (await response.json())["messages"] assert len(messages) == 0 diff --git a/tests/api/test_posts.py b/tests/api/test_posts.py index bc25e7088..72bc8f5a2 100644 --- a/tests/api/test_posts.py +++ b/tests/api/test_posts.py @@ -5,8 +5,6 @@ from aleph.db.models import MessageDb from aleph.db.models.posts import PostDb -import datetime as dt - from aleph.types.db_session import DbSessionFactory POSTS_URI = "/api/v1/posts.json" @@ -55,70 +53,6 @@ async def test_get_posts(ccn_api_client, fixture_posts: Sequence[PostDb]): assert_posts_equal(posts, fixture_posts) -@pytest.fixture -def post_with_refs_and_tags() -> Tuple[MessageDb, PostDb]: - message = MessageDb( - item_hash="1234", - sender="0xdeadbeef", - type="POST", - chain="ETH", - signature=None, - item_type="storage", - item_content=None, - content={}, - time=dt.datetime(2023, 5, 1, tzinfo=dt.timezone.utc), - channel=None, - size=254, - ) - - post = PostDb( - item_hash=message.item_hash, - owner=message.sender, - type=None, - ref="custom-ref", - amends=None, - channel=None, - content={"swap": "this"}, - creation_datetime=message.time, - latest_amend=None, - ) - - return message, post - - -@pytest.fixture -def amended_post_with_refs_and_tags(post_with_refs_and_tags: Tuple[MessageDb, PostDb]): - original_message, original_post = post_with_refs_and_tags - - amend_message = MessageDb( - item_hash="5678", - sender="0xdeadbeef", - type="POST", - chain="ETH", - signature=None, - item_type="storage", - item_content=None, - content={}, - time=dt.datetime(2023, 5, 2, tzinfo=dt.timezone.utc), - channel=None, - size=277, - ) - - amend_post = PostDb( - item_hash=amend_message.item_hash, - owner=original_message.sender, - type="amend", - ref=original_message.item_hash, - amends=original_message.item_hash, - channel=None, - content={"don't": "swap"}, - creation_datetime=amend_message.time, - latest_amend=None, - ) - - return amend_message, amend_post - - @pytest.mark.asyncio async def test_get_posts_refs( ccn_api_client, @@ -233,3 +167,141 @@ async def test_get_amended_posts_refs( assert post["original_item_hash"] == original_post_db.item_hash assert post["ref"] == original_post_db.ref assert post["content"] == amend_post_db.content + + +@pytest.mark.asyncio +async def test_get_posts_tags( + ccn_api_client, + session_factory: DbSessionFactory, + fixture_posts: Sequence[PostDb], + post_with_refs_and_tags: Tuple[MessageDb, PostDb], +): + message_db, post_db = post_with_refs_and_tags + + with session_factory() as session: + session.add_all(fixture_posts) + session.add(message_db) + session.add(post_db) + session.commit() + + # Match one tag + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": "mainnet"} + ) + assert response.status == 200, await response.text() + response_json = await response.json() + assert len(response_json["posts"]) == 1 + assert response_json["pagination_total"] == 1 + + post = response_json["posts"][0] + assert post["item_hash"] == post_db.item_hash + assert post["original_item_hash"] == post_db.item_hash + assert post["content"] == post_db.content + + # Unknown tag + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": "not-a-tag"} + ) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 0 + assert response_json["pagination_total"] == 0 + + # Search for several tags + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": f"mainnet,not-a-ref"} + ) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 1 + assert response_json["pagination_total"] == 1 + + post = response_json["posts"][0] + assert post["item_hash"] == post_db.item_hash + assert post["original_item_hash"] == post_db.item_hash + assert post["ref"] == post_db.ref + assert post["content"] == post_db.content + + # Check for several matching tags + # Search for several tags + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": f"original,mainnet"} + ) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 1 + assert response_json["pagination_total"] == 1 + + post = response_json["posts"][0] + assert post["item_hash"] == post_db.item_hash + assert post["original_item_hash"] == post_db.item_hash + assert post["ref"] == post_db.ref + assert post["content"] == post_db.content + + +@pytest.mark.asyncio +async def test_get_amended_posts_tags( + ccn_api_client, + session_factory: DbSessionFactory, + fixture_posts: Sequence[PostDb], + post_with_refs_and_tags: Tuple[MessageDb, PostDb], + amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb], +): + original_message_db, original_post_db = post_with_refs_and_tags + amend_message_db, amend_post_db = amended_post_with_refs_and_tags + + original_post_db.latest_amend = amend_post_db.item_hash + + with session_factory() as session: + session.add_all(fixture_posts) + session.add(original_message_db) + session.add(original_post_db) + session.add(amend_message_db) + session.add(amend_post_db) + session.commit() + + # Match one tag + response = await ccn_api_client.get("/api/v0/posts.json", params={"tags": "amend"}) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 1 + assert response_json["pagination_total"] == 1 + + post = response_json["posts"][0] + assert post["item_hash"] == amend_post_db.item_hash + assert post["original_item_hash"] == original_post_db.item_hash + assert post["ref"] == original_post_db.ref + assert post["content"] == amend_post_db.content + + # Unknown tag + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": "not-a-tag"} + ) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 0 + assert response_json["pagination_total"] == 0 + + # Tag of the original + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": "original"} + ) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 0 + assert response_json["pagination_total"] == 0 + + # Search for several tags + response = await ccn_api_client.get( + "/api/v0/posts.json", params={"tags": "mainnet,not-a-tag"} + ) + assert response.status == 200 + response_json = await response.json() + assert len(response_json["posts"]) == 1 + assert response_json["pagination_total"] == 1 + + post = response_json["posts"][0] + assert post["item_hash"] == amend_post_db.item_hash + assert post["original_item_hash"] == original_post_db.item_hash + assert post["ref"] == original_post_db.ref + assert post["content"] == amend_post_db.content