Skip to content

Commit

Permalink
Sync from internal repo (2024-03-18) (#15)
Browse files Browse the repository at this point in the history
* feat(sdk/go): add support for temporary tokens (#4127)

GitOrigin-RevId: 3f1b1e5f61d0a4193424b7b54c00b87ad3d1ea81

* fix(sdk/go): add lemur model enums (#4147)

GitOrigin-RevId: bf5eb9e70ce2ff5bdb7faadf97af01cf0a18b576

* fix(sdk): add conformer-2 enum to go and node sdks (#4146)

GitOrigin-RevId: f5c7342796ba568e6ab8572042bb01e861c0b589

* fix(sdk/go): tidy up go.sum
  • Loading branch information
marcusolsson committed Mar 18, 2024
1 parent 5b819ff commit 4f91eb8
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 152 deletions.
4 changes: 3 additions & 1 deletion assemblyai.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

const (
version = "1.3.0"
version = "1.4.0"
defaultBaseURLScheme = "https"
defaultBaseURLHost = "api.assemblyai.com"
defaultUserAgent = "assemblyai-go/" + version
Expand All @@ -27,6 +27,7 @@ type Client struct {

Transcripts *TranscriptService
LeMUR *LeMURService
RealTime *RealTimeService
}

// NewClientWithOptions returns a new configurable AssemblyAI client. If you provide client
Expand All @@ -51,6 +52,7 @@ func NewClientWithOptions(opts ...ClientOption) *Client {

c.Transcripts = &TranscriptService{client: c}
c.LeMUR = &LeMURService{client: c}
c.RealTime = &RealTimeService{client: c}

return c
}
Expand Down
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ require (
github.com/cenkalti/backoff v2.2.1+incompatible
github.com/google/go-cmp v0.5.9
github.com/google/go-querystring v1.1.0
github.com/stretchr/testify v1.9.0
nhooyr.io/websocket v1.8.7
)

require github.com/klauspost/compress v1.10.3 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/klauspost/compress v1.10.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
9 changes: 7 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
Expand Down Expand Up @@ -29,8 +30,6 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
Expand All @@ -45,10 +44,13 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs=
Expand All @@ -59,9 +61,12 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g=
nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0=
20 changes: 20 additions & 0 deletions lemur.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,26 @@ import (
"context"
)

const (
// LeMUR Default is best at complex reasoning. It offers more nuanced
// responses and improved contextual comprehension.
LeMURModelDefault LeMURModel = "default"

// LeMUR Basic is a simplified model optimized for speed and cost. LeMUR
// Basic can complete requests up to 20% faster than Default.
LeMURModelBasic LeMURModel = "basic"

// Claude 2.1 is similar to Default, with key improvements: it minimizes
// model hallucination and system prompts, has a larger context window, and
// performs better in citations.
LeMURModelAssemblyAIMistral7B LeMURModel = "assemblyai/mistral-7b"

// LeMUR Mistral 7B is an LLM self-hosted by AssemblyAI. It's the fastest
// and cheapest of the LLM options. We recommend it for use cases like basic
// summaries and factual Q&A.
LeMURModelAnthropicClaude2_1 LeMURModel = "anthropic/claude-2-1"
)

// LeMURService groups the operations related to LeMUR.
type LeMURService struct {
client *Client
Expand Down
118 changes: 90 additions & 28 deletions realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
Expand All @@ -15,10 +14,12 @@ import (
)

var (
// ErrSessionClosed is returned when attempting to write to a closed session.
// ErrSessionClosed is returned when attempting to write to a closed
// session.
ErrSessionClosed = errors.New("session closed")

// ErrDisconnected is returned when attempting to write to a disconnected client.
// ErrDisconnected is returned when attempting to write to a disconnected
// client.
ErrDisconnected = errors.New("client is disconnected")
)

Expand Down Expand Up @@ -72,8 +73,10 @@ type RealTimeBaseTranscript struct {
// The partial transcript for your audio
Text string `json:"text"`

// An array of objects, with the information for each word in the transcription text.
// Includes the start and end time of the word in milliseconds, the confidence score of the word, and the text, which is the word itself.
// An array of objects, with the information for each word in the
// transcription text. Includes the start and end time of the word in
// milliseconds, the confidence score of the word, and the text, which is
// the word itself.
Words []Word `json:"words"`
}

Expand Down Expand Up @@ -116,8 +119,10 @@ var DefaultSampleRate = 16_000
type RealTimeClient struct {
baseURL *url.URL
apiKey string
token string

conn *websocket.Conn
conn *websocket.Conn
httpClient *http.Client

mtx sync.RWMutex
sessionOpen bool
Expand All @@ -126,6 +131,10 @@ type RealTimeClient struct {
done chan bool

handler RealTimeHandler

sampleRate int
encoding RealTimeEncoding
wordBoost []string
}

func (c *RealTimeClient) isSessionOpen() bool {
Expand All @@ -148,7 +157,8 @@ type RealTimeError struct {

type RealTimeClientOption func(*RealTimeClient)

// WithRealTimeBaseURL sets the API endpoint used by the client. Mainly used for testing.
// WithRealTimeBaseURL sets the API endpoint used by the client. Mainly used for
// testing.
func WithRealTimeBaseURL(rawurl string) RealTimeClientOption {
return func(c *RealTimeClient) {
if u, err := url.Parse(rawurl); err == nil {
Expand All @@ -157,12 +167,22 @@ func WithRealTimeBaseURL(rawurl string) RealTimeClientOption {
}
}

// WithRealTimeAuthToken configures the client to authenticate using an
// AssemblyAI API key.
func WithRealTimeAPIKey(apiKey string) RealTimeClientOption {
return func(rtc *RealTimeClient) {
rtc.apiKey = apiKey
}
}

// WithRealTimeAuthToken configures the client to authenticate using a temporary
// token generated using [CreateTemporaryToken].
func WithRealTimeAuthToken(token string) RealTimeClientOption {
return func(rtc *RealTimeClient) {
rtc.token = token
}
}

func WithHandler(handler RealTimeHandler) RealTimeClientOption {
return func(rtc *RealTimeClient) {
rtc.handler = handler
Expand All @@ -171,24 +191,13 @@ func WithHandler(handler RealTimeHandler) RealTimeClientOption {

func WithRealTimeSampleRate(sampleRate int) RealTimeClientOption {
return func(rtc *RealTimeClient) {
if sampleRate > 0 {
vs := rtc.baseURL.Query()
vs.Set("sample_rate", strconv.Itoa(sampleRate))
rtc.baseURL.RawQuery = vs.Encode()
}
rtc.sampleRate = sampleRate
}
}

func WithRealTimeWordBoost(wordBoost []string) RealTimeClientOption {
return func(rtc *RealTimeClient) {
vs := rtc.baseURL.Query()

if len(wordBoost) > 0 {
b, _ := json.Marshal(wordBoost)
vs.Set("word_boost", string(b))
}

rtc.baseURL.RawQuery = vs.Encode()
rtc.wordBoost = wordBoost
}
}

Expand All @@ -205,26 +214,26 @@ const (

func WithRealTimeEncoding(encoding RealTimeEncoding) RealTimeClientOption {
return func(rtc *RealTimeClient) {
vs := rtc.baseURL.Query()
vs.Set("encoding", string(encoding))
rtc.baseURL.RawQuery = vs.Encode()
rtc.encoding = encoding
}
}

func NewRealTimeClientWithOptions(options ...RealTimeClientOption) *RealTimeClient {
client := &RealTimeClient{
baseURL: &url.URL{
Scheme: "wss",
Host: "api.assemblyai.com",
Path: "/v2/realtime/ws",
RawQuery: fmt.Sprintf("sample_rate=%v", DefaultSampleRate),
Scheme: "wss",
Host: "api.assemblyai.com",
Path: "/v2/realtime/ws",
},
httpClient: &http.Client{},
}

for _, option := range options {
option(client)
}

client.baseURL.RawQuery = client.queryFromOptions()

return client
}

Expand Down Expand Up @@ -261,7 +270,6 @@ func NewRealTimeClient(apiKey string, handler RealTimeHandler) *RealTimeClient {
// Closes the any open WebSocket connection in case of errors.
func (c *RealTimeClient) Connect(ctx context.Context) error {
header := make(http.Header)
header.Set("Authorization", c.apiKey)

opts := &websocket.DialOptions{
HTTPHeader: header,
Expand Down Expand Up @@ -360,6 +368,33 @@ func (c *RealTimeClient) Connect(ctx context.Context) error {
return nil
}

func (c *RealTimeClient) queryFromOptions() string {
values := url.Values{}

// Temporary token
if c.token != "" {
values.Set("token", c.token)
}

// Sample rate
if c.sampleRate > 0 {
values.Set("sample_rate", strconv.Itoa(c.sampleRate))
}

// Encoding
if c.encoding != "" {
values.Set("encoding", string(c.encoding))
}

// Word boost
if len(c.wordBoost) > 0 {
b, _ := json.Marshal(c.wordBoost)
values.Set("word_boost", string(b))
}

return values.Encode()
}

// Disconnect sends the terminate_session message and waits for the server to
// send a SessionTerminated message before closing the connection.
func (c *RealTimeClient) Disconnect(ctx context.Context, waitForSessionTermination bool) error {
Expand Down Expand Up @@ -405,3 +440,30 @@ func (c *RealTimeClient) SetEndUtteranceSilenceThreshold(ctx context.Context, th
EndUtteranceSilenceThreshold: threshold,
})
}

// RealTimeService groups operations related to the real-time transcription.
type RealTimeService struct {
client *Client
}

// CreateTemporaryToken creates a temporary token that can be used to
// authenticate a real-time client.
func (svc *RealTimeService) CreateTemporaryToken(ctx context.Context, expiresIn int64) (*RealtimeTemporaryTokenResponse, error) {
params := &CreateRealtimeTemporaryTokenParams{
ExpiresIn: Int64(expiresIn),
}

req, err := svc.client.newJSONRequest("POST", "/v2/realtime/token", params)
if err != nil {
return nil, err
}

var tokenResponse RealtimeTemporaryTokenResponse
resp, err := svc.client.do(ctx, req, &tokenResponse)
if err != nil {
return nil, err
}
defer resp.Body.Close()

return &tokenResponse, nil
}

0 comments on commit 4f91eb8

Please sign in to comment.