Skip to content

Commit

Permalink
Minimize file usage
Browse files Browse the repository at this point in the history
  • Loading branch information
airenas committed Nov 22, 2023
1 parent 409cfdd commit 377de6c
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 60 deletions.
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ help:
awk '{info=$$0; getline; print " " $$0 ": " info;}' | column -t -s ':'
.PHONY: help
#####################################################################################
generate:
go install github.com/petergtz/pegomock/...@latest
go generate ./...

renew-async-api:
go get github.com/airenas/async-api@$$(cd ../async-api;git rev-parse HEAD)
#####################################################################################
Expand Down
4 changes: 2 additions & 2 deletions internal/pkg/test/mocks/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ func (m *Sender) SendMessage(ctx context.Context, msg amessages.Message, opt *me
// Transcriber is transcription client mock
type Transcriber struct{ mock.Mock }

func (m *Transcriber) Upload(ctx context.Context, audio *api.UploadData) (string, error) {
args := m.Called(ctx, audio)
func (m *Transcriber) Upload(ctx context.Context, audioFunc func(context.Context) (*api.UploadData, func(), error)) (string, error) {
args := m.Called(ctx, audioFunc)
return args.String(0), args.Error(1)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/pkg/transcriber/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type FileData struct {

// Transcriber provides transcription
type Transcriber interface {
Upload(ctx context.Context, audio *UploadData) (string, error)
Upload(ctx context.Context, audioFunc func(context.Context) (*UploadData, func(), error)) (string, error)
HookToStatus(ctx context.Context, ID string) (<-chan StatusData, func(), error)
GetStatus(ctx context.Context, ID string) (*StatusData, error)
GetAudio(ctx context.Context, ID string) (*FileData, error)
Expand Down
56 changes: 32 additions & 24 deletions internal/pkg/transcriber/client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package transcriber

import (
"bytes"
"context"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -215,34 +214,43 @@ type uploadResponse struct {
}

// Upload uploads audio to transcriber service
func (sp *Client) Upload(ctx context.Context, audio *tapi.UploadData) (string, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
i := 0
for v, k := range audio.Files {
name := getFileParam(i)
part, err := writer.CreateFormFile(name, v)
if err != nil {
return "", fmt.Errorf("can't add file to request: %w", err)
}
_, err = io.Copy(part, k)
func (sp *Client) Upload(ctx context.Context, fileFunc func(context.Context) (*tapi.UploadData, func(), error)) (string, error) {
return goapp.InvokeWithBackoff(ctx, func() (string, bool, error) {
audio, cf, err := fileFunc(ctx)
if err != nil {
return "", fmt.Errorf("can't add file content to request: %w", err)
return "", true, fmt.Errorf("can't get files: %w", err)
}
}
for v, k := range audio.Params {
if err := writer.WriteField(v, k); err != nil {
return "", fmt.Errorf("can't add param: %w", err)
}
}
writer.Close()
defer cf()

bodyReader := bytes.NewReader(body.Bytes())
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
go func() {
i := 0
for v, k := range audio.Files {
name := getFileParam(i)
part, err := writer.CreateFormFile(name, v)
if err != nil {
pw.CloseWithError(fmt.Errorf("can't add file to request: %w", err))
return
}

if _, err := io.Copy(part, k); err != nil {
pw.CloseWithError(fmt.Errorf("can't add file content to request: %w", err))
return
}
i++
}
for v, k := range audio.Params {
if err := writer.WriteField(v, k); err != nil {
pw.CloseWithError(fmt.Errorf("can't add param: %w", err))
return
}
}
pw.CloseWithError(writer.Close())
}()

return goapp.InvokeWithBackoff(ctx, func() (string, bool, error) {
var respData uploadResponse
_, _ = bodyReader.Seek(0, io.SeekStart)
req, err := http.NewRequest(http.MethodPost, sp.uploadURL, bodyReader)
req, err := http.NewRequest(http.MethodPost, sp.uploadURL, pr)
if err != nil {
return "", false, err
}
Expand Down
40 changes: 26 additions & 14 deletions internal/pkg/transcriber/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package transcriber

import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -177,7 +178,7 @@ func TestResult_WrongMediaType(t *testing.T) {
func TestUpload(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/upload": newTestR(200, "{\"id\":\"1\"}")})

r, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "name"}})
r, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "name"}}, 0))

assert.Nil(t, err)
assert.Equal(t, r, "1")
Expand All @@ -187,7 +188,7 @@ func TestUpload(t *testing.T) {
func TestUpload_WrongCode_Fails(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/": newTestR(300, "{\"id\":\"1\"}")})

r, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "name"}})
r, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "name"}}, 0))

assert.NotNil(t, err)
assert.Equal(t, "", r)
Expand All @@ -197,7 +198,7 @@ func TestUpload_WrongCode_Fails(t *testing.T) {
func TestUpload_WrongJSON_Fails(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/upload": newTestR(300, "olia")})

r, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "name"}})
r, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "name"}}, 0))

assert.NotNil(t, err)
assert.Equal(t, r, "")
Expand All @@ -207,7 +208,7 @@ func TestUpload_WrongJSON_Fails(t *testing.T) {
func TestUpload_PassParams(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/upload": newTestR(200, "{\"id\":\"1\"}")})

r, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "__olia__"}})
r, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "__olia__"}}, 0))

assert.Nil(t, err)
assert.Equal(t, "1", r)
Expand All @@ -220,8 +221,8 @@ func TestUpload_PassParams(t *testing.T) {
func TestUpload_PassFile(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/upload": newTestR(200, "{\"id\":\"1\"}")})

r, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}})
r, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}}, 0))

assert.Nil(t, err)
assert.Equal(t, r, "1")
Expand All @@ -235,8 +236,8 @@ func TestUpload_Backoff(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/upload": newTestR(http.StatusTooManyRequests, "{\"id\":\"1\"}")})
client.backoff = newSimpleBackoff

_, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}})
_, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}}, 0))

assert.NotNil(t, err)
assert.Equal(t, 4, len((*tReq)))
Expand All @@ -246,8 +247,8 @@ func TestUpload_NoBackoff(t *testing.T) {
client, _, tReq := initTestServer(t, map[string]testResp{"/upload": newTestR(http.StatusBadRequest, "{\"id\":\"1\"}")})
client.backoff = newSimpleBackoff

_, err := client.Upload(test.Ctx(t), &api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}})
_, err := client.Upload(test.Ctx(t), newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}}, 0))

assert.NotNil(t, err)
assert.Equal(t, 1, len((*tReq)))
Expand All @@ -260,8 +261,8 @@ func TestUpload_NoBackoff_Deadline(t *testing.T) {
ctx, cf := context.WithDeadline(context.Background(), time.Now())
defer cf()

_, err := client.Upload(ctx, &api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}})
_, err := client.Upload(ctx, newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}}, 0))

assert.NotNil(t, err)
assert.Equal(t, 0, len((*tReq)))
Expand All @@ -274,8 +275,8 @@ func TestUpload_NoBackoff_Canceled(t *testing.T) {
ctx, cf := context.WithCancel(context.Background())
cf()

_, err := client.Upload(ctx, &api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}})
_, err := client.Upload(ctx, newTestUploadFunc(&api.UploadData{Params: map[string]string{"name": "__olia__"},
Files: map[string]io.Reader{"file.wav": strings.NewReader("__file_olia__")}}, 0))

assert.NotNil(t, err)
assert.Equal(t, 0, len((*tReq)))
Expand Down Expand Up @@ -332,3 +333,14 @@ func TestNewClient(t *testing.T) {
})
}
}

func newTestUploadFunc(uploadData *api.UploadData, failCount int) func(context.Context) (*api.UploadData, func(), error) {
c := 0
return func(context.Context) (*api.UploadData, func(), error) {
if failCount > c {
return nil, func() {}, fmt.Errorf("fail")
}
c++
return uploadData, func() {}, nil
}
}
34 changes: 19 additions & 15 deletions internal/pkg/worker/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,25 +565,29 @@ func processStatus(ctx context.Context, statusData *tapi.StatusData, wd *persist
}

func upload(ctx context.Context, req *persistence.ReqData, transcriber tapi.Transcriber, data *ServiceData) (string, error) {
goapp.Log.Info().Str("ID", req.ID).Msg("load file")
filesMap := map[string]io.Reader{}
files := []io.ReadCloser{}
defer func() {
for _, r := range files {
_ = r.Close()
fileLoadFunc := func(ctxInt context.Context) (*tapi.UploadData, func(), error) {
goapp.Log.Info().Str("ID", req.ID).Msg("get files")
filesMap := map[string]io.Reader{}
files := []io.ReadCloser{}
closeFunc := func() {
for _, r := range files {
_ = r.Close()
}
}
}()
for _, f := range req.FileNames {
file, err := data.Filer.LoadFile(ctx, utils.MakeFileName(req.ID, f))
if err != nil {
return "", fmt.Errorf("can't load file: %w", err)
for _, f := range req.FileNames {
file, err := data.Filer.LoadFile(ctxInt, utils.MakeFileName(req.ID, f))
if err != nil {
closeFunc()
return nil, nil, fmt.Errorf("can't load file: %w", err)
}
files = append(files, file)
filesMap[f] = file
goapp.Log.Info().Str("ID", req.ID).Msg("loaded")
}
files = append(files, file)
filesMap[f] = file
goapp.Log.Info().Str("ID", req.ID).Msg("loaded")
return &tapi.UploadData{Params: prepareParams(req.Params), Files: filesMap}, closeFunc, nil
}
goapp.Log.Info().Str("ID", req.ID).Msg("uploading")
extID, err := transcriber.Upload(ctx, &tapi.UploadData{Params: prepareParams(req.Params), Files: filesMap})
extID, err := transcriber.Upload(ctx, fileLoadFunc)
if err != nil {
return "", &errTranscriber{err: fmt.Errorf("can't upload: %w", err)}
}
Expand Down

0 comments on commit 377de6c

Please sign in to comment.