Skip to content

Commit

Permalink
Sync from internal repo (2024-03-06) (#13)
Browse files Browse the repository at this point in the history
* feat(sdk/go): add support for word search (#3970)

GitOrigin-RevId: fcb7b107da43b42d9cc031a1d3a3bfcb98b9d547

* feat(sdk/go): send real-time audio as binary data (#3990)

GitOrigin-RevId: 78c7ed9a0e1a46666622880e93ef783e3a0c2d25

* fix(sdk/go): change lemur context from json.RawMessage to interface{} (#3986)

GitOrigin-RevId: 1fd06e1e9cb192a9e9e6e75162180c665f30c741

* feat(sdk/go): add support for purging lemur request data (#3989)

GitOrigin-RevId: 636299f860ea5f3fc3436d497e0d1c2fbe5cd1c3
  • Loading branch information
marcusolsson committed Mar 6, 2024
1 parent 8e723bf commit 5b819ff
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 13 deletions.
2 changes: 1 addition & 1 deletion assemblyai.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

const (
version = "1.2.0"
version = "1.3.0"
defaultBaseURLScheme = "https"
defaultBaseURLHost = "api.assemblyai.com"
defaultUserAgent = "assemblyai-go/" + version
Expand Down
15 changes: 15 additions & 0 deletions lemur.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,18 @@ func (s *LeMURService) Task(ctx context.Context, params LeMURTaskParams) (LeMURT

return response, nil
}

func (s *LeMURService) PurgeRequestData(ctx context.Context, requestID string) (PurgeLeMURRequestDataResponse, error) {
req, err := s.client.newJSONRequest("DELETE", "/lemur/v3/"+requestID, nil)
if err != nil {
return PurgeLeMURRequestDataResponse{}, err
}

var response PurgeLeMURRequestDataResponse

if _, err := s.client.do(ctx, req, &response); err != nil {
return PurgeLeMURRequestDataResponse{}, err
}

return response, nil
}
67 changes: 67 additions & 0 deletions lemur_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func TestLeMUR_Summarize(t *testing.T) {
want := LeMURSummaryParams{
LeMURBaseParams: LeMURBaseParams{
TranscriptIDs: []string{"transcript_id"},
Context: "Additional context",
},
}

Expand All @@ -37,6 +38,50 @@ func TestLeMUR_Summarize(t *testing.T) {
response, err := client.LeMUR.Summarize(ctx, LeMURSummaryParams{
LeMURBaseParams: LeMURBaseParams{
TranscriptIDs: []string{"transcript_id"},
Context: "Additional context",
},
})
if err != nil {
t.Errorf("Submit returned error: %v", err)
}

want := lemurSummaryWildfires

if *response.Response != want {
t.Errorf("LeMUR.Summarize = %v, want = %v", response, want)
}
}

func TestLeMUR_SummarizeWithStructContext(t *testing.T) {
client, handler, teardown := setup()
defer teardown()

handler.HandleFunc("/lemur/v3/generate/summary", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")

var body LeMURSummaryParams
json.NewDecoder(r.Body).Decode(&body)

want := LeMURSummaryParams{
LeMURBaseParams: LeMURBaseParams{
TranscriptIDs: []string{"transcript_id"},
Context: map[string]interface{}{"key": "value"},
},
}

if !cmp.Equal(body, want) {
t.Errorf("Request body = %+v, want = %+v", body, want)
}

writeFileResponse(t, w, "testdata/lemur/summarize.json")
})

ctx := context.Background()

response, err := client.LeMUR.Summarize(ctx, LeMURSummaryParams{
LeMURBaseParams: LeMURBaseParams{
TranscriptIDs: []string{"transcript_id"},
Context: map[string]interface{}{"key": "value"},
},
})
if err != nil {
Expand Down Expand Up @@ -193,3 +238,25 @@ then get into the examples with feedback.
t.Errorf("LeMUR.ActionItems = %v, want = %v", response, want)
}
}

func TestLeMUR_PurgeRequestData(t *testing.T) {
client, handler, teardown := setup()
defer teardown()

handler.HandleFunc("/lemur/v3/23f1485d-b3ba-4bba-8910-c16085e1afa5", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE")

writeFileResponse(t, w, "testdata/lemur/purge-request-data.json")
})

ctx := context.Background()

response, err := client.LeMUR.PurgeRequestData(ctx, "23f1485d-b3ba-4bba-8910-c16085e1afa5")
if err != nil {
t.Errorf("PurgeRequestData returned error: %v", err)
}

if !ToBool(response.Deleted) {
t.Errorf("LeMUR request was not deleted")
}
}
7 changes: 1 addition & 6 deletions realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package assemblyai

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -389,11 +388,7 @@ func (c *RealTimeClient) Send(ctx context.Context, samples []byte) error {
return ErrSessionClosed
}

data := AudioData{
AudioData: base64.StdEncoding.EncodeToString(samples),
}

return wsjson.Write(ctx, c.conn, data)
return c.conn.Write(ctx, websocket.MessageBinary, samples)
}

// ForceEndUtterance manually ends an utterance.
Expand Down
8 changes: 4 additions & 4 deletions realtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)
Expand All @@ -25,12 +26,11 @@ func TestRealTime_Send(t *testing.T) {
t.Error(err)
}

_, b, _ := conn.Read(ctx)
_, got, _ := conn.Read(ctx)

got := strings.TrimSpace(string(b))
want := `{"audio_data":"Zm9v"}`
want := []byte("foo")

if got != want {
if !cmp.Equal(got, want) {
t.Errorf("message = %v, want %v", got, want)
}

Expand Down
5 changes: 5 additions & 0 deletions testdata/lemur/purge-request-data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"request_id": "2af1485d-b3ba-4bba-8910-c16085e1afa5",
"request_id_to_purge": "23f1485d-b3ba-4bba-8910-c16085e1afa5",
"deleted": true
}
26 changes: 26 additions & 0 deletions testdata/transcript/word-search.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"id": "bfc3622e-8c69-4497-9a84-fb65b30dcb07",
"total_count": 6,
"matches": [
{
"text": "hopkins",
"count": 2,
"timestamps": [
[24298, 24714],
[273498, 274090]
],
"indexes": [68, 835]
},
{
"text": "wildfires",
"count": 4,
"timestamps": [
[1668, 2346],
[33852, 34546],
[50118, 51110],
[231356, 232354]
],
"indexes": [4, 90, 140, 716]
}
]
}
23 changes: 23 additions & 0 deletions transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/url"
"strings"
"time"

"github.com/cenkalti/backoff"
Expand Down Expand Up @@ -255,3 +256,25 @@ func (s *TranscriptService) TranscribeFromReader(ctx context.Context, reader io.
}
return s.Wait(ctx, *transcript.ID)
}

// WordSearch searches a transcript for any occurrences of the provided words.
func (s *TranscriptService) WordSearch(ctx context.Context, transcriptID string, words []string) (WordSearchResponse, error) {
values := url.Values{}
values.Set("words", strings.Join(words, ","))

req, err := s.client.newJSONRequest("GET", fmt.Sprint("/v2/transcript/", transcriptID, "/word-search?", values.Encode()), nil)
if err != nil {
return WordSearchResponse{}, err
}

var results WordSearchResponse

resp, err := s.client.do(ctx, req, &results)
if err != nil {
return WordSearchResponse{}, err
}
defer resp.Body.Close()

return results, nil

}
42 changes: 42 additions & 0 deletions transcript_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,45 @@ func TestTranscripts_List(t *testing.T) {
t.Errorf(cmp.Diff(want, results))
}
}

func TestTranscripts_SearchWords(t *testing.T) {
client, handler, teardown := setup()
defer teardown()

handler.HandleFunc("/v2/transcript/"+fakeTranscriptID+"/word-search", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
testQuery(t, r, "words=hopkins%2Cwildfires")

writeFileResponse(t, w, "testdata/transcript/word-search.json")
})

ctx := context.Background()

results, err := client.Transcripts.WordSearch(ctx, fakeTranscriptID, []string{"hopkins", "wildfires"})
if err != nil {
t.Errorf("Transcripts.WordSearch returned error: %v", err)
}

want := WordSearchResponse{
ID: String("bfc3622e-8c69-4497-9a84-fb65b30dcb07"),
Matches: []WordSearchMatch{
{
Count: Int64(2),
Indexes: []int64{68, 835},
Text: String("hopkins"),
Timestamps: []WordSearchTimestamp{{24298, 24714}, {273498, 274090}},
},
{
Count: Int64(4),
Indexes: []int64{4, 90, 140, 716},
Text: String("wildfires"),
Timestamps: []WordSearchTimestamp{{1668, 2346}, {33852, 34546}, {50118, 51110}, {231356, 232354}},
},
},
TotalCount: Int64(6),
}

if !cmp.Equal(results, want) {
t.Errorf(cmp.Diff(want, results))
}
}
4 changes: 2 additions & 2 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ type LeMURActionItemsResponse struct {
}

type LeMURBaseParams struct {
// Context to provide the model. This can be a string or a free-form JSON value.
Context json.RawMessage `json:"context,omitempty"`
// Context to provide the model. This can be a string or struct.
Context interface{} `json:"context,omitempty"`

// The model that is used for the final prompt after compression is performed.
// Defaults to "default".
Expand Down

0 comments on commit 5b819ff

Please sign in to comment.