Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.44.3"
__version__ = "0.45.0"
23 changes: 23 additions & 0 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,11 @@ class RawTranscriptionConfig(BaseModel):
The speech model to use for the transcription.
"""

speech_models: Optional[List[str]] = None
"""
The list of speech models to use for the transcription in priority order.
"""

prompt: Optional[str] = None
"The prompt used to generate the transcript with the Slam-1 speech model. Can't be used together with `keyterms_prompt`."

Expand Down Expand Up @@ -708,6 +713,7 @@ def __init__(
speech_threshold: Optional[float] = None,
raw_transcription_config: Optional[RawTranscriptionConfig] = None,
speech_model: Optional[SpeechModel] = None,
speech_models: Optional[List[str]] = None,
prompt: Optional[str] = None,
keyterms_prompt: Optional[List[str]] = None,
) -> None:
Expand Down Expand Up @@ -801,6 +807,7 @@ def __init__(
self.language_detection_options = language_detection_options
self.speech_threshold = speech_threshold
self.speech_model = speech_model
self.speech_models = speech_models
self.prompt = prompt
self.keyterms_prompt = keyterms_prompt

Expand Down Expand Up @@ -831,6 +838,16 @@ def speech_model(self, speech_model: Optional[SpeechModel]) -> None:
"Sets the speech model to use for the transcription."
self._raw_transcription_config.speech_model = speech_model

@property
def speech_models(self) -> Optional[List[str]]:
"The list of speech models to use for the transcription in priority order."
return self._raw_transcription_config.speech_models

@speech_models.setter
def speech_models(self, speech_models: Optional[List[str]]) -> None:
"Sets the list of speech models to use for the transcription in priority order."
self._raw_transcription_config.speech_models = speech_models

@property
def prompt(self) -> Optional[str]:
"The prompt to use for the transcription."
Expand Down Expand Up @@ -1902,6 +1919,9 @@ class BaseTranscript(BaseModel):
speech_model: Optional[SpeechModel] = None
"The speech model to use for the transcription."

speech_models: Optional[List[str]] = None
"The list of speech models to use for the transcription in priority order."

prompt: Optional[str] = None
"The prompt used to generate the transcript with the Slam-1 speech model. Can't be used together with `keyterms_prompt`."

Expand Down Expand Up @@ -1973,6 +1993,9 @@ class TranscriptResponse(BaseTranscript):
speech_model: Optional[SpeechModel] = None
"The speech model used for the transcription"

speech_model_used: Optional[str] = None
"The actual speech model that was used for the transcription"

prompt: Optional[str] = None
"When Slam-1 is enabled, the prompt used to generate the transcript"

Expand Down
37 changes: 36 additions & 1 deletion tests/unit/test_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import assemblyai as aai
from assemblyai.api import ENDPOINT_TRANSCRIPT
from tests.unit import factories
from assemblyai.types import SpeechModel

aai.settings.api_key = "test"

Expand Down Expand Up @@ -451,3 +450,39 @@ def test_delete_by_id_async(httpx_mock: HTTPXMock):
assert transcript.error is None
assert transcript.text == mock_transcript_response["text"]
assert transcript.audio_url == mock_transcript_response["audio_url"]


def test_speech_model_used_field_deserialization():
"""
Tests that the speech_model_used field can be properly deserialized.
"""
mock_transcript_response = factories.generate_dict_factory(
factories.TranscriptCompletedResponseFactory
)()

# Add speech_model_used to the mock response
mock_transcript_response["speech_model_used"] = "best"

transcript_response = aai.types.TranscriptResponse(**mock_transcript_response)

assert transcript_response.speech_model_used == "best"


def test_speech_model_used_field_missing():
"""
Tests that the speech_model_used field being missing does not break deserialization.
This is important because the field has not yet been added to the API for all users.
"""
mock_transcript_response = factories.generate_dict_factory(
factories.TranscriptCompletedResponseFactory
)()

# Explicitly ensure speech_model_used is not in the response
if "speech_model_used" in mock_transcript_response:
del mock_transcript_response["speech_model_used"]

# This should not raise an exception
transcript_response = aai.types.TranscriptResponse(**mock_transcript_response)

# The field should be None when not present
assert transcript_response.speech_model_used is None