Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from aleph_message.models import ItemHash, Chain, MessageType
from sqlalchemy import func, select, update, text, delete, nullsfirst, nullslast
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import insert, array
from sqlalchemy.orm import selectinload, load_only
from sqlalchemy.sql import Insert, Select
from sqlalchemy.sql.elements import literal
Expand Down Expand Up @@ -105,7 +105,7 @@ def make_matching_messages_query(
)
if tags:
select_stmt = select_stmt.where(
MessageDb.content["content"]["tags"].contains(tags)
MessageDb.content["content"]["tags"].has_any(array(tags))
)
if channels:
select_stmt = select_stmt.where(MessageDb.channel.in_(channels))
Expand Down
4 changes: 2 additions & 2 deletions src/aleph/db/accessors/posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
Float,
case,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import JSONB, array
from sqlalchemy.orm import aliased
from sqlalchemy.sql import Select

Expand Down Expand Up @@ -230,7 +230,7 @@ def filter_post_select_stmt(
select_stmt = select_stmt.where(literal_column("original_type").in_(post_types))
if tags:
select_stmt = select_stmt.where(
literal_column("content", type_=JSONB)["tags"].astext.in_(tags)
literal_column("content", type_=JSONB)["tags"].has_any(array(tags))
)
if channels:
select_stmt = select_stmt.where(literal_column("channel").in_(channels))
Expand Down
72 changes: 69 additions & 3 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from pathlib import Path
from typing import Any, Dict, Sequence, cast
from typing import Any, Dict, Sequence, cast, Tuple

import pytest
import pytest_asyncio
from aleph_message.models import AggregateContent, PostContent
from sqlalchemy import insert
Expand All @@ -10,12 +11,13 @@
from aleph.db.models import (
MessageDb,
ChainTxDb,
AggregateElementDb, message_confirmations,
AggregateElementDb,
message_confirmations,
)
from aleph.db.models.posts import PostDb
from aleph.toolkit.timestamp import timestamp_to_datetime
from aleph.types.db_session import DbSessionFactory

import datetime as dt

# TODO: remove the raw parameter, it's just to avoid larger refactorings
async def _load_fixtures(
Expand Down Expand Up @@ -122,3 +124,67 @@ async def fixture_posts(
session.commit()

return posts


@pytest.fixture
def post_with_refs_and_tags() -> Tuple[MessageDb, PostDb]:
message = MessageDb(
item_hash="1234",
sender="0xdeadbeef",
type="POST",
chain="ETH",
signature=None,
item_type="storage",
item_content=None,
content={"content": {"tags": ["original", "mainnet"], "swap": "this"}},
time=dt.datetime(2023, 5, 1, tzinfo=dt.timezone.utc),
channel=None,
size=254,
)

post = PostDb(
item_hash=message.item_hash,
owner=message.sender,
type=None,
ref="custom-ref",
amends=None,
channel=None,
content=message.content["content"],
creation_datetime=message.time,
latest_amend=None,
)

return message, post


@pytest.fixture
def amended_post_with_refs_and_tags(post_with_refs_and_tags: Tuple[MessageDb, PostDb]):
original_message, original_post = post_with_refs_and_tags

amend_message = MessageDb(
item_hash="5678",
sender="0xdeadbeef",
type="POST",
chain="ETH",
signature=None,
item_type="storage",
item_content=None,
content={"content": {"tags": ["amend", "mainnet"], "don't": "swap"}},
time=dt.datetime(2023, 5, 2, tzinfo=dt.timezone.utc),
channel=None,
size=277,
)

amend_post = PostDb(
item_hash=amend_message.item_hash,
owner=original_message.sender,
type="amend",
ref=original_message.item_hash,
amends=original_message.item_hash,
channel=None,
content=amend_message.content["content"],
creation_datetime=amend_message.time,
latest_amend=None,
)

return amend_message, amend_post
70 changes: 64 additions & 6 deletions tests/api/test_list_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import aiohttp
import pytest

from aleph.db.models import MessageDb, PostDb
from aleph.types.db_session import DbSessionFactory
from .utils import get_messages_by_keys

MESSAGES_URI = "/api/v0/messages.json"
Expand Down Expand Up @@ -149,18 +151,74 @@ async def test_get_messages_multiple_hashes(fixture_messages, ccn_api_client):


@pytest.mark.asyncio
async def test_get_messages_filter_by_tags(fixture_messages, ccn_api_client):
async def test_get_messages_filter_by_tags(
fixture_messages,
ccn_api_client,
session_factory: DbSessionFactory,
post_with_refs_and_tags: Tuple[MessageDb, PostDb],
amended_post_with_refs_and_tags: Tuple[MessageDb, PostDb]
):
"""
Tests getting messages by tags.
There's no example in the fixtures, we just test that the endpoint returns a 200.
# TODO: add a POST message fixture with tags.
"""

tags = ["mainnet"]
message_db, _ = post_with_refs_and_tags
amend_message_db, _ = amended_post_with_refs_and_tags

response = await ccn_api_client.get(
MESSAGES_URI, params={"tags": ",".join(tags)}
)
with session_factory() as session:
session.add_all([message_db, amend_message_db])
session.commit()

# Matching tag for both messages
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 2

# Matching tags for both messages
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "original,amend"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 2

# Matching the original tag
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "original"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 1
assert messages[0]["item_hash"] == message_db.item_hash

# Matching the amend tag
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "amend"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 1
assert messages[0]["item_hash"] == amend_message_db.item_hash

# No match
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "not-a-tag"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 0

# Matching the amend tag with other tags
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "amend,not-a-tag,not-a-tag-either"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 1
assert messages[0]["item_hash"] == amend_message_db.item_hash


@pytest.mark.asyncio
async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_client):
"""
Tests getting messages by tags.
There's no example in the fixtures, we just test that the endpoint returns a 200.
"""

# Matching tag
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 0
Expand Down
Loading