diff --git a/.gitignore b/.gitignore index e1dab21..da0810c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ __pycache__ .mypy_cache .coverage htmlcov +*.egg_info diff --git a/README.md b/README.md index c4876ce..4f912dd 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,6 @@ To install from Github: ````bash git clone https://github.com/awslabs/amazon-transcribe-streaming-sdk.git cd amazon-transcribe-streaming-sdk -git submodule update --init python -m pip install . ```` @@ -59,7 +58,7 @@ handler will simply print the text out to your interpreter. """ class MyEventHandler(TranscriptResultStreamHandler): async def handle_transcript_event(self, transcript_event: TranscriptEvent): - # This handler can be implemented to handle audio as needed. + # This handler can be implemented to handle transcriptions as needed. # Here's an example to get started. results = transcript_event.transcript.results for result in results: @@ -69,7 +68,7 @@ class MyEventHandler(TranscriptResultStreamHandler): async def basic_transcribe(): # Setup up our client with our chosen AWS region - client = TranscribeStreamingClient("us-west-2") + client = TranscribeStreamingClient(region="us-west-2") # Start transcription to generate our async stream stream = await client.start_stream_transcription( @@ -86,11 +85,9 @@ async def basic_transcribe(): await stream.input_stream.send_audio_event(audio_chunk=chunk) await stream.input_stream.end_stream() - asyncio.ensure_future(write_chunks()) - - # Instantiae our handler and start processing events + # Instantiate our handler and start processing events handler = MyEventHandler(stream.output_stream) - await handler.handle_events() + await asyncio.gather(write_chunks(), handler.handle_events()) loop = asyncio.get_event_loop() loop.run_until_complete(basic_transcribe()) diff --git a/amazon_transcribe/__init__.py b/amazon_transcribe/__init__.py index 30eae91..a1ed56b 100644 --- a/amazon_transcribe/__init__.py +++ b/amazon_transcribe/__init__.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. -__version__ = "0.1.0dev" +__version__ = "0.1.0" from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup diff --git a/amazon_transcribe/client.py b/amazon_transcribe/client.py index da8c611..fc4d81a 100644 --- a/amazon_transcribe/client.py +++ b/amazon_transcribe/client.py @@ -14,6 +14,7 @@ import re from binascii import unhexlify +from typing import Optional from amazon_transcribe import AWSCRTEventLoop from amazon_transcribe.auth import AwsCrtCredentialResolver, CredentialResolver @@ -40,7 +41,7 @@ def create_client(region="us-east-2", endpoint_resolver=None): """Helper function for easy default client setup""" - return TranscribeStreamingClient(region, endpoint_resolver) + return TranscribeStreamingClient(region=region, endpoint_resolver=endpoint_resolver) class TranscribeStreamingClient: @@ -48,7 +49,7 @@ class TranscribeStreamingClient: streams to Amazon TranscribeStreaming service. """ - def __init__(self, region, endpoint_resolver=None, credential_resolver=None): + def __init__(self, *, region, endpoint_resolver=None, credential_resolver=None): if endpoint_resolver is None: endpoint_resolver = _TranscribeRegionEndpointResolver() self._endpoint_resolver: BaseEndpointResolver = endpoint_resolver @@ -63,12 +64,13 @@ def __init__(self, region, endpoint_resolver=None, credential_resolver=None): async def start_stream_transcription( self, - language_code: str = None, - media_sample_rate_hz: int = None, - media_encoding: str = None, - vocabulary_name: str = None, - session_id: str = None, - vocab_filter_method: str = None, + *, + language_code: str, + media_sample_rate_hz: int, + media_encoding: str, + vocabulary_name: Optional[str] = None, + session_id: Optional[str] = None, + vocab_filter_method: Optional[str] = None, ) -> StartStreamTranscriptionEventStream: """Coordinate transcription settings and start stream.""" transcribe_streaming_request = StartStreamTranscriptionRequest( diff --git a/requirements.txt b/requirements.txt index 5f28e94..381c121 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pytest<5.5 pytest-asyncio<0.15.0 mypy<1.0 -black<20.0 +black==19.10b0 diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index b095b2e..344a776 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -11,7 +11,7 @@ class TestClientStreaming: @pytest.fixture def client(self): - return TranscribeStreamingClient("us-west-2") + return TranscribeStreamingClient(region="us-west-2") @pytest.fixture def wav_bytes(self): diff --git a/tests/integration/test_handlers.py b/tests/integration/test_handlers.py index 70f87f7..35f3a30 100644 --- a/tests/integration/test_handlers.py +++ b/tests/integration/test_handlers.py @@ -33,7 +33,7 @@ def chunks(self): @pytest.mark.asyncio async def test_base_transcribe_handler(self, chunks): - client = TranscribeStreamingClient("us-west-2") + client = TranscribeStreamingClient(region="us-west-2") stream = await client.start_stream_transcription( language_code="en-US", media_sample_rate_hz=16000, media_encoding="pcm", @@ -44,14 +44,13 @@ async def write_chunks(): await stream.input_stream.send_audio_event(audio_chunk=chunk) await stream.input_stream.end_stream() - write_task = asyncio.ensure_future(write_chunks()) handler = TranscriptResultStreamHandler(stream.output_stream) with pytest.raises(NotImplementedError): - await asyncio.gather(handler.handle_events(), write_task) + await asyncio.gather(write_chunks(), handler.handle_events()) @pytest.mark.asyncio async def test_extended_transcribe_handler(self, chunks): - client = TranscribeStreamingClient("us-west-2") + client = TranscribeStreamingClient(region="us-west-2") stream = await client.start_stream_transcription( language_code="en-US", media_sample_rate_hz=16000, media_encoding="pcm", @@ -62,7 +61,6 @@ async def write_chunks(): await stream.input_stream.send_audio_event(audio_chunk=chunk) await stream.input_stream.end_stream() - write_task = asyncio.ensure_future(write_chunks()) handler = self.ExampleStreamHandler(stream.output_stream) - await asyncio.gather(handler.handle_events(), write_task) + await asyncio.gather(write_chunks(), handler.handle_events()) assert len(handler.result_holder) > 0 diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 03e706e..7bc5b1d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -3,7 +3,7 @@ class TestClientSetup: def test_basic_client_setup(self): - client = TranscribeStreamingClient("us-west-2") + client = TranscribeStreamingClient(region="us-west-2") assert client.service_name == "transcribe" assert client.region == "us-west-2" assert client._endpoint_resolver is not None