Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 116 additions & 35 deletions src/aleph/handlers/forget.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,50 @@
from __future__ import annotations

import logging
from typing import Dict
from dataclasses import dataclass
from typing import Dict, Optional, List

from aioipfs.api import RepoAPI
from aioipfs.exceptions import NotPinnedError
from aleph_message.models import ForgetMessage, ItemType, MessageType, StoreContent
from aleph_message.models import ForgetMessage, MessageType
from aleph_message.models import ItemType

from aleph.model.filepin import PermanentPin
from aleph.model.hashes import delete_value
from aleph.model.messages import Message
from aleph.services.ipfs.common import get_ipfs_api
from aleph.storage import get_message_content
from aleph.utils import item_type_from_hash


@dataclass
class TargetMessageInfo:
item_hash: str
sender: str
type: MessageType
forgotten_by: List[str]
content_address: Optional[str]
content_item_hash: Optional[str]
content_item_type: Optional[ItemType]

@classmethod
def from_db_object(cls, message_dict: Dict) -> TargetMessageInfo:
content = message_dict.get("content", {})
content_item_type = content.get("item_type")

if content_item_type is not None:
content_item_type = ItemType(content_item_type)

return cls(
item_hash=message_dict["item_hash"],
sender=message_dict["sender"],
type=MessageType(message_dict["type"]),
forgotten_by=message_dict.get("forgotten_by", []),
content_address=content.get("address"),
content_item_hash=content.get("item_hash"),
content_item_type=content_item_type,
)


logger = logging.getLogger(__name__)


Expand All @@ -24,10 +57,12 @@ async def count_file_references(storage_hash: str) -> int:


async def file_references_exist(storage_hash: str) -> bool:
"""Check if references to a file on Aleph exist.
"""
return bool(await Message.collection.count_documents(
filter={"content.item_hash": storage_hash}, limit=1))
"""Check if references to a file on Aleph exist."""
return bool(
await Message.collection.count_documents(
filter={"content.item_hash": storage_hash}, limit=1
)
)


async def garbage_collect(storage_hash: str, storage_type: ItemType):
Expand All @@ -37,7 +72,12 @@ async def garbage_collect(storage_hash: str, storage_type: ItemType):
"""
logger.debug(f"Garbage collecting {storage_hash}")

if await PermanentPin.collection.count_documents(filter={"multihash": storage_hash}, limit=1) > 0:
if (
await PermanentPin.collection.count_documents(
filter={"multihash": storage_hash}, limit=1
)
> 0
):
logger.debug(f"Permanent pin will not be collected {storage_hash}")
return

Expand All @@ -49,8 +89,10 @@ async def garbage_collect(storage_hash: str, storage_type: ItemType):
storage_detected: ItemType = item_type_from_hash(storage_hash)

if storage_type != storage_detected:
raise ValueError(f"Inconsistent ItemType {storage_type} != {storage_detected} "
f"for hash '{storage_hash}'")
raise ValueError(
f"Inconsistent ItemType {storage_type} != {storage_detected} "
f"for hash '{storage_hash}'"
)

if storage_type == ItemType.ipfs:
api = await get_ipfs_api(timeout=5)
Expand All @@ -75,42 +117,49 @@ async def garbage_collect(storage_hash: str, storage_type: ItemType):
logger.debug(f"Removed from {storage_type}: {storage_hash}")


async def is_allowed_to_forget(target: Dict, by: ForgetMessage) -> bool:
"""Check if a forget message is allowed to 'forget' the target message given its hash.
"""
async def is_allowed_to_forget(
target_info: TargetMessageInfo, by: ForgetMessage
) -> bool:
"""Check if a forget message is allowed to 'forget' the target message given its hash."""
# Both senders are identical:
if by.sender == target.get("sender"):
if by.sender == target_info.sender:
return True
else:
# Content already forgotten, probably by someone else
if target_info.content_address is None:
return False

# The forget sender matches the content address:
target_content = await get_message_content(target)
if by.sender == target_content.value["address"]:
if by.sender == target_info.content_address:
return True
return False


async def forget_if_allowed(target_message: Dict, forget_message: ForgetMessage) -> None:
async def forget_if_allowed(
target_info: TargetMessageInfo, forget_message: ForgetMessage
) -> None:
"""Forget a message.

Remove the ‘content’ and ‘item_content’ sections of the targeted messages.
Add a field ‘removed_by’ that references to the processed FORGET message.
"""
target_hash = target_message["item_hash"]

if not target_message:
logger.info(f"Message to forget could not be found with id {target_hash}")
return
target_hash = target_info.item_hash

if target_message.get("type") == MessageType.forget:
logger.info(f"FORGET message may not be forgotten {target_hash} by {forget_message.item_hash}")
if target_info.type == MessageType.forget:
logger.info(
f"FORGET message may not be forgotten {target_hash} by {forget_message.item_hash}"
)
return

if not await is_allowed_to_forget(target_message, by=forget_message):
logger.info(f"Not allowed to forget {target_hash} by {forget_message.item_hash}")
# TODO: support forgetting the same message several times (if useful)
if target_info.forgotten_by:
logger.debug(f"Message content already forgotten: {target_hash}")
return

if target_message.get("content") is None:
logger.debug(f"Message content already forgotten: {target_message}")
if not await is_allowed_to_forget(target_info, by=forget_message):
logger.info(
f"Not allowed to forget {target_hash} by {forget_message.item_hash}"
)
return

logger.debug(f"Removing content for {target_hash}")
Expand All @@ -119,15 +168,46 @@ async def forget_if_allowed(target_message: Dict, forget_message: ForgetMessage)
"item_content": None,
"forgotten_by": [forget_message.item_hash],
}
await Message.collection.update_many(filter={"item_hash": target_hash}, update={"$set": updates})
await Message.collection.update_many(
filter={"item_hash": target_hash}, update={"$set": updates}
)

# TODO QUESTION: Should the removal be added to the CappedMessage collection for websocket
# updates ? Forget messages should already be published there, but the logic to validate
# them could be centralized here.

if target_message.get("type") == MessageType.store:
store_content = StoreContent(**target_message["content"])
await garbage_collect(store_content.item_hash, store_content.item_type)
if target_info.type == MessageType.store:
if (
target_info.content_item_type is None
or target_info.content_item_hash is None
):
raise ValueError(
f"Could not garbage collect content linked to STORE message {target_hash}."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this message more explicit ? Something like:

Suggested change
f"Could not garbage collect content linked to STORE message {target_hash}."
f"Information missing, could not garbage collect content linked to STORE message {target_hash}."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why put the target_hash in the string and not as another argument ?

Suggested change
f"Could not garbage collect content linked to STORE message {target_hash}."
f"Could not garbage collect content linked to STORE message", target_hash

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually never passed several values to ValueError, I see that it works (it makes the args property a tuple), is there an upside in doing this besides avoiding the string formatting?

)

await garbage_collect(
target_info.content_item_hash, target_info.content_item_type
)


async def get_target_message_info(target_hash: str) -> Optional[TargetMessageInfo]:
message_dict = await Message.collection.find_one(
filter={"item_hash": target_hash},
projection={
"_id": 0,
"item_hash": 1,
"sender": 1,
"type": 1,
"forgotten_by": 1,
"content.address": 1,
"content.item_hash": 1,
"content.item_type": 1,
},
)
if message_dict is None:
return None

return TargetMessageInfo.from_db_object(message_dict)


async def handle_forget_message(message: Dict, content: Dict):
Expand All @@ -136,11 +216,12 @@ async def handle_forget_message(message: Dict, content: Dict):
logger.debug(f"Handling forget message {forget_message.item_hash}")

for target_hash in forget_message.content.hashes:
target_message = await Message.collection.find_one(filter={"item_hash": target_hash})
if target_message is None:
target_info = await get_target_message_info(target_hash)

if target_info is None:
logger.info(f"Message to forget could not be found with id {target_hash}")
continue

await forget_if_allowed(target_message=target_message, forget_message=forget_message)
await forget_if_allowed(target_info=target_info, forget_message=forget_message)

return True
11 changes: 7 additions & 4 deletions tests/storage/forget/test_forget_message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from aleph_message.models import ForgetMessage
from aleph.handlers.forget import forget_if_allowed
from aleph.handlers.forget import forget_if_allowed, TargetMessageInfo


@pytest.mark.asyncio
Expand Down Expand Up @@ -46,7 +46,8 @@ async def test_forget_inline_message(mocker):
message_mock = mocker.patch("aleph.handlers.forget.Message")
message_mock.collection.update_many = mocker.AsyncMock()

await forget_if_allowed(target_message, forget_message)
target_info = TargetMessageInfo.from_db_object(target_message)
await forget_if_allowed(target_info, forget_message)

message_mock.collection.update_many.assert_called_once_with(
filter={"item_hash": target_message["item_hash"]},
Expand Down Expand Up @@ -106,7 +107,8 @@ async def test_forget_store_message(mocker):
message_mock = mocker.patch("aleph.handlers.forget.Message")
message_mock.collection.update_many = mocker.AsyncMock()

await forget_if_allowed(target_message, forget_message)
target_info = TargetMessageInfo.from_db_object(target_message)
await forget_if_allowed(target_info, forget_message)

message_mock.collection.update_many.assert_called_once_with(
filter={"item_hash": target_message["item_hash"]},
Expand Down Expand Up @@ -171,7 +173,8 @@ async def test_forget_forget_message(mocker):
message_mock = mocker.patch("aleph.handlers.forget.Message")
message_mock.collection.update_many = mocker.AsyncMock()

await forget_if_allowed(target_message, forget_message)
target_info = TargetMessageInfo.from_db_object(target_message)
await forget_if_allowed(target_info, forget_message)

assert not message_mock.collection.update_many.called
assert not garbage_collect_mock.called