Skip to content

Commit

Permalink
Merge pull request #231 from Wordcab/230-add-wordcabtranscript-as-source
Browse files Browse the repository at this point in the history
implement WordcabTranscriptSource
  • Loading branch information
Thomas Chaigneau committed Mar 17, 2023
2 parents d7d5816 + 46a1089 commit 031f5f0
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 34 deletions.
9 changes: 5 additions & 4 deletions src/wordcab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ListTranscripts,
Stats,
SummarizeJob,
WordcabTranscriptSource,
)


Expand Down Expand Up @@ -94,7 +95,7 @@ def get_stats(

@no_type_check
def start_extract(
source_object: Union[BaseSource, InMemorySource],
source_object: Union[BaseSource, InMemorySource, WordcabTranscriptSource],
display_name: str,
ephemeral_data: Optional[bool] = False,
only_api: Optional[bool] = True,
Expand All @@ -113,7 +114,7 @@ def start_extract(
Parameters
----------
source_object : BaseSource or InMemorySource
source_object : BaseSource, InMemorySource or WordcabTranscriptSource
The source object to use for the extraction job.
display_name : str
The display name of the extraction job. This is useful for retrieving the job later.
Expand Down Expand Up @@ -154,7 +155,7 @@ def start_extract(

@no_type_check
def start_summary(
source_object: Union[BaseSource, InMemorySource],
source_object: Union[BaseSource, InMemorySource, WordcabTranscriptSource],
display_name: str,
summary_type: str,
context: Optional[Union[str, List[str]]] = None,
Expand All @@ -173,7 +174,7 @@ def start_summary(
Parameters
----------
source_object : BaseSource or InMemorySource
source_object : BaseSource, InMemorySource or WordcabTranscriptSource
The source object to summarize.
display_name : str
The display name of the summary. This is useful for retrieving the job later.
Expand Down
9 changes: 5 additions & 4 deletions src/wordcab/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
StructuredSummary,
SummarizeJob,
TranscriptUtterance,
WordcabTranscriptSource,
)
from .login import get_token
from .utils import (
Expand Down Expand Up @@ -147,7 +148,7 @@ def get_stats(

def start_extract( # noqa: C901
self,
source_object: Union[BaseSource, InMemorySource],
source_object: Union[BaseSource, InMemorySource, WordcabTranscriptSource],
display_name: str,
ephemeral_data: Optional[bool] = False,
only_api: Optional[bool] = True,
Expand Down Expand Up @@ -207,7 +208,7 @@ def start_extract( # noqa: C901
headers["Authorization"] = f"Bearer {self.api_key}"

pipelines = _format_pipelines(pipelines)
params: Dict[str, str] = {
params: Dict[str, Union[str, None]] = {
"source": source,
"display_name": display_name,
"ephemeral_data": str(ephemeral_data).lower(),
Expand Down Expand Up @@ -258,7 +259,7 @@ def start_extract( # noqa: C901

def start_summary( # noqa: C901
self,
source_object: Union[BaseSource, InMemorySource],
source_object: Union[BaseSource, InMemorySource, WordcabTranscriptSource],
display_name: str,
summary_type: str,
context: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -381,7 +382,7 @@ def start_summary( # noqa: C901
headers["Authorization"] = f"Bearer {self.api_key}"

pipelines = _format_pipelines(pipelines)
params: Dict[str, str] = {
params: Dict[str, Union[str, None]] = {
"source": source,
"display_name": display_name,
"ephemeral_data": str(ephemeral_data).lower(),
Expand Down
20 changes: 16 additions & 4 deletions src/wordcab/core_objects/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@dataclass
class BaseSource:
"""Base class for all source objects except for InMemorySource. It is not meant to be used directly.
"""Base class for AudioSource and GenericSource objects. It is not meant to be used directly.
Parameters
----------
Expand Down Expand Up @@ -377,20 +377,32 @@ def prepare_headers(self) -> dict:


@dataclass
class WordcabTranscriptSource(BaseSource):
class WordcabTranscriptSource:
"""Wordcab transcript source object."""

transcript_id: Optional[str] = field(default=None)
source: str = field(init=False)

def __post_init__(self) -> None:
"""Post-init method."""
super().__post_init__()
if self.transcript_id is None:
raise ValueError(
"Please provide a `transcript_id` to initialize a WordcabTranscriptSource object."
)
self.source = "wordcab_transcript"
raise NotImplementedError("Wordcab transcript source is not implemented yet.")

def __repr__(self) -> str:
"""Representation method."""
return f"WordcabTranscriptSource(transcript_id={self.transcript_id})"

def prepare_payload(self) -> None:
"""Prepare payload for API request."""
return None

def prepare_headers(self) -> Dict[str, str]:
"""Prepare headers for API request."""
self.headers = {"Accept": "application/json"}
return self.headers


@dataclass
Expand Down
17 changes: 9 additions & 8 deletions tests/core_objects/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,20 +305,21 @@ def test_in_memory_source() -> None:
}


def test_wordcab_transcript_source() -> None:
"""Test the WordcabTranscriptSource object."""
source_obj = WordcabTranscriptSource(transcript_id="test")

assert source_obj.transcript_id == "test"
assert source_obj.source == "wordcab_transcript"
assert source_obj.__repr__() == "WordcabTranscriptSource(transcript_id=test)"


def test_signed_url_source() -> None:
"""Test the SignedURLSource object."""
with pytest.raises(NotImplementedError):
SignedURLSource(url="https://example.com")


def test_wordcab_transcript_source() -> None:
"""Test the WordcabTranscriptSource object."""
with pytest.raises(NotImplementedError):
WordcabTranscriptSource(url="https://example.com", transcript_id="123456")
with pytest.raises(ValueError):
WordcabTranscriptSource(url="https://example.com")


def test_rev_source() -> None:
"""Test the RevSource object."""
with pytest.raises(NotImplementedError):
Expand Down
73 changes: 59 additions & 14 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
StructuredSummary,
SummarizeJob,
TranscriptUtterance,
WordcabTranscriptSource,
)


Expand Down Expand Up @@ -221,15 +222,9 @@ def test_start_extract(
# )


def test_start_summary(
def test_start_summary_errors(
base_source: BaseSource,
generic_source_txt: GenericSource,
generic_source_json: GenericSource,
generic_url_json: GenericSource,
audio_source: AudioSource,
audio_url_source: AudioSource,
in_memory_source: InMemorySource,
context_elements: List[str],
api_key: str,
) -> None:
"""Test client start_summary method."""
Expand Down Expand Up @@ -278,7 +273,10 @@ def test_start_summary(
summary_lens=3,
)

# Test in memory source

def test_summary_in_memory(in_memory_source: InMemorySource, api_key: str) -> None:
"""Test client start_summary method with in-memory source."""
with Client(api_key=api_key) as client:
in_memory_job = client.start_summary(
source_object=in_memory_source,
display_name="test-sdk-in-memory",
Expand All @@ -296,7 +294,10 @@ def test_start_summary(
only_api=True,
)

# Test generic source with txt file

def test_summary_generic_txt(generic_source_txt: GenericSource, api_key: str) -> None:
"""Test client start_summary method with generic text source."""
with Client(api_key=api_key) as client:
txt_job = client.start_summary(
source_object=generic_source_txt,
display_name="test-sdk-txt",
Expand All @@ -313,7 +314,12 @@ def test_start_summary(
only_api=True,
)

# Test generic source with txt file and context

def test_start_summary_generic_context(
context_elements: List[str], generic_source_txt: GenericSource, api_key: str
) -> None:
"""Test client start_summary method with generic text source."""
with Client(api_key=api_key) as client:
txt_job = client.start_summary(
source_object=generic_source_txt,
display_name="test-sdk-txt",
Expand All @@ -331,7 +337,12 @@ def test_start_summary(
only_api=True,
)

# Test generic source with json file

def test_start_summary_generic_json(
generic_source_json: GenericSource, api_key: str
) -> None:
"""Test client start_summary method with generic json source."""
with Client(api_key=api_key) as client:
json_job = client.start_summary(
source_object=generic_source_json,
display_name="test-sdk-json",
Expand All @@ -349,7 +360,12 @@ def test_start_summary(
only_api=True,
)

# Test generic source with url json file

def test_start_summary_generic_url(
generic_url_json: GenericSource, api_key: str
) -> None:
"""Test client start_summary method with generic url source."""
with Client(api_key=api_key) as client:
json_job = client.start_summary(
source_object=generic_url_json,
display_name="test-sdk-json-url",
Expand All @@ -367,7 +383,10 @@ def test_start_summary(
only_api=True,
)

# Test audio source

def test_start_summary_audio(audio_source: AudioSource, api_key: str) -> None:
"""Test client start_summary method with audio source."""
with Client(api_key=api_key) as client:
audio_job = client.start_summary(
source_object=audio_source,
display_name="test-sdk-audio",
Expand All @@ -385,7 +404,10 @@ def test_start_summary(
only_api=True,
)

# Test audio url source

def test_start_summary_audio_url(audio_url_source: AudioSource, api_key: str) -> None:
"""Test client start_summary method with audio url source."""
with Client(api_key=api_key) as client:
audio_job = client.start_summary(
source_object=audio_url_source,
display_name="test-sdk-audio-url",
Expand All @@ -404,6 +426,29 @@ def test_start_summary(
)


def test_start_summary_wordcab_transcript(api_key: str) -> None:
"""Test client start_summary method with WordcabTranscriptSource."""
with Client(api_key=api_key) as client:
wordcab_transcript_job = client.start_summary(
source_object=WordcabTranscriptSource(
transcript_id="generic_transcript_MXzewRcYCnJXKFTLewMYC53uTNyWCEeo"
),
display_name="test-sdk-wordcab-transcript",
summary_type="narrative",
summary_lens=1,
)
assert isinstance(wordcab_transcript_job, SummarizeJob)
assert wordcab_transcript_job.display_name == "test-sdk-wordcab-transcript"
assert wordcab_transcript_job.job_name is not None
assert wordcab_transcript_job.source == "wordcab_transcript"
assert wordcab_transcript_job.settings == JobSettings(
ephemeral_data=False,
pipeline="transcribe,summarize",
split_long_utterances=False,
only_api=True,
)


def test_list_jobs(api_key: str) -> None:
"""Test client list_jobs method."""
with Client(api_key=api_key) as client:
Expand Down

0 comments on commit 031f5f0

Please sign in to comment.