diff --git a/sdk/storage/azure-storage-blob-changefeed/azure/__init__.py b/sdk/storage/azure-storage-blob-changefeed/azure/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-blob-changefeed/azure/__init__.py +++ b/sdk/storage/azure-storage-blob-changefeed/azure/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-blob-changefeed/azure/storage/__init__.py b/sdk/storage/azure-storage-blob-changefeed/azure/storage/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-blob-changefeed/azure/storage/__init__.py +++ b/sdk/storage/azure-storage-blob-changefeed/azure/storage/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/__init__.py b/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/__init__.py index 59cb70146572..0d1f7edf5dc6 100644 --- a/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/__init__.py +++ b/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: str +__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore diff --git a/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_change_feed_client.py b/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_change_feed_client.py index 31db5231dafb..395877c6112f 100644 --- a/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_change_feed_client.py +++ b/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_change_feed_client.py @@ -9,10 +9,11 @@ Any, Dict, Optional, Union, TYPE_CHECKING ) +from typing_extensions import Self from azure.core.paging import ItemPaged from azure.core.tracing.decorator import distributed_trace -from azure.storage.blob import BlobServiceClient # pylint: disable=no-name-in-module +from azure.storage.blob._blob_service_client import BlobServiceClient from azure.storage.blob._shared.base_client import parse_connection_str from ._models import ChangeFeedPaged @@ -20,7 +21,7 @@ from azure.core.credentials import AzureNamedKeyCredential, AzureSasCredential, TokenCredential -class ChangeFeedClient(object): +class ChangeFeedClient: """A client to interact with a specific account change feed. :param str account_url: @@ -60,18 +61,18 @@ class ChangeFeedClient(object): :caption: Creating the ChangeFeedClient from a URL to a public blob (no auth needed). """ def __init__( - self, account_url: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> None: + self, account_url: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> None: self._blob_service_client = BlobServiceClient(account_url, credential, **kwargs) @classmethod def from_connection_string( - cls, conn_str: str, - credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long - **kwargs: Any - ) -> "ChangeFeedClient": + cls, conn_str: str, + credential: Optional[Union[str, Dict[str, str], "AzureNamedKeyCredential", "AzureSasCredential", "TokenCredential"]] = None, # pylint: disable=line-too-long + **kwargs: Any + ) -> Self: """Create ChangeFeedClient from a Connection String. :param str conn_str: @@ -99,7 +100,7 @@ def from_connection_string( return cls(account_url, credential=credential, **kwargs) @distributed_trace - def list_changes(self, **kwargs: Any) -> ItemPaged[Dict]: + def list_changes(self, **kwargs: Any) -> ItemPaged[Dict[str, Any]]: """Returns a generator to list the change feed events. The generator will lazily follow the continuation tokens returned by the service. @@ -135,4 +136,5 @@ def list_changes(self, **kwargs: Any) -> ItemPaged[Dict]: container_client, results_per_page=results_per_page, page_iterator_class=ChangeFeedPaged, - **kwargs) + **kwargs + ) diff --git a/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_models.py b/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_models.py index e0a69bb923b2..172783cb4e87 100644 --- a/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_models.py +++ b/sdk/storage/azure-storage-blob-changefeed/azure/storage/blob/changefeed/_models.py @@ -4,52 +4,137 @@ # license information. # -------------------------------------------------------------------------- # pylint: disable=too-few-public-methods + import collections import copy import json from datetime import datetime - -from azure.storage.blob._shared.avro.datafile import DataFileReader -from azure.storage.blob._shared.avro.avro_io import DatumReader +from typing import ( + Any, Dict, Iterator, List, Optional, Union, + TYPE_CHECKING +) +from typing_extensions import Self from azure.core.exceptions import HttpResponseError from azure.core.paging import PageIterator +from azure.storage.blob._shared.avro.avro_io import DatumReader +from azure.storage.blob._shared.avro.datafile import DataFileReader + +if TYPE_CHECKING: + from collections import deque + from azure.storage.blob._blob_client import BlobClient + from azure.storage.blob._container_client import ContainerClient + from azure.storage.blob.aio._container_client_async import ContainerClient as AsyncContainerClient + # =============================================================================================== SEGMENT_COMMON_PATH = "idx/segments/" PATH_DELIMITER = "/" # =============================================================================================== -class ChangeFeedPaged(PageIterator): - """An Iterable of change feed events +class Segment: + def __init__( + self, client: Union["ContainerClient", "AsyncContainerClient"], + segment_path: str, + page_size: int, + segment_cursor: Optional[Dict[str, Any]] = None + ) -> None: + self.client = client + self.segment_path = segment_path + self.page_size = page_size + self.shards: "deque" = collections.deque() + self.cursor: Dict[str, Any] = {"ShardCursors": {}, "SegmentPath": self.segment_path} + self._initialize(segment_cursor=segment_cursor) + # cursor is in this format: + # {"segment_path", path, "CurrentShardPath": shard_path, "segment_cursor": ShardCursors dict} + + def __iter__(self) -> Self: + return self + + def __next__(self) -> List[Dict[str, Any]]: + segment_events: List[Dict[str, Any]] = [] + while len(segment_events) < self.page_size and self.shards: + shard = self.shards.popleft() + try: + event = next(shard) + segment_events.append(event) + self.shards.append(shard) + except StopIteration: + pass + + # update cursor + self.cursor["ShardCursors"][shard.shard_path] = shard.cursor + self.cursor["CurrentShardPath"] = shard.shard_path + + if not segment_events: + raise StopIteration + + return segment_events + + next = __next__ # Python 2 compatibility. + + def _initialize(self, segment_cursor=None): + segment_content = self.client.get_blob_client(self.segment_path).download_blob().readall() + segment_content = segment_content.decode() + segment_dict = json.loads(segment_content) + + raw_shard_paths = segment_dict["chunkFilePaths"] + shard_paths = [] + # to strip the overhead of all raw shard paths + for raw_shard_path in raw_shard_paths: + shard_paths.append(raw_shard_path.replace("$blobchangefeed/", "", 1)) + + # TODO: we can optimize to initiate shards in parallel + if not segment_cursor: + for shard_path in shard_paths: + self.shards.append(Shard(self.client, shard_path)) + else: + start_shard_path = segment_cursor["CurrentShardPath"] + shard_cursors = {shard_cursor["CurrentChunkPath"][:-10]: shard_cursor + for shard_cursor in segment_cursor["ShardCursors"]} + + if shard_paths: + # Initialize all shards using the shard cursors + for shard_path in shard_paths: + self.shards.append(Shard(self.client, shard_path, shard_cursors.get(shard_path))) + + # the move the shard behind start_shard_path one to the left most place, the left most shard is the next + # shard we should read based on continuation token. + while self.shards[0].shard_path != start_shard_path: + self.shards.append(self.shards.popleft()) + self.shards.append(self.shards.popleft()) - :ivar int results_per_page: - The maximum number of results retrieved per API call. - :ivar str continuation_token: - The continuation token to retrieve the next page of results. - :ivar current_page: - The current page of listed results. - :vartype current_page: list(dict) - :param ~azure.storage.blob.ContainerClient or ~azure.storage.blob.aio.ContainerClient: +class ChangeFeedPaged(PageIterator): + """An Iterable of change feed events. + + :param ~azure.storage.blob.ContainerClient or ~azure.storage.blob.aio.ContainerClient container_client: the client to get change feed events. - :param int results_per_page: - The maximum number of blobs to retrieve per - call. - :param datetime start_time: + :param Optional[int] results_per_page: + The maximum number of blobs to retrieve per call. + :param Optional[datetime] start_time: Filters the results to return only events which happened after this time. - :param datetime end_time: + :param Optional[datetime] end_time: Filters the results to return only events which happened before this time. - :param str continuation_token: - An continuation token with which to start listing events from the previous position. + :param Optional[str] continuation_token: + A continuation token with which to start listing events from the previous position. """ + + results_per_page: int + """The maximum number of results retrieved per API call.""" + continuation_token: str + """The continuation token to retrieve the next page of results.""" + current_page: Optional[List[Dict[str, Any]]] + """The current page of listed results.""" + def __init__( - self, container_client, - results_per_page=None, - start_time=None, - end_time=None, - continuation_token=None): + self, container_client: Union["ContainerClient", "AsyncContainerClient"], + results_per_page: Optional[int] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + continuation_token: Optional[str] = None + ) -> None: if (start_time or end_time) and continuation_token: raise ValueError("start_time/end_time and continuation_token shouldn't be specified at the same time") super(ChangeFeedPaged, self).__init__( @@ -57,7 +142,7 @@ def __init__( extract_data=self._extract_data_cb, continuation_token=continuation_token or "" ) - dict_continuation_token = json.loads(continuation_token) if continuation_token else None # type: dict + dict_continuation_token = json.loads(continuation_token) if continuation_token else None if dict_continuation_token and (container_client.primary_hostname != dict_continuation_token["UrlHost"]): # pylint: disable=unsubscriptable-object raise ValueError("The token is not for the current storage account.") @@ -87,8 +172,14 @@ def _extract_data_cb(self, event_list): return json.dumps(cursor), self.current_page -class ChangeFeed(object): - def __init__(self, client, page_size, start_time=None, end_time=None, cf_cursor=None): +class ChangeFeed: + def __init__( + self, client: Union["ContainerClient", "AsyncContainerClient"], + page_size: int, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + cf_cursor: Optional[Dict[str, Any]] = None + ) -> None: self.client = client self.page_size = page_size self._segment_paths_generator = None @@ -110,11 +201,11 @@ def __init__(self, client, page_size, start_time=None, end_time=None, cf_cursor= "UrlHost": self.client.primary_hostname} self._initialize(cur_segment_cursor=cur_segment_cursor) - def __iter__(self): + def __iter__(self) -> Self: return self - def __next__(self): - change_feed = [] + def __next__(self) -> List[Dict[str, Any]]: + change_feed: List[Dict[str, Any]] = [] remaining_to_load = self.page_size if not self.current_segment: @@ -218,87 +309,23 @@ def _is_later_than_end_time(self, segment_path): return segment_date > opaque_end_date -class Segment(object): - def __init__(self, client, segment_path, page_size, segment_cursor=None): - self.client = client - self.segment_path = segment_path - self.page_size = page_size - self.shards = collections.deque() - self.cursor = {"ShardCursors": {}, "SegmentPath": self.segment_path} - self._initialize(segment_cursor=segment_cursor) - # cursor is in this format: - # {"segment_path", path, "CurrentShardPath": shard_path, "segment_cursor": ShardCursors dict} - - def __iter__(self): - return self - - def __next__(self): - segment_events = [] - while len(segment_events) < self.page_size and self.shards: - shard = self.shards.popleft() - try: - event = next(shard) - segment_events.append(event) - self.shards.append(shard) - except StopIteration: - pass - - # update cursor - self.cursor["ShardCursors"][shard.shard_path] = shard.cursor - self.cursor["CurrentShardPath"] = shard.shard_path - - if not segment_events: - raise StopIteration - - return segment_events - - next = __next__ # Python 2 compatibility. - - def _initialize(self, segment_cursor=None): - segment_content = self.client.get_blob_client(self.segment_path).download_blob().readall() - segment_content = segment_content.decode() - segment_dict = json.loads(segment_content) - - raw_shard_paths = segment_dict["chunkFilePaths"] - shard_paths = [] - # to strip the overhead of all raw shard paths - for raw_shard_path in raw_shard_paths: - shard_paths.append(raw_shard_path.replace("$blobchangefeed/", "", 1)) - - # TODO: we can optimize to initiate shards in parallel - if not segment_cursor: - for shard_path in shard_paths: - self.shards.append(Shard(self.client, shard_path)) - else: - start_shard_path = segment_cursor["CurrentShardPath"] - shard_cursors = {shard_cursor["CurrentChunkPath"][:-10]: shard_cursor - for shard_cursor in segment_cursor["ShardCursors"]} - - if shard_paths: - # Initialize all shards using the shard cursors - for shard_path in shard_paths: - self.shards.append(Shard(self.client, shard_path, shard_cursors.get(shard_path))) - - # the move the shard behind start_shard_path one to the left most place, the left most shard is the next - # shard we should read based on continuation token. - while self.shards[0].shard_path != start_shard_path: - self.shards.append(self.shards.popleft()) - self.shards.append(self.shards.popleft()) - - -class Shard(object): - def __init__(self, client, shard_path, shard_cursor=None): +class Shard: + def __init__( + self, client: Union["ContainerClient", "AsyncContainerClient"], + shard_path: str, + shard_cursor: Optional[Dict[str, Any]] = None + ) -> None: self.client = client self.shard_path = shard_path self.current_chunk = None - self.unprocessed_chunk_path_props = [] + self.unprocessed_chunk_path_props: "deque" = collections.deque() self.cursor = None # to track the chunk info we are reading self._initialize(shard_cursor=shard_cursor) - def __iter__(self): + def __iter__(self) -> Self: return self - def __next__(self): + def __next__(self) -> Dict[str, Any]: next_event = None while not next_event and self.current_chunk: temp_chunk = self.current_chunk @@ -335,54 +362,13 @@ def _get_next_chunk(self, chunk_cursor=None): return None -class Chunk(object): - def __init__(self, client, chunk_path, chunk_cursor=None): - self.client = client - self.chunk_path = chunk_path - self.file_reader = None - self.cursor = {"CurrentChunkPath": chunk_path} # to track the current position in avro file - self._data_stream = None - self._initialize(chunk_cursor=chunk_cursor) - - def __iter__(self): - return self - - def __next__(self): - try: - event = next(self.file_reader) - self.cursor["EventIndex"] = self._data_stream.event_index - self.cursor["BlockOffset"] = self._data_stream.object_position - return event - except StopIteration as exc: - self.cursor["EventIndex"] = self._data_stream.event_index - self.cursor["BlockOffset"] = self._data_stream.object_position - raise StopIteration from exc - - next = __next__ # Python 2 compatibility. - - def _initialize(self, chunk_cursor=None): - # To get all events in a chunk - blob_client = self.client.get_blob_client(self.chunk_path) - - file_offset = chunk_cursor.get("BlockOffset") if chunk_cursor else 0 - - # An offset means the avro data doesn't have avro header, - # so only when the data stream has a offset we need header stream to help - header_stream = ChangeFeedStreamer(blob_client) if file_offset else None - self._data_stream = ChangeFeedStreamer(blob_client, chunk_file_start=file_offset) - self.file_reader = DataFileReader(self._data_stream, DatumReader(), header_reader=header_stream) - - event_index = chunk_cursor.get("EventIndex") if chunk_cursor else 0 - for _ in range(0, event_index): - next(self.file_reader) - - -class ChangeFeedStreamer(object): - """ - File-like streaming iterator. - """ +class ChangeFeedStreamer: + """File-like streaming iterator.""" - def __init__(self, blob_client, chunk_file_start=0): + def __init__( + self, blob_client: "BlobClient", + chunk_file_start: int = 0 + ) -> None: self._chunk_file_start = chunk_file_start or 0 # this value will never be updated self._download_offset = self._chunk_file_start # range start of the next download self.object_position = self._chunk_file_start # track the most recently read sync marker position @@ -395,27 +381,27 @@ def __init__(self, blob_client, chunk_file_start=0): self._iterator = blob_client.download_blob(offset=self._chunk_file_start, length=length).chunks() if length > 0 else iter([]) - def __len__(self): + def __len__(self) -> int: return self._download_offset - def __iter__(self): + def __iter__(self) -> Iterator[bytes]: return self._iterator @staticmethod - def seekable(): + def seekable() -> bool: return True - def __next__(self): + def __next__(self) -> bytes: next_chunk = next(self._iterator) self._download_offset += len(next_chunk) return next_chunk next = __next__ # Python 2 compatibility. - def tell(self): + def tell(self) -> int: return self._point - def seek(self, offset, whence=0): + def seek(self, offset: int, whence: int = 0) -> None: if whence == 0: self._point = self._chunk_file_start + offset elif whence == 1: @@ -425,7 +411,7 @@ def seek(self, offset, whence=0): if self._point < self._chunk_file_start: self._point = self._chunk_file_start - def read(self, size): + def read(self, size: int) -> bytes: try: # keep downloading file content until the buffer has enough bytes to read while self._point + size > self._download_offset: @@ -454,8 +440,54 @@ def read(self, size): return data - def track_object_position(self): + def track_object_position(self) -> None: self.object_position = self.tell() - def set_object_index(self, event_index): + def set_object_index(self, event_index: int) -> None: self.event_index = event_index + + +class Chunk: + def __init__( + self, client: Union["ContainerClient", "AsyncContainerClient"], + chunk_path: str, + chunk_cursor: Optional[Dict[str, Any]] = None + ) -> None: + self.client = client + self.chunk_path = chunk_path + self.file_reader: DataFileReader = None # type: ignore [assignment] + self.cursor: Dict[str, Any] = {"CurrentChunkPath": chunk_path} # to track the current position in avro file + self._data_stream: ChangeFeedStreamer = None # type: ignore [assignment] + self._initialize(chunk_cursor=chunk_cursor) + + def __iter__(self) -> Self: + return self + + def __next__(self) -> Dict[str, Any]: + try: + event: Dict[str, Any] = next(self.file_reader) + self.cursor["EventIndex"] = self._data_stream.event_index + self.cursor["BlockOffset"] = self._data_stream.object_position + return event + except StopIteration as exc: + self.cursor["EventIndex"] = self._data_stream.event_index + self.cursor["BlockOffset"] = self._data_stream.object_position + raise StopIteration from exc + + next = __next__ # Python 2 compatibility. + + def _initialize(self, chunk_cursor=None): + # To get all events in a chunk + blob_client = self.client.get_blob_client(self.chunk_path) + + file_offset = chunk_cursor.get("BlockOffset") if chunk_cursor else 0 + + # An offset means the avro data doesn't have avro header, + # so only when the data stream has a offset we need header stream to help + header_stream = ChangeFeedStreamer(blob_client) if file_offset else None + self._data_stream = ChangeFeedStreamer(blob_client, chunk_file_start=file_offset) + self.file_reader = DataFileReader(self._data_stream, DatumReader(), header_reader=header_stream) + + event_index = chunk_cursor.get("EventIndex") if chunk_cursor else 0 + for _ in range(0, event_index): + next(self.file_reader) diff --git a/sdk/storage/azure-storage-blob-changefeed/pyproject.toml b/sdk/storage/azure-storage-blob-changefeed/pyproject.toml index 6b5bea8d0363..b04c8ccc0c0e 100644 --- a/sdk/storage/azure-storage-blob-changefeed/pyproject.toml +++ b/sdk/storage/azure-storage-blob-changefeed/pyproject.toml @@ -1,5 +1,5 @@ [tool.azure-sdk-build] -mypy = false +mypy = true pyright = false -type_check_samples = false +type_check_samples = true black = false diff --git a/sdk/storage/azure-storage-blob-changefeed/samples/change_feed_samples.py b/sdk/storage/azure-storage-blob-changefeed/samples/change_feed_samples.py index 28d176eedd86..4e1d966e3748 100644 --- a/sdk/storage/azure-storage-blob-changefeed/samples/change_feed_samples.py +++ b/sdk/storage/azure-storage-blob-changefeed/samples/change_feed_samples.py @@ -76,24 +76,6 @@ def list_range_of_events(self): for event in change_feed: print(event) - def list_events_using_continuation_token(self): - - # Instantiate a ChangeFeedClient - cf_client = ChangeFeedClient("https://{}.blob.core.windows.net".format(self.ACCOUNT_NAME), - credential=self.ACCOUNT_KEY) - # to get continuation token - change_feed = cf_client.list_changes(results_per_page=2).by_page() - change_feed_page1 = next(change_feed) - for event in change_feed_page1: - print(event) - token = change_feed.continuation_token - - # restart using the continuation token - change_feed2 = cf_client.list_changes(results_per_page=56).by_page(continuation_token=token) - change_feed_page2 = next(change_feed2) - for event in change_feed_page2: - print(event) - def list_events_in_live_mode(self): # Instantiate a ChangeFeedClient cf_client = ChangeFeedClient("https://{}.blob.core.windows.net".format(self.ACCOUNT_NAME), @@ -105,7 +87,7 @@ def list_events_in_live_mode(self): for page in change_feed: for event in page: print(event) - token = change_feed.continuation_token + token = change_feed.continuation_token # type: ignore [attr-defined] sleep(60) print("continue printing events") @@ -116,6 +98,4 @@ def list_events_in_live_mode(self): sample.list_events_by_page() sample.list_all_events() sample.list_range_of_events() - sample.list_events_using_continuation_token() sample.list_events_in_live_mode() -