Skip to content

Commit

Permalink
add notification rules support
Browse files Browse the repository at this point in the history
  • Loading branch information
mjurbanski-reef committed Apr 11, 2024
1 parent 71d01f2 commit aaf8569
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 37 deletions.
17 changes: 15 additions & 2 deletions b2sdk/_internal/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import pathlib
from contextlib import suppress
from typing import Sequence
from typing import Iterable, Sequence

from .encryption.setting import EncryptionSetting, EncryptionSettingFactory
from .encryption.types import EncryptionMode
Expand All @@ -37,7 +37,7 @@
from .filter import Filter, FilterMatcher
from .http_constants import LIST_FILE_NAMES_MAX_LIMIT
from .progress import AbstractProgressListener, DoNothingProgressListener
from .raw_api import LifecycleRule
from .raw_api import LifecycleRule, NotificationRule, NotificationRuleResponse
from .replication.setting import ReplicationConfiguration, ReplicationConfigurationFactory
from .transfer.emerge.executor import AUTO_CONTENT_TYPE
from .transfer.emerge.unbound_write_intent import UnboundWriteIntentGenerator
Expand Down Expand Up @@ -1492,6 +1492,19 @@ def as_dict(self):
def __repr__(self):
return f'Bucket<{self.id_},{self.name},{self.type_}>'

def get_notification_rules(self) -> list[NotificationRuleResponse]:
"""
Get all notification rules for this bucket.
"""
return self.api.session.get_bucket_notification_rules(self.id_)

def set_notification_rules(self,
rules: Iterable[NotificationRule]) -> list[NotificationRuleResponse]:
"""
Set notification rules for this bucket.
"""
return self.api.session.set_bucket_notification_rules(self.id_, rules)


class BucketFactory:
"""
Expand Down
55 changes: 55 additions & 0 deletions b2sdk/_internal/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import logging
import re
import typing
import warnings
from abc import ABCMeta
from typing import Any
Expand Down Expand Up @@ -574,6 +575,47 @@ class DestinationDirectoryDoesntAllowOperation(DestinationDirectoryError):
pass


class EventTypeError(BadRequest):
pass


class EventTypeCategoriesError(EventTypeError):
pass


class EventTypeOverlapError(EventTypeError):
pass


class EventTypesEmptyError(EventTypeError):
pass


class EventTypeInvalidError(EventTypeError):
pass


def _event_type_invalid_error(code: str, message: str, **_) -> B2Error:
from b2sdk._internal.raw_api import EVENT_TYPE

valid_types = sorted(typing.get_args(EVENT_TYPE))
return EventTypeInvalidError(
f"Event Type error: {message!r}. Valid types: {sorted(valid_types)!r}", code
)


_error_handlers: dict[tuple[int, str | None], typing.Callable] = {
(400, "event_type_categories"):
lambda code, message, **_: EventTypeCategoriesError(message, code),
(400, "event_type_overlap"):
lambda code, message, **_: EventTypeOverlapError(message, code),
(400, "event_types_empty"):
lambda code, message, **_: EventTypesEmptyError(message, code),
(400, "event_type_invalid"):
_event_type_invalid_error,
}


@trace_call(logger)
def interpret_b2_error(
status: int,
Expand All @@ -583,6 +625,19 @@ def interpret_b2_error(
post_params: dict[str, Any] | None = None
) -> B2Error:
post_params = post_params or {}

handler = _error_handlers.get((status, code))
if handler:
error = handler(
status=status,
code=code,
message=message,
response_headers=response_headers,
post_params=post_params
)
if error:
return error

if status == 400 and code == "already_hidden":
return FileAlreadyHidden(post_params.get('fileName'))
elif status == 400 and code == 'bad_json':
Expand Down
105 changes: 102 additions & 3 deletions b2sdk/_internal/raw_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
from abc import ABCMeta, abstractmethod
from enum import Enum, unique
from logging import getLogger
from typing import Any
from typing import Any, Iterable

from .utils.escape import unprintable_to_hex
from .utils.typing import JSON

try:
from typing_extensions import NotRequired, TypedDict
from typing_extensions import Literal, NotRequired, TypedDict
except ImportError:
from typing import NotRequired, TypedDict
from typing import Literal, NotRequired, TypedDict

from .encryption.setting import EncryptionMode, EncryptionSetting
from .exception import (
Expand Down Expand Up @@ -73,6 +73,8 @@
'shareFiles',
'writeFiles',
'deleteFiles',
'readBucketNotifications',
'writeBucketNotifications',
]

# API version number to use when calling the service
Expand Down Expand Up @@ -102,6 +104,67 @@ class LifecycleRule(TypedDict):
daysFromUploadingToHiding: NotRequired[int | None]


class NameValueDict(TypedDict):
name: str
value: str


class NotificationTargetConfiguration(TypedDict):
"""
Notification Target Configuration.
`hmacSha256SigningSecret`, if present, has to be a string of 32 alphanumeric characters.
"""
# TODO: add URL to the documentation

targetType: Literal['webhook']
url: str
customHeaders: NotRequired[list[NameValueDict] | None]
hmacSha256SigningSecret: NotRequired[str]


EVENT_TYPE = Literal[
'b2:ObjectCreated:*', 'b2:ObjectCreated:Upload', 'b2:ObjectCreated:MultipartUpload',
'b2:ObjectCreated:Copy', 'b2:ObjectCreated:Replica', 'b2:ObjectCreated:MultipartReplica',
'b2:ObjectDeleted:*', 'b2:ObjectDeleted:Delete', 'b2:ObjectDeleted:LifecycleRule',
'b2:HideMarkerCreated:*', 'b2:HideMarkerCreated:Hide', 'b2:HideMarkerCreated:LifecycleRule',]


class _NotificationRule(TypedDict):
"""
Notification Rule.
"""
eventTypes: list[EVENT_TYPE]
isEnabled: bool
name: str
objectNamePrefix: str
targetConfiguration: NotificationTargetConfiguration
suspensionReason: NotRequired[str]


class NotificationRule(_NotificationRule):
"""
Notification Rule.
When creating or modifying a notification rule, `isSuspended` and `suspensionReason` are ignored.
"""
isSuspended: NotRequired[bool]


class NotificationRuleResponse(_NotificationRule):
isSuspended: bool


def notification_rule_response_to_request(rule: NotificationRuleResponse) -> NotificationRule:
"""
Convert NotificationRuleResponse to NotificationRule.
"""
rule = rule.copy()
for key in ('isSuspended', 'suspensionReason'):
rule.pop(key, None)
return rule


class AbstractRawApi(metaclass=ABCMeta):
"""
Direct access to the B2 web apis.
Expand Down Expand Up @@ -415,6 +478,18 @@ def get_download_url_by_id(self, download_url, file_id):
def get_download_url_by_name(self, download_url, bucket_name, file_name):
return download_url + '/file/' + bucket_name + '/' + b2_url_encode(file_name)

@abstractmethod
def set_bucket_notification_rules(
self, api_url: str, account_auth_token: str, bucket_id: str,
rules: Iterable[NotificationRule]
) -> list[NotificationRuleResponse]:
pass

@abstractmethod
def get_bucket_notification_rules(self, api_url: str, account_auth_token: str,
bucket_id: str) -> list[NotificationRuleResponse]:
pass


class B2RawHTTPApi(AbstractRawApi):
"""
Expand Down Expand Up @@ -1088,6 +1163,30 @@ def copy_part(
except AccessDenied:
raise SSECKeyError()

def set_bucket_notification_rules(
self, api_url: str, account_auth_token: str, bucket_id: str, rules: list[NotificationRule]
) -> list[NotificationRuleResponse]:
return self._post_json(
api_url,
'b2_set_bucket_notification_rules',
account_auth_token,
**{
'bucketId': bucket_id,
'eventNotificationRules': rules,
},
)["eventNotificationRules"]

def get_bucket_notification_rules(self, api_url: str, account_auth_token: str,
bucket_id: str) -> list[NotificationRuleResponse]:
return self._get_json(
api_url,
'b2_get_bucket_notification_rules',
account_auth_token,
**{
'bucketId': bucket_id,
},
)["eventNotificationRules"]


def _add_range_header(headers, range_):
if range_ is not None:
Expand Down
71 changes: 70 additions & 1 deletion b2sdk/_internal/raw_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import logging
import random
import re
import secrets
import threading
import time
from contextlib import contextmanager, suppress
from typing import Iterable

from requests.structures import CaseInsensitiveDict

Expand All @@ -40,6 +42,7 @@
MissingPart,
NonExistentBucket,
PartSha1Mismatch,
ResourceNotFound,
SourceReplicationConflict,
SSECKeyError,
Unauthorized,
Expand All @@ -54,7 +57,14 @@
)
from .file_version import UNVERIFIED_CHECKSUM_PREFIX
from .http_constants import FILE_INFO_HEADER_PREFIX, HEX_DIGITS_AT_END
from .raw_api import ALL_CAPABILITIES, AbstractRawApi, LifecycleRule, MetadataDirectiveMode
from .raw_api import (
ALL_CAPABILITIES,
AbstractRawApi,
LifecycleRule,
MetadataDirectiveMode,
NotificationRule,
NotificationRuleResponse,
)
from .replication.setting import ReplicationConfiguration
from .replication.types import ReplicationStatus
from .stream.hashing import StreamWithHash
Expand Down Expand Up @@ -542,6 +552,7 @@ def __init__(
self.bucket_info = bucket_info or {}
self.cors_rules = cors_rules or []
self.lifecycle_rules = lifecycle_rules or []
self._notification_rules = []
self.options_set = options_set or set()
self.revision = 1
self.upload_url_counter = iter(range(200))
Expand Down Expand Up @@ -1160,6 +1171,44 @@ def _chunks_number(self, content_length, chunk_size):
def _next_file_id(self):
return str(next(self.file_id_counter))

def get_notification_rules(self) -> list[NotificationRule]:
return self._notification_rules

def set_notification_rules(self,
rules: Iterable[NotificationRule]) -> list[NotificationRuleResponse]:
old_rules_by_name = {rule["name"]: rule for rule in self._notification_rules}
new_rules: list[NotificationRuleResponse] = []
for rule in rules:
for field in ("isSuspended", "suspensionReason"):
rule.pop(field, None)
old_rule = old_rules_by_name.get(rule["name"], {"targetConfiguration": {}})
new_rule = {
**{
"isSuspended": False,
"suspensionReason": "",
},
**old_rule,
**rule,
"targetConfiguration":
{
**old_rule.get("targetConfiguration", {}),
**rule.get("targetConfiguration", {}),
},
}
new_rules.append(new_rule)
self._notification_rules = new_rules
return self._notification_rules

def simulate_notification_rule_suspension(
self, rule_name: str, reason: str, is_suspended: bool | None = None
) -> None:
for rule in self._notification_rules:
if rule["name"] == rule_name:
rule["isSuspended"] = bool(reason) if is_suspended is None else is_suspended
rule["suspensionReason"] = reason
return
raise ResourceNotFound(f"Rule {rule_name} not found")


class RawSimulator(AbstractRawApi):
"""
Expand Down Expand Up @@ -1235,6 +1284,8 @@ def expire_auth_token(self, auth_token):

def create_account(self):
"""
Simulate creating an account.
Return (accountId, masterApplicationKey) for a newly created account.
"""
# Pick the IDs for the account and the key
Expand Down Expand Up @@ -1973,3 +2024,21 @@ def _get_bucket_by_name(self, bucket_name):
if bucket_name not in self.bucket_name_to_bucket:
raise NonExistentBucket(bucket_name)
return self.bucket_name_to_bucket[bucket_name]

def set_bucket_notification_rules(
self, api_url: str, account_auth_token: str, bucket_id: str,
rules: Iterable[NotificationRule]
):
bucket = self._get_bucket_by_id(bucket_id)
self._assert_account_auth(
api_url, account_auth_token, bucket.account_id, 'writeBucketNotifications'
)
return bucket.set_notification_rules(rules)

def get_bucket_notification_rules(self, api_url: str, account_auth_token: str,
bucket_id: str) -> list[NotificationRule]:
bucket = self._get_bucket_by_id(bucket_id)
self._assert_account_auth(
api_url, account_auth_token, bucket.account_id, 'readBucketNotifications'
)
return bucket.get_notification_rules()
8 changes: 8 additions & 0 deletions b2sdk/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,11 @@ def update_file_legal_hold(
file_name,
legal_hold,
)

def get_bucket_notification_rules(self, bucket_id):
return self._wrap_default_token(self.raw_api.get_bucket_notification_rules, bucket_id)

def set_bucket_notification_rules(self, bucket_id, rules):
return self._wrap_default_token(
self.raw_api.set_bucket_notification_rules, bucket_id, rules
)
Loading

0 comments on commit aaf8569

Please sign in to comment.