Skip to content

Commit

Permalink
fix(sdk/go): clean up api for disabling partial transcripts (#4499) (#19
Browse files Browse the repository at this point in the history
)

GitOrigin-RevId: db8fdf49aaaaf2442b6aa9508751ff63a686b2fc
  • Loading branch information
marcusolsson committed Apr 16, 2024
1 parent f7cf945 commit 56d6173
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 15 deletions.
2 changes: 0 additions & 2 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ func TestIntegration_RealTime_WithoutPartialTranscripts(t *testing.T) {
},
}),
WithRealTimeSampleRate(sampleRate),
WithRealTimeDisablePartialTranscripts(true),
)

ctx := context.Background()
Expand Down Expand Up @@ -217,7 +216,6 @@ func TestIntegration_RealTime_WithExtraSessionInfo(t *testing.T) {
},
}),
WithRealTimeSampleRate(sampleRate),
WithRealTimeDisablePartialTranscripts(true),
)

ctx := context.Background()
Expand Down
16 changes: 4 additions & 12 deletions realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,9 @@ type RealTimeClient struct {

transcriber *RealTimeTranscriber

sampleRate int
encoding RealTimeEncoding
wordBoost []string
disablePartialTranscripts bool
sampleRate int
encoding RealTimeEncoding
wordBoost []string
}

func (c *RealTimeClient) isSessionOpen() bool {
Expand Down Expand Up @@ -222,13 +221,6 @@ func WithRealTimeWordBoost(wordBoost []string) RealTimeClientOption {
}
}

// WithRealTimeDisablePartialTranscripts disables partial transcripts during real-time transcription.
func WithRealTimeDisablePartialTranscripts(disable bool) RealTimeClientOption {
return func(rtc *RealTimeClient) {
rtc.disablePartialTranscripts = disable
}
}

// RealTimeEncoding is the encoding format for the audio data.
type RealTimeEncoding string

Expand Down Expand Up @@ -474,7 +466,7 @@ func (c *RealTimeClient) queryFromOptions() string {
}

// Disable partial transcripts
if c.disablePartialTranscripts {
if c.transcriber.OnPartialTranscript == nil {
values.Set("disable_partial_transcripts", "true")
}

Expand Down
80 changes: 79 additions & 1 deletion realtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ func TestRealTime_Send(t *testing.T) {
WithRealTimeWordBoost([]string{"foo", "bar"}),
WithRealTimeEncoding(RealTimeEncodingPCMMulaw),
WithRealTimeSampleRate(8_000),
WithRealTimeDisablePartialTranscripts(true),
)

ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
Expand Down Expand Up @@ -479,3 +478,82 @@ func upgradeRequest(w http.ResponseWriter, r *http.Request) (*websocket.Conn, fu
return conn.Close(websocket.StatusInternalError, "websocket closed unexpectedly")
}
}

func TestRealTime_DisablePartialTranscriptsIfNoCallback(t *testing.T) {
t.Parallel()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

conn, teardown := upgradeRequest(w, r)
defer teardown()

disablePartialTranscripts := r.URL.Query().Get("disable_partial_transcripts")
require.Equal(t, "true", disablePartialTranscripts)

var err error

err = beginSession(ctx, conn)
require.NoError(t, err)

err = terminateSession(ctx, conn)
require.NoError(t, err)
}))
defer ts.Close()

client := NewRealTimeClientWithOptions(
WithRealTimeBaseURL(ts.URL),
WithRealTimeTranscriber(&RealTimeTranscriber{}),
)

ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

var err error

err = client.Connect(ctx)
require.NoError(t, err)

err = client.Disconnect(ctx, true)
require.NoError(t, err)
}

func TestRealTime_EnablePartialTranscriptsIfCallback(t *testing.T) {
t.Parallel()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

conn, teardown := upgradeRequest(w, r)
defer teardown()

require.False(t, r.URL.Query().Has("disable_partial_transcripts"))

var err error

err = beginSession(ctx, conn)
require.NoError(t, err)

err = terminateSession(ctx, conn)
require.NoError(t, err)
}))
defer ts.Close()

client := NewRealTimeClientWithOptions(
WithRealTimeBaseURL(ts.URL),
WithRealTimeTranscriber(&RealTimeTranscriber{
OnPartialTranscript: func(_ PartialTranscript) {},
}),
)

ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()

var err error

err = client.Connect(ctx)
require.NoError(t, err)

err = client.Disconnect(ctx, true)
require.NoError(t, err)
}

0 comments on commit 56d6173

Please sign in to comment.