Skip to content

Commit

Permalink
Merge pull request #1 from awslabs/v0.1.0
Browse files Browse the repository at this point in the history
v0.1.0
  • Loading branch information
nateprewitt committed Jul 28, 2020
2 parents 3b48a5c + bab5729 commit b099801
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -3,3 +3,4 @@ __pycache__
.mypy_cache
.coverage
htmlcov
*.egg_info
11 changes: 4 additions & 7 deletions README.md
Expand Up @@ -21,7 +21,6 @@ To install from Github:
````bash
git clone https://github.com/awslabs/amazon-transcribe-streaming-sdk.git
cd amazon-transcribe-streaming-sdk
git submodule update --init
python -m pip install .
````

Expand Down Expand Up @@ -59,7 +58,7 @@ handler will simply print the text out to your interpreter.
"""
class MyEventHandler(TranscriptResultStreamHandler):
async def handle_transcript_event(self, transcript_event: TranscriptEvent):
# This handler can be implemented to handle audio as needed.
# This handler can be implemented to handle transcriptions as needed.
# Here's an example to get started.
results = transcript_event.transcript.results
for result in results:
Expand All @@ -69,7 +68,7 @@ class MyEventHandler(TranscriptResultStreamHandler):

async def basic_transcribe():
# Setup up our client with our chosen AWS region
client = TranscribeStreamingClient("us-west-2")
client = TranscribeStreamingClient(region="us-west-2")

# Start transcription to generate our async stream
stream = await client.start_stream_transcription(
Expand All @@ -86,11 +85,9 @@ async def basic_transcribe():
await stream.input_stream.send_audio_event(audio_chunk=chunk)
await stream.input_stream.end_stream()

asyncio.ensure_future(write_chunks())

# Instantiae our handler and start processing events
# Instantiate our handler and start processing events
handler = MyEventHandler(stream.output_stream)
await handler.handle_events()
await asyncio.gather(write_chunks(), handler.handle_events())

loop = asyncio.get_event_loop()
loop.run_until_complete(basic_transcribe())
Expand Down
2 changes: 1 addition & 1 deletion amazon_transcribe/__init__.py
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.


__version__ = "0.1.0dev"
__version__ = "0.1.0"

from awscrt.io import ClientBootstrap, DefaultHostResolver, EventLoopGroup

Expand Down
18 changes: 10 additions & 8 deletions amazon_transcribe/client.py
Expand Up @@ -14,6 +14,7 @@

import re
from binascii import unhexlify
from typing import Optional

from amazon_transcribe import AWSCRTEventLoop
from amazon_transcribe.auth import AwsCrtCredentialResolver, CredentialResolver
Expand All @@ -40,15 +41,15 @@

def create_client(region="us-east-2", endpoint_resolver=None):
"""Helper function for easy default client setup"""
return TranscribeStreamingClient(region, endpoint_resolver)
return TranscribeStreamingClient(region=region, endpoint_resolver=endpoint_resolver)


class TranscribeStreamingClient:
"""High level client for orchestrating setup and transmission of audio
streams to Amazon TranscribeStreaming service.
"""

def __init__(self, region, endpoint_resolver=None, credential_resolver=None):
def __init__(self, *, region, endpoint_resolver=None, credential_resolver=None):
if endpoint_resolver is None:
endpoint_resolver = _TranscribeRegionEndpointResolver()
self._endpoint_resolver: BaseEndpointResolver = endpoint_resolver
Expand All @@ -63,12 +64,13 @@ def __init__(self, region, endpoint_resolver=None, credential_resolver=None):

async def start_stream_transcription(
self,
language_code: str = None,
media_sample_rate_hz: int = None,
media_encoding: str = None,
vocabulary_name: str = None,
session_id: str = None,
vocab_filter_method: str = None,
*,
language_code: str,
media_sample_rate_hz: int,
media_encoding: str,
vocabulary_name: Optional[str] = None,
session_id: Optional[str] = None,
vocab_filter_method: Optional[str] = None,
) -> StartStreamTranscriptionEventStream:
"""Coordinate transcription settings and start stream."""
transcribe_streaming_request = StartStreamTranscriptionRequest(
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,4 +1,4 @@
pytest<5.5
pytest-asyncio<0.15.0
mypy<1.0
black<20.0
black==19.10b0
2 changes: 1 addition & 1 deletion tests/integration/test_client.py
Expand Up @@ -11,7 +11,7 @@
class TestClientStreaming:
@pytest.fixture
def client(self):
return TranscribeStreamingClient("us-west-2")
return TranscribeStreamingClient(region="us-west-2")

@pytest.fixture
def wav_bytes(self):
Expand Down
10 changes: 4 additions & 6 deletions tests/integration/test_handlers.py
Expand Up @@ -33,7 +33,7 @@ def chunks(self):

@pytest.mark.asyncio
async def test_base_transcribe_handler(self, chunks):
client = TranscribeStreamingClient("us-west-2")
client = TranscribeStreamingClient(region="us-west-2")

stream = await client.start_stream_transcription(
language_code="en-US", media_sample_rate_hz=16000, media_encoding="pcm",
Expand All @@ -44,14 +44,13 @@ async def write_chunks():
await stream.input_stream.send_audio_event(audio_chunk=chunk)
await stream.input_stream.end_stream()

write_task = asyncio.ensure_future(write_chunks())
handler = TranscriptResultStreamHandler(stream.output_stream)
with pytest.raises(NotImplementedError):
await asyncio.gather(handler.handle_events(), write_task)
await asyncio.gather(write_chunks(), handler.handle_events())

@pytest.mark.asyncio
async def test_extended_transcribe_handler(self, chunks):
client = TranscribeStreamingClient("us-west-2")
client = TranscribeStreamingClient(region="us-west-2")

stream = await client.start_stream_transcription(
language_code="en-US", media_sample_rate_hz=16000, media_encoding="pcm",
Expand All @@ -62,7 +61,6 @@ async def write_chunks():
await stream.input_stream.send_audio_event(audio_chunk=chunk)
await stream.input_stream.end_stream()

write_task = asyncio.ensure_future(write_chunks())
handler = self.ExampleStreamHandler(stream.output_stream)
await asyncio.gather(handler.handle_events(), write_task)
await asyncio.gather(write_chunks(), handler.handle_events())
assert len(handler.result_holder) > 0
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Expand Up @@ -3,7 +3,7 @@

class TestClientSetup:
def test_basic_client_setup(self):
client = TranscribeStreamingClient("us-west-2")
client = TranscribeStreamingClient(region="us-west-2")
assert client.service_name == "transcribe"
assert client.region == "us-west-2"
assert client._endpoint_resolver is not None
Expand Down

0 comments on commit b099801

Please sign in to comment.