diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bbaa8e5..5a2c77e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,4 +34,5 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Maintenance - input: Fixes `List` input not managing S3 "folders" [#35](https://github.com/AdRoll/baker/pull/35) +- input: with [#35](https://github.com/AdRoll/baker/pull/35) we introduced a regression that has been fixed with [#39](https://github.com/AdRoll/baker/pull/39) - upload: fixes a severe concurrency issue in the uploader [#38](https://github.com/AdRoll/baker/pull/38) diff --git a/input/list.go b/input/list.go index ee11cc10..51795c25 100644 --- a/input/list.go +++ b/input/list.go @@ -42,7 +42,7 @@ var ListDesc = baker.InputDesc{ " and each line will be read and parsed as a \"file specifier\"\n" + " * \"@\" followed by a S3 URL pointing to a file: the text file pointed by the URL will be\n" + " downloaded, and each line will be read and parsed as a \"file specifier\"\n" + - " * \"@\" followed by a local path pointing to a directory: the directory will be recursively\n" + + " * \"@\" followed by a local path pointing to a directory (must end with a slash): the directory will be recursively\n" + " walked, and all files matching the \"MatchPath\" option regexp will be processed as logfiles\n" + " * \"@\" followed by a S3 URL pointing to a directory: the directory on S3 will be recursively\n" + " walked, and all files matching the \"MatchPath\" option regexp will be processed as logfiles\n" + @@ -282,58 +282,72 @@ func (s *List) processList(fn string) error { } case "s3": - // ListObjectsV2Input prefix must not start with / - prefix := strings.TrimLeft(u.Path, "/") + if u.Path[len(u.Path)-1:] == "/" { + // ListObjectsV2Input prefix must not start with / + prefix := strings.TrimLeft(u.Path, "/") - paths := make(chan string) - errCh := make(chan error) + paths := make(chan string) + errCh := make(chan error) - go func() { - defer close(paths) + go func() { + defer close(paths) - var nextToken *string - input := &s3.ListObjectsV2Input{ - Bucket: aws.String(u.Host), - Prefix: aws.String(prefix), - MaxKeys: aws.Int64(1000), // 1000 is the max value - } - for { - if nextToken != nil { - input.ContinuationToken = nextToken + var nextToken *string + input := &s3.ListObjectsV2Input{ + Bucket: aws.String(u.Host), + Prefix: aws.String(prefix), + MaxKeys: aws.Int64(1000), // 1000 is the max value } + for { + if nextToken != nil { + input.ContinuationToken = nextToken + } - resp, err := s.svc.ListObjectsV2(input) - if err != nil { - errCh <- err - return - } + resp, err := s.svc.ListObjectsV2(input) + if err != nil { + errCh <- err + return + } - for _, obj := range resp.Contents { - path := *obj.Key - if s.matchPath.MatchString(path) { - paths <- path + for _, obj := range resp.Contents { + path := *obj.Key + if s.matchPath.MatchString(path) { + paths <- path + } } - } - if *(resp.IsTruncated) == false { - return + if *(resp.IsTruncated) == false { + return + } + nextToken = resp.NextContinuationToken } - nextToken = resp.NextContinuationToken - } - }() + }() - for { - select { - case err := <-errCh: - return err - case line, ok := <-paths: - if !ok { + for { + select { + case err := <-errCh: + return err + case line, ok := <-paths: + if !ok { + return nil + } + s.ci.ProcessFile(fmt.Sprintf("s3://%s/%s", u.Host, line)) + case <-s.ci.Done: return nil } - s.ci.ProcessFile(fmt.Sprintf("s3://%s/%s", u.Host, line)) - case <-s.ci.Done: - return nil } + } else { + resp, err := s.svc.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(u.Host), + Key: aws.String(u.Path), + }) + if err != nil { + return err + } + + s.processListFile(resp.Body) + resp.Body.Close() + return nil } case "http", "https": diff --git a/input/list_test.go b/input/list_test.go index 579cbb56..7c241c76 100644 --- a/input/list_test.go +++ b/input/list_test.go @@ -219,7 +219,7 @@ func TestListInvalidStdin(t *testing.T) { } } -func TestListS3(t *testing.T) { +func TestListS3Folder(t *testing.T) { defer testutil.DisableLogging()() ch := make(chan *baker.Data) @@ -256,7 +256,7 @@ func TestListS3(t *testing.T) { return } - svc, recordsLen, getObjCounter := mockS3Service(t, generatedFiles, generatedRecords) + svc, recordsLen, getObjCounter := mockS3Service(t, generatedFiles, generatedRecords, false) list.(*List).svc = svc if err := list.Run(ch); err != nil { @@ -278,7 +278,58 @@ func TestListS3(t *testing.T) { } } -func mockS3Service(t *testing.T, generatedFiles, generatedRecords int) (*s3.S3, int, *int64) { +func TestListS3Manifest(t *testing.T) { + defer testutil.DisableLogging()() + + ch := make(chan *baker.Data) + + var receivedFilesContent int64 + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for data := range ch { + if len(data.Bytes) > 0 { + atomic.AddInt64(&receivedFilesContent, 1) + } + } + }() + + cfg := baker.InputParams{ + ComponentParams: baker.ComponentParams{ + DecodedConfig: &ListConfig{ + Files: []string{"@s3://bucket-name/path-prefix/manifest"}, + MatchPath: ".*\\.log\\.zst", + }, + }, + } + + list, err := NewList(cfg) + if err != nil { + t.Error("Error creating List:", err) + return + } + + svc, _, getObjCounter := mockS3Service(t, 0, 2, true) + list.(*List).svc = svc + + if err := list.Run(ch); err != nil { + log.Fatalf("unexpected error %v", err) + } + close(ch) + wg.Wait() + + if int(*getObjCounter) != 3 { // 1 manifest + 2 files into it + t.Errorf("getObjCounter want: %d, got: %d", 3, int(*getObjCounter)) + } + + if int(receivedFilesContent) != 2 { + t.Fatalf("receivedFilesContent want: %d, got: %d", 2, int(receivedFilesContent)) + } +} + +func mockS3Service(t *testing.T, generatedFiles, generatedRecords int, getManifest bool) (*s3.S3, int, *int64) { t.Helper() var counter int64 var buf []byte @@ -290,6 +341,7 @@ func mockS3Service(t *testing.T, generatedFiles, generatedRecords int) (*s3.S3, compressedRecord := zstd.Compress(nil, buf) lastModified := aws.Time(time.Now()) + var manifestCalled bool svc := s3.New(unit.Session) svc.Handlers.Unmarshal.Clear() @@ -321,9 +373,20 @@ func mockS3Service(t *testing.T, generatedFiles, generatedRecords int) (*s3.S3, case *s3.GetObjectOutput: atomic.AddInt64(&counter, 1) - data.ContentLength = aws.Int64(int64(len(compressedRecord))) - data.LastModified = lastModified - data.Body = ioutil.NopCloser(bytes.NewReader(compressedRecord)) + if getManifest && !manifestCalled { + var manifest []byte + for i := 0; i < generatedRecords; i++ { + manifest = append(manifest, []byte(fmt.Sprintf("s3://bucket-name/file-%d-from_manifest.log.zst\n", i))...) + } + manifestCalled = true + data.ContentLength = aws.Int64(int64(len(manifest))) + data.LastModified = lastModified + data.Body = ioutil.NopCloser(bytes.NewReader(manifest)) + } else { + data.ContentLength = aws.Int64(int64(len(compressedRecord))) + data.LastModified = lastModified + data.Body = ioutil.NopCloser(bytes.NewReader(compressedRecord)) + } } })