Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 120 additions & 7 deletions ably/realtime/realtime_channel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
import logging
from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, Dict, Any
from ably.realtime.connection import ConnectionState
from ably.transport.websockettransport import ProtocolMessageAction
from ably.rest.channel import Channel, Channels as RestChannels
Expand All @@ -14,10 +14,75 @@

if TYPE_CHECKING:
from ably.realtime.realtime import AblyRealtime
from ably.util.crypto import CipherParams

log = logging.getLogger(__name__)


class ChannelOptions:
"""Channel options for Ably Realtime channels

Attributes
----------
cipher : CipherParams, optional
Requests encryption for this channel when not null, and specifies encryption-related parameters.
params : Dict[str, str], optional
Channel parameters that configure the behavior of the channel.
"""

def __init__(self, cipher: Optional[CipherParams] = None, params: Optional[dict] = None):
self.__cipher = cipher
self.__params = params
# Validate params
if self.__params and not isinstance(self.__params, dict):
raise AblyException("params must be a dictionary", 40000, 400)

@property
def cipher(self):
"""Get cipher configuration"""
return self.__cipher

@property
def params(self) -> Dict[str, str]:
"""Get channel parameters"""
return self.__params

def __eq__(self, other):
"""Check equality with another ChannelOptions instance"""
if not isinstance(other, ChannelOptions):
return False

return (self.__cipher == other.__cipher and
self.__params == other.__params)

def __hash__(self):
"""Make ChannelOptions hashable"""
return hash((
self.__cipher,
tuple(sorted(self.__params.items())) if self.__params else None,
))

def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary representation"""
result = {}
if self.__cipher is not None:
result['cipher'] = self.__cipher
if self.__params:
result['params'] = self.__params
return result

@classmethod
def from_dict(cls, options_dict: Dict[str, Any]) -> 'ChannelOptions':
"""Create ChannelOptions from dictionary"""
if not isinstance(options_dict, dict):
raise AblyException("options must be a dictionary", 40000, 400)

return cls(
cipher=options_dict.get('cipher'),
params=options_dict.get('params'),
)


class RealtimeChannel(EventEmitter, Channel):
"""
Ably Realtime Channel
Expand All @@ -43,23 +108,44 @@ class RealtimeChannel(EventEmitter, Channel):
Unsubscribe to messages from a channel
"""

def __init__(self, realtime: AblyRealtime, name: str):
def __init__(self, realtime: AblyRealtime, name: str, channel_options: Optional[ChannelOptions] = None):
EventEmitter.__init__(self)
self.__name = name
self.__realtime = realtime
self.__state = ChannelState.INITIALIZED
self.__message_emitter = EventEmitter()
self.__state_timer: Optional[Timer] = None
self.__attach_resume = False
self.__attach_serial: Optional[str] = None
self.__channel_serial: Optional[str] = None
self.__retry_timer: Optional[Timer] = None
self.__error_reason: Optional[AblyException] = None
self.__channel_options = channel_options or ChannelOptions()
self.__params: Optional[Dict[str, str]] = None

# Used to listen to state changes internally, if we use the public event emitter interface then internals
# will be disrupted if the user called .off() to remove all listeners
self.__internal_state_emitter = EventEmitter()

Channel.__init__(self, realtime, name, {})
# Pass channel options as dictionary to parent Channel class
Channel.__init__(self, realtime, name, self.__channel_options.to_dict())

async def set_options(self, channel_options: ChannelOptions) -> None:
"""Set channel options"""
should_reattach = self.should_reattach_to_set_options(channel_options)
self.set_options_without_reattach(channel_options)

if should_reattach:
self._attach_impl()
state_change = await self.__internal_state_emitter.once_async()
if state_change.current in (ChannelState.SUSPENDED, ChannelState.FAILED):
raise state_change.reason

def set_options_without_reattach(self, channel_options: ChannelOptions) -> None:
"""Internal method"""
self.__channel_options = channel_options
# Update parent class options
self.options = channel_options.to_dict()

# RTL4
async def attach(self) -> None:
Expand Down Expand Up @@ -108,6 +194,7 @@ def _attach_impl(self):
# RTL4c
attach_msg = {
"action": ProtocolMessageAction.ATTACH,
"params": self.__channel_options.params,
"channel": self.name,
}

Expand Down Expand Up @@ -292,8 +379,6 @@ def _on_message(self, proto_msg: dict) -> None:
action = proto_msg.get('action')
# RTL4c1
channel_serial = proto_msg.get('channelSerial')
if channel_serial:
self.__channel_serial = channel_serial
# TM2a, TM2c, TM2f
Message.update_inner_message_fields(proto_msg)

Expand All @@ -303,6 +388,10 @@ def _on_message(self, proto_msg: dict) -> None:
exception = None
resumed = False

self.__attach_serial = channel_serial
self.__channel_serial = channel_serial
self.__params = proto_msg.get('params')

if error:
exception = AblyException.from_dict(error)

Expand All @@ -327,6 +416,7 @@ def _on_message(self, proto_msg: dict) -> None:
self._request_state(ChannelState.ATTACHING)
elif action == ProtocolMessageAction.MESSAGE:
messages = Message.from_encoded_array(proto_msg.get('messages'))
self.__channel_serial = channel_serial
for message in messages:
self.__message_emitter._emit(message.name, message)
elif action == ProtocolMessageAction.ERROR:
Expand Down Expand Up @@ -431,6 +521,12 @@ def __on_retry_timer_expire(self) -> None:
log.info("RealtimeChannel retry timer expired, attempting a new attach")
self._request_state(ChannelState.ATTACHING)

def should_reattach_to_set_options(self, new_options: ChannelOptions) -> bool:
"""Internal method"""
if self.state != ChannelState.ATTACHING and self.state != ChannelState.ATTACHED:
return False
return self.__channel_options != new_options

# RTL23
@property
def name(self) -> str:
Expand All @@ -453,6 +549,11 @@ def error_reason(self) -> Optional[AblyException]:
"""An AblyException instance describing the last error which occurred on the channel, if any."""
return self.__error_reason

@property
def params(self) -> Dict[str, str]:
"""Get channel parameters"""
return self.__params


class Channels(RestChannels):
"""Creates and destroys RealtimeChannel objects.
Expand All @@ -466,19 +567,31 @@ class Channels(RestChannels):
"""

# RTS3
def get(self, name: str) -> RealtimeChannel:
def get(self, name: str, options: Optional[ChannelOptions] = None) -> RealtimeChannel:
"""Creates a new RealtimeChannel object, or returns the existing channel object.

Parameters
----------

name: str
Channel name
options: ChannelOptions or dict, optional
Channel options for the channel
"""
if name not in self.__all:
channel = self.__all[name] = RealtimeChannel(self.__ably, name)
channel = self.__all[name] = RealtimeChannel(self.__ably, name, options)
else:
channel = self.__all[name]
# Update options if channel is not attached or currently attaching
if options and channel.should_reattach_to_set_options(options):
raise AblyException(
'Channels.get() cannot be used to set channel options that would cause the channel to '
'reattach. Please, use RealtimeChannel.setOptions() instead.',
400,
40000
)
elif options:
channel.set_options_without_reattach(options)
return channel

# RTS4
Expand Down
61 changes: 60 additions & 1 deletion test/ably/realtime/realtimechannel_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import pytest
from ably.realtime.realtime_channel import ChannelState, RealtimeChannel
from ably.realtime.realtime_channel import ChannelState, RealtimeChannel, ChannelOptions
from ably.transport.websockettransport import ProtocolMessageAction
from ably.types.message import Message
from test.ably.testapp import TestApp
Expand Down Expand Up @@ -468,3 +468,62 @@ async def test_channel_error_cleared_upon_connect_from_terminal_state(self):
assert channel.error_reason is None

await ably.close()

async def test_channel_params_received_by_relatime(self):
ably = await TestApp.get_ably_realtime()
channel_name = random_string(5)
channel = ably.channels.get(channel_name, ChannelOptions(params={
"rewind": "1"
}))
await channel.attach()
assert channel.params["rewind"] == "1"

await ably.close()

async def test_channel_params_unknown_params_skipped_by_relatime(self):
ably = await TestApp.get_ably_realtime()
channel_name = random_string(5)
channel = ably.channels.get(channel_name, ChannelOptions(params={
"rewind": "1",
"foo": "bar"
}))
await channel.attach()
assert channel.params["rewind"] == "1"
assert channel.params.get("foo") is None

await ably.close()

async def test_channel_params_as_dict(self):
ably = await TestApp.get_ably_realtime()
channel_name = random_string(5)
channel = ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"}))
await channel.attach()
assert channel.params["delta"] == "vcdiff"

await ably.close()

async def test_channel_get_channel_with_same_params(self):
ably = await TestApp.get_ably_realtime()
channel_name = random_string(5)
channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
await channel.attach()
same_channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
assert channel == same_channel

await ably.close()

async def test_channel_get_channel_with_different_params(self):
ably = await TestApp.get_ably_realtime()
channel_name = random_string(5)
channel = ably.channels.get(channel_name, ChannelOptions(params={"rewind": "1"}))
await channel.attach()

with pytest.raises(AblyException) as exception:
ably.channels.get(channel_name, ChannelOptions(params={"delta": "vcdiff"}))

assert exception.value.code == 40000
assert exception.value.status_code == 400

assert channel.params == {"rewind": "1"}

await ably.close()