From 482391b16630e83ab4f310ec75938bafbf7c9745 Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Fri, 2 Jun 2023 10:26:55 -0400 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: e6895bd6ffd174f8dfb4272f419941a1f2e3da05 --- README.md | 44 +++++++++- assemblyai/transcriber.py | 6 ++ assemblyai/types.py | 13 +++ setup.py | 2 +- tests/unit/factories.py | 2 +- tests/unit/test_summarization.py | 137 +++++++++++++++++++++++++++++++ 6 files changed, 199 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_summarization.py diff --git a/README.md b/README.md index a6a03cc..754eb2a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ --- + [![CI Passing](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml/badge.svg)](https://github.com/AssemblyAI/assemblyai-python-sdk/actions/workflows/test.yml) [![GitHub License](https://img.shields.io/github/license/AssemblyAI/assemblyai-python-sdk)](https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/LICENSE) [![PyPI version](https://badge.fury.io/py/assemblyai.svg)](https://badge.fury.io/py/assemblyai) @@ -26,14 +27,12 @@ With a single API call, get access to AI models built on the latest AI breakthro - [Playgrounds](#playgrounds) - [Advanced](#advanced-todo) - # Documentation Visit our [AssemblyAI API Documentation](https://www.assemblyai.com/docs) to get an overview of our models! # Quick Start - ## Installation ```bash @@ -66,6 +65,7 @@ transcript = transcriber.transcribe("./my-local-audio-file.wav") print(transcript.text) ``` +
@@ -79,6 +79,7 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3") print(transcript.text) ``` +
@@ -96,6 +97,7 @@ print(transcript.export_subtitles_srt()) # in VTT format print(transcript.export_subtitles_vtt()) ``` +
@@ -115,6 +117,7 @@ paragraphs = transcript.get_paragraphs() for paragraph in paragraphs: print(paragraph.text) ``` +
@@ -131,6 +134,7 @@ matches = transcript.word_search(["price", "product"]) for match in matches: print(f"Found '{match.text}' {match.count} times in the transcript") ``` +
@@ -152,9 +156,40 @@ transcript = transcriber.transcribe("https://example.org/audio.mp3", config) print(transcript.text) ``` + +
+ +
+ Summarize the content of a transcript + +```python +import assemblyai as aai + +transcriber = aai.Transcriber() +transcript = transcriber.transcribe( + "https://example.org/audio.mp3", + config=aai.TranscriptionConfig(summarize=True) +) + +print(transcript.summary) +``` + +By default, the summarization model will be `informative` and the summarization type will be `bullets`. [Read more about summarization models and types here](https://www.assemblyai.com/docs/Models/summarization#types-and-models). + +To change the model and/or type, pass additional parameters to the `TranscriptionConfig`: + +```python +config=aai.TranscriptionConfig( + summarize=True, + summary_model=aai.SummarizationModel.catchy, + summary_type=aai.Summarizationtype.headline +) +``` +
--- + ### **LeMUR Examples**
@@ -175,6 +210,7 @@ summary = transcript_group.lemur.summarize(context="Customers asking for cars", print(summary) ``` +
@@ -195,6 +231,7 @@ feedback = transcript_group.lemur.ask_coach(context="Who was the best interviewe print(feedback) ``` +
@@ -218,6 +255,7 @@ for result in result: print(f"Question: {result.question}") print(f"Answer: {result.answer}") ``` +
--- @@ -247,8 +285,8 @@ config.set_pii_redact( transcriber = aai.Transcriber() transcript = transcriber.transcribe("https://example.org/audio.mp3", config) ``` - + --- diff --git a/assemblyai/transcriber.py b/assemblyai/transcriber.py index 4d37be4..5a5ad23 100644 --- a/assemblyai/transcriber.py +++ b/assemblyai/transcriber.py @@ -204,6 +204,12 @@ def text(self) -> Optional[str]: return self._impl.transcript.text + @property + def summary(self) -> Optional[str]: + "The summarization of the transcript" + + return self._impl.transcript.summary + @property def status(self) -> types.TranscriptStatus: "The current status of the transcript" diff --git a/assemblyai/types.py b/assemblyai/types.py index 7f64907..647947f 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -1004,6 +1004,16 @@ def set_summarize( return self + # Validate that required parameters are also set + if self._raw_transcription_config.punctuate == False: + raise ValueError( + "If `summarization` is enabled, then `punctuate` must not be disabled" + ) + if self._raw_transcription_config.format_text == False: + raise ValueError( + "If `summarization` is enabled, then `format_text` must not be disabled" + ) + self._raw_transcription_config.summarization = True self._raw_transcription_config.summary_model = model self._raw_transcription_config.summary_type = type @@ -1379,6 +1389,9 @@ class TranscriptResponse(BaseTranscript): webhook_auth: Optional[bool] "Whether the webhook was sent with an HTTP authentication header" + summary: Optional[str] + "The summarization of the transcript" + # auto_highlights_result: Optional[AutohighlightResponse] = None # "The list of results when enabling Automatic Transcript Highlights" diff --git a/setup.py b/setup.py index 0edeff5..6cffa7b 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="assemblyai", - version="0.5.1", + version="0.6.0", description="AssemblyAI Python SDK", author="AssemblyAI", author_email="engineering.sdk@assemblyai.com", diff --git a/tests/unit/factories.py b/tests/unit/factories.py index c6e200f..d7fbdd8 100644 --- a/tests/unit/factories.py +++ b/tests/unit/factories.py @@ -200,7 +200,7 @@ class Meta: audio_duration = factory.Faker("pyint") -def generate_dict_factory(f: factory.Factory) -> Callable[[None], Dict[str, Any]]: +def generate_dict_factory(f: factory.Factory) -> Callable[[], Dict[str, Any]]: """ Creates a dict factory from the given *Factory class. diff --git a/tests/unit/test_summarization.py b/tests/unit/test_summarization.py new file mode 100644 index 0000000..6713e1b --- /dev/null +++ b/tests/unit/test_summarization.py @@ -0,0 +1,137 @@ +import json +from typing import Any, Dict + +import httpx +import pytest +from pytest_httpx import HTTPXMock + +import assemblyai as aai +from tests.unit import factories + +aai.settings.api_key = "test" + + +def __submit_request(httpx_mock: HTTPXMock, **params) -> Dict[str, Any]: + """ + Helper function to abstract calling transcriber with given parameters, + and perform some common assertions. + + Returns the body (dictionary) of the initial submission request. + """ + summary = "example summary" + + mock_transcript_response = factories.generate_dict_factory( + factories.TranscriptCompletedResponseFactory + )() + + # Mock initial submission response + httpx_mock.add_response( + url=f"{aai.settings.base_url}/transcript", + status_code=httpx.codes.OK, + method="POST", + json=mock_transcript_response, + ) + + # Mock polling-for-completeness response, with mock summary result + httpx_mock.add_response( + url=f"{aai.settings.base_url}/transcript/{mock_transcript_response['id']}", + status_code=httpx.codes.OK, + method="GET", + json={**mock_transcript_response, "summary": summary}, + ) + + # == Make API request via SDK == + transcript = aai.Transcriber().transcribe( + data="https://example.org/audio.wav", + config=aai.TranscriptionConfig( + **params, + ), + ) + + # Check that submission and polling requests were made + assert len(httpx_mock.get_requests()) == 2 + + # Check that summary field from response was traced back through SDK classes + assert transcript.summary == summary + + # Extract and return body of initial submission request + request = httpx_mock.get_requests()[0] + return json.loads(request.content.decode()) + + +@pytest.mark.parametrize("required_field", ["punctuate", "format_text"]) +def test_summarization_fails_without_required_field( + httpx_mock: HTTPXMock, required_field: str +): + """ + Tests whether the SDK raises an error before making a request + if `summarization` is enabled and the given required field is disabled + """ + with pytest.raises(ValueError) as error: + __submit_request(httpx_mock, summarization=True, **{required_field: False}) + + # Check that the error message informs the user of the invalid parameter + assert required_field in str(error) + + # Check that the error was raised before any requests were made + assert len(httpx_mock.get_requests()) == 0 + + # Inform httpx_mock that it's okay we didn't make any requests + httpx_mock.reset(False) + + +def test_summarization_disabled_by_default(httpx_mock: HTTPXMock): + """ + Tests that excluding `summarization` from the `TranscriptionConfig` will + result in the default behavior of it being excluded from the request body + """ + request_body = __submit_request(httpx_mock) + assert request_body.get("summarization") is None + + +def test_default_summarization_params(httpx_mock: HTTPXMock): + """ + Tests that including `summarization=True` in the `TranscriptionConfig` + will result in `summarization=True` in the request body. + """ + request_body = __submit_request(httpx_mock, summarization=True) + assert request_body.get("summarization") == True + + +def test_summarization_with_params(httpx_mock: HTTPXMock): + """ + Tests that including additional summarization parameters along with + `summarization=True` in the `TranscriptionConfig` will result in all + parameters being included in the request as well. + """ + + summary_model = aai.SummarizationModel.conversational + summary_type = aai.SummarizationType.bullets + + request_body = __submit_request( + httpx_mock, + summarization=True, + summary_model=summary_model, + summary_type=summary_type, + ) + + assert request_body.get("summarization") == True + assert request_body.get("summary_model") == summary_model + assert request_body.get("summary_type") == summary_type + + +def test_summarization_params_excluded_when_disabled(httpx_mock: HTTPXMock): + """ + Tests that additional summarization parameters are excluded from the submission + request body if `summarization` itself is not enabled. + """ + request_body = __submit_request( + httpx_mock, + summarization=False, + summary_model=aai.SummarizationModel.conversational, + summary_type=aai.SummarizationType.bullets, + ) + + assert request_body.get("summarization") is None + assert request_body.get("summary_model") is None + assert request_body.get("summary_type") is None