From 204138f3c3d9a35a4ee13e902c301151f2fe635b Mon Sep 17 00:00:00 2001 From: evgeny Date: Thu, 4 Sep 2025 13:30:04 +0100 Subject: [PATCH] chore: mock headers to avoid warnings in the test run --- ably/realtime/realtime_channel.py | 127 +++++++++++++++++++-- test/ably/realtime/realtimechannel_test.py | 61 +++++++++- 2 files changed, 180 insertions(+), 8 deletions(-) diff --git a/ably/realtime/realtime_channel.py b/ably/realtime/realtime_channel.py index 326c23a6..01ecbf04 100644 --- a/ably/realtime/realtime_channel.py +++ b/ably/realtime/realtime_channel.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio import logging -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Dict, Any from ably.realtime.connection import ConnectionState from ably.transport.websockettransport import ProtocolMessageAction from ably.rest.channel import Channel, Channels as RestChannels @@ -14,10 +14,75 @@ if TYPE_CHECKING: from ably.realtime.realtime import AblyRealtime + from ably.util.crypto import CipherParams log = logging.getLogger(__name__) +class ChannelOptions: + """Channel options for Ably Realtime channels + + Attributes + ---------- + cipher : CipherParams, optional + Requests encryption for this channel when not null, and specifies encryption-related parameters. + params : Dict[str, str], optional + Channel parameters that configure the behavior of the channel. + """ + + def __init__(self, cipher: Optional[CipherParams] = None, params: Optional[dict] = None): + self.__cipher = cipher + self.__params = params + # Validate params + if self.__params and not isinstance(self.__params, dict): + raise AblyException("params must be a dictionary", 40000, 400) + + @property + def cipher(self): + """Get cipher configuration""" + return self.__cipher + + @property + def params(self) -> Dict[str, str]: + """Get channel parameters""" + return self.__params + + def __eq__(self, other): + """Check equality with another ChannelOptions instance""" + if not isinstance(other, ChannelOptions): + return False + + return (self.__cipher == other.__cipher and + self.__params == other.__params) + + def __hash__(self): + """Make ChannelOptions hashable""" + return hash(( + self.__cipher, + tuple(sorted(self.__params.items())) if self.__params else None, + )) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation""" + result = {} + if self.__cipher is not None: + result['cipher'] = self.__cipher + if self.__params: + result['params'] = self.__params + return result + + @classmethod + def from_dict(cls, options_dict: Dict[str, Any]) -> 'ChannelOptions': + """Create ChannelOptions from dictionary""" + if not isinstance(options_dict, dict): + raise AblyException("options must be a dictionary", 40000, 400) + + return cls( + cipher=options_dict.get('cipher'), + params=options_dict.get('params'), + ) + + class RealtimeChannel(EventEmitter, Channel): """ Ably Realtime Channel @@ -43,7 +108,7 @@ class RealtimeChannel(EventEmitter, Channel): Unsubscribe to messages from a channel """ - def __init__(self, realtime: AblyRealtime, name: str): + def __init__(self, realtime: AblyRealtime, name: str, channel_options: Optional[ChannelOptions] = None): EventEmitter.__init__(self) self.__name = name self.__realtime = realtime @@ -51,15 +116,36 @@ def __init__(self, realtime: AblyRealtime, name: str): self.__message_emitter = EventEmitter() self.__state_timer: Optional[Timer] = None self.__attach_resume = False + self.__attach_serial: Optional[str] = None self.__channel_serial: Optional[str] = None self.__retry_timer: Optional[Timer] = None self.__error_reason: Optional[AblyException] = None + self.__channel_options = channel_options or ChannelOptions() + self.__params: Optional[Dict[str, str]] = None # Used to listen to state changes internally, if we use the public event emitter interface then internals # will be disrupted if the user called .off() to remove all listeners self.__internal_state_emitter = EventEmitter() - Channel.__init__(self, realtime, name, {}) + # Pass channel options as dictionary to parent Channel class + Channel.__init__(self, realtime, name, self.__channel_options.to_dict()) + + async def set_options(self, channel_options: ChannelOptions) -> None: + """Set channel options""" + should_reattach = self.should_reattach_to_set_options(channel_options) + self.set_options_without_reattach(channel_options) + + if should_reattach: + self._attach_impl() + state_change = await self.__internal_state_emitter.once_async() + if state_change.current in (ChannelState.SUSPENDED, ChannelState.FAILED): + raise state_change.reason + + def set_options_without_reattach(self, channel_options: ChannelOptions) -> None: + """Internal method""" + self.__channel_options = channel_options + # Update parent class options + self.options = channel_options.to_dict() # RTL4 async def attach(self) -> None: @@ -108,6 +194,7 @@ def _attach_impl(self): # RTL4c attach_msg = { "action": ProtocolMessageAction.ATTACH, + "params": self.__channel_options.params, "channel": self.name, } @@ -292,8 +379,6 @@ def _on_message(self, proto_msg: dict) -> None: action = proto_msg.get('action') # RTL4c1 channel_serial = proto_msg.get('channelSerial') - if channel_serial: - self.__channel_serial = channel_serial # TM2a, TM2c, TM2f Message.update_inner_message_fields(proto_msg) @@ -303,6 +388,10 @@ def _on_message(self, proto_msg: dict) -> None: exception = None resumed = False + self.__attach_serial = channel_serial + self.__channel_serial = channel_serial + self.__params = proto_msg.get('params') + if error: exception = AblyException.from_dict(error) @@ -327,6 +416,7 @@ def _on_message(self, proto_msg: dict) -> None: self._request_state(ChannelState.ATTACHING) elif action == ProtocolMessageAction.MESSAGE: messages = Message.from_encoded_array(proto_msg.get('messages')) + self.__channel_serial = channel_serial for message in messages: self.__message_emitter._emit(message.name, message) elif action == ProtocolMessageAction.ERROR: @@ -431,6 +521,12 @@ def __on_retry_timer_expire(self) -> None: log.info("RealtimeChannel retry timer expired, attempting a new attach") self._request_state(ChannelState.ATTACHING) + def should_reattach_to_set_options(self, new_options: ChannelOptions) -> bool: + """Internal method""" + if self.state != ChannelState.ATTACHING and self.state != ChannelState.ATTACHED: + return False + return self.__channel_options != new_options + # RTL23 @property def name(self) -> str: @@ -453,6 +549,11 @@ def error_reason(self) -> Optional[AblyException]: """An AblyException instance describing the last error which occurred on the channel, if any.""" return self.__error_reason + @property + def params(self) -> Dict[str, str]: + """Get channel parameters""" + return self.__params + class Channels(RestChannels): """Creates and destroys RealtimeChannel objects. @@ -466,7 +567,7 @@ class Channels(RestChannels): """ # RTS3 - def get(self, name: str) -> RealtimeChannel: + def get(self, name: str, options: Optional[ChannelOptions] = None) -> RealtimeChannel: """Creates a new RealtimeChannel object, or returns the existing channel object. Parameters @@ -474,11 +575,23 @@ def get(self, name: str) -> RealtimeChannel: name: str Channel name + options: ChannelOptions or dict, optional + Channel options for the channel """ if name not in self.__all: - channel = self.__all[name] = RealtimeChannel(self.__ably, name) + channel = self.__all[name] = RealtimeChannel(self.__ably, name, options) else: channel = self.__all[name] + # Update options if channel is not attached or currently attaching + if options and channel.should_reattach_to_set_options(options): + raise AblyException( + 'Channels.get() cannot be used to set channel options that would cause the channel to ' + 'reattach. Please, use RealtimeChannel.setOptions() instead.', + 400, + 40000 + ) + elif options: + channel.set_options_without_reattach(options) return channel # RTS4 diff --git a/test/ably/realtime/realtimechannel_test.py b/test/ably/realtime/realtimechannel_test.py index 488f3059..a41c46b1 100644 --- a/test/ably/realtime/realtimechannel_test.py +++ b/test/ably/realtime/realtimechannel_test.py @@ -1,6 +1,6 @@ import asyncio import pytest -from ably.realtime.realtime_channel import ChannelState, RealtimeChannel +from ably.realtime.realtime_channel import ChannelState, RealtimeChannel, ChannelOptions from ably.transport.websockettransport import ProtocolMessageAction from ably.types.message import Message from test.ably.testapp import TestApp @@ -468,3 +468,62 @@ async def test_channel_error_cleared_upon_connect_from_terminal_state(self): assert channel.error_reason is None await ably.close() + + async def test_channel_params_received_by_relatime(self): + ably = await TestApp.get_ably_realtime() + channel_name = random_string(5) + channel = ably.channels.get(channel_name, ChannelOptions(params={ + "rewind": "1" + })) + await channel.attach() + assert channel.params["rewind"] == "1" + + await ably.close() + + async def test_channel_params_unknown_params_skipped_by_relatime(self): + ably = await TestApp.get_ably_realtime() + channel_name = random_string(5) + channel = ably.channels.get(channel_name, ChannelOptions(params={ + "rewind": "1", + "foo": "bar" + })) + await channel.attach() + assert channel.params["rewind"] == "1" + assert channel.params.get("foo") is None + + await ably.close() + + async def test_channel_params_as_dict(self): + ably = await TestApp.get_ably_realtime() + channel_name = random_string(5) + channel = ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"})) + await channel.attach() + assert channel.params["delta"] == "vcdiff" + + await ably.close() + + async def test_channel_get_channel_with_same_params(self): + ably = await TestApp.get_ably_realtime() + channel_name = random_string(5) + channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"})) + await channel.attach() + same_channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"})) + assert channel == same_channel + + await ably.close() + + async def test_channel_get_channel_with_different_params(self): + ably = await TestApp.get_ably_realtime() + channel_name = random_string(5) + channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"})) + await channel.attach() + + with pytest.raises(AblyException) as exception: + ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"})) + + assert exception.value.code == 40000 + assert exception.value.status_code == 400 + + assert channel.params == {"rewind": "1"} + + await ably.close()