From a33483247df327d448cfbb9a734cf03ba63bc458 Mon Sep 17 00:00:00 2001 From: adreed-msft <49764384+adreed-msft@users.noreply.github.com> Date: Fri, 10 Mar 2023 10:47:10 -0800 Subject: [PATCH] Fix wildcard handling when a file of the same name as the wildcard input exists (#2062) * Fix stgexp bug * Fix test compilation --- cmd/copyEnumeratorInit.go | 6 ++- cmd/list.go | 3 +- cmd/removeEnumerator.go | 3 +- cmd/setPropertiesEnumerator.go | 3 +- cmd/syncEnumerator.go | 6 ++- cmd/zc_enumerator.go | 8 +-- cmd/zc_traverser_list.go | 2 +- cmd/zc_traverser_local.go | 11 +++- cmd/zt_generic_service_traverser_test.go | 6 +-- cmd/zt_generic_traverser_test.go | 68 ++++++++++++++++++++++-- 10 files changed, 96 insertions(+), 20 deletions(-) diff --git a/cmd/copyEnumeratorInit.go b/cmd/copyEnumeratorInit.go index 61a391ca0..1a39d94f0 100755 --- a/cmd/copyEnumeratorInit.go +++ b/cmd/copyEnumeratorInit.go @@ -71,7 +71,8 @@ func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrde traverser, err = InitResourceTraverser(cca.Source, cca.FromTo.From(), &ctx, &srcCredInfo, cca.SymlinkHandling, cca.ListOfFilesChannel, cca.Recursive, getRemoteProperties, cca.IncludeDirectoryStubs, cca.permanentDeleteOption, func(common.EntityType) {}, cca.ListOfVersionIDs, - cca.S2sPreserveBlobTags, common.ESyncHashType.None(), cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), cca.CpkOptions, nil /* errorChannel */) + cca.S2sPreserveBlobTags, common.ESyncHashType.None(), cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), + cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir) if err != nil { return nil, err @@ -338,7 +339,8 @@ func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *co rt, err := InitResourceTraverser(dst, cca.FromTo.To(), ctx, &dstCredInfo, common.ESymlinkHandlingType.Skip(), nil, false, false, false, common.EPermanentDeleteOption.None(), - func(common.EntityType) {}, cca.ListOfVersionIDs, false, common.ESyncHashType.None(), cca.preservePermissions, pipeline.LogNone, cca.CpkOptions, nil /* errorChannel */) + func(common.EntityType) {}, cca.ListOfVersionIDs, false, common.ESyncHashType.None(), cca.preservePermissions, pipeline.LogNone, + cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir) if err != nil { return false diff --git a/cmd/list.go b/cmd/list.go index e7d08bbf3..478030399 100755 --- a/cmd/list.go +++ b/cmd/list.go @@ -224,7 +224,8 @@ func (cooked cookedListCmdArgs) HandleListContainerCommand() (err error) { traverser, err := InitResourceTraverser(source, cooked.location, &ctx, &credentialInfo, common.ESymlinkHandlingType.Skip(), nil, true, false, false, common.EPermanentDeleteOption.None(), func(common.EntityType) {}, - nil, false, common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), pipeline.LogNone, common.CpkOptions{}, nil /* errorChannel */) + nil, false, common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), + pipeline.LogNone, common.CpkOptions{}, nil /* errorChannel */, false) if err != nil { return fmt.Errorf("failed to initialize traverser: %s", err.Error()) diff --git a/cmd/removeEnumerator.go b/cmd/removeEnumerator.go index 24ae693da..3894186d4 100755 --- a/cmd/removeEnumerator.go +++ b/cmd/removeEnumerator.go @@ -51,7 +51,8 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er sourceTraverser, err = InitResourceTraverser(cca.Source, cca.FromTo.From(), &ctx, &cca.credentialInfo, common.ESymlinkHandlingType.Skip(), cca.ListOfFilesChannel, cca.Recursive, false, cca.IncludeDirectoryStubs, cca.permanentDeleteOption, func(common.EntityType) {}, cca.ListOfVersionIDs, false, - common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(), cca.CpkOptions, nil /* errorChannel */) + common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(), + cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir) // report failure to create traverser if err != nil { diff --git a/cmd/setPropertiesEnumerator.go b/cmd/setPropertiesEnumerator.go index 608c195da..29058c878 100755 --- a/cmd/setPropertiesEnumerator.go +++ b/cmd/setPropertiesEnumerator.go @@ -51,7 +51,8 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator common.ESymlinkHandlingType.Preserve(), // preserve because we want to index all blobs, including symlink blobs cca.ListOfFilesChannel, cca.Recursive, false, cca.IncludeDirectoryStubs, cca.permanentDeleteOption, func(common.EntityType) {}, cca.ListOfVersionIDs, false, - common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(), cca.CpkOptions, nil /* errorChannel */) + common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(), + cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir) // report failure to create traverser if err != nil { diff --git a/cmd/syncEnumerator.go b/cmd/syncEnumerator.go index 1aaff0e66..bc947d520 100644 --- a/cmd/syncEnumerator.go +++ b/cmd/syncEnumerator.go @@ -65,7 +65,8 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s if entityType == common.EEntityType.File() { atomic.AddUint64(&cca.atomicSourceFilesScanned, 1) } - }, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), cca.cpkOptions, nil /* errorChannel */) + }, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), + cca.cpkOptions, nil /* errorChannel */, false) if err != nil { return nil, err @@ -86,7 +87,8 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s if entityType == common.EEntityType.File() { atomic.AddUint64(&cca.atomicDestinationFilesScanned, 1) } - }, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), cca.cpkOptions, nil /* errorChannel */) + }, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), + cca.cpkOptions, nil /* errorChannel */, false) if err != nil { return nil, err } diff --git a/cmd/zc_enumerator.go b/cmd/zc_enumerator.go index aab4dce26..6c123889d 100755 --- a/cmd/zc_enumerator.go +++ b/cmd/zc_enumerator.go @@ -332,7 +332,7 @@ type enumerationCounterFunc func(entityType common.EntityType) func InitResourceTraverser(resource common.ResourceString, location common.Location, ctx *context.Context, credential *common.CredentialInfo, symlinkHandling common.SymlinkHandlingType, listOfFilesChannel chan string, recursive, getProperties, includeDirectoryStubs bool, permanentDeleteOption common.PermanentDeleteOption, incrementEnumerationCounter enumerationCounterFunc, listOfVersionIds chan string, - s2sPreserveBlobTags bool, syncHashType common.SyncHashType, preservePermissions common.PreservePermissionsOption, logLevel pipeline.LogLevel, cpkOptions common.CpkOptions, errorChannel chan ErrorFileInfo) (ResourceTraverser, error) { + s2sPreserveBlobTags bool, syncHashType common.SyncHashType, preservePermissions common.PreservePermissionsOption, logLevel pipeline.LogLevel, cpkOptions common.CpkOptions, errorChannel chan ErrorFileInfo, stripTopDir bool) (ResourceTraverser, error) { var output ResourceTraverser var p *pipeline.Pipeline @@ -389,7 +389,7 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat _, err := common.OSStat(resource.ValueLocal()) // If wildcard is present and this isn't an existing file/folder, glob and feed the globbed list into a list enum. - if strings.Contains(resource.ValueLocal(), "*") && err != nil { + if strings.Contains(resource.ValueLocal(), "*") && (stripTopDir || err != nil) { basePath := getPathBeforeFirstWildcard(resource.ValueLocal()) matches, err := filepath.Glob(resource.ValueLocal()) @@ -411,9 +411,9 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat globChan, includeDirectoryStubs, incrementEnumerationCounter, s2sPreserveBlobTags, logLevel, cpkOptions, syncHashType, preservePermissions) } else { if ctx != nil { - output = newLocalTraverser(*ctx, resource.ValueLocal(), recursive, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel) + output = newLocalTraverser(*ctx, resource.ValueLocal(), recursive, stripTopDir, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel) } else { - output = newLocalTraverser(context.TODO(), resource.ValueLocal(), recursive, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel) + output = newLocalTraverser(context.TODO(), resource.ValueLocal(), recursive, stripTopDir, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel) } } case common.ELocation.Benchmark(): diff --git a/cmd/zc_traverser_list.go b/cmd/zc_traverser_list.go index 9f222d409..c25ca15a1 100755 --- a/cmd/zc_traverser_list.go +++ b/cmd/zc_traverser_list.go @@ -108,7 +108,7 @@ func newListTraverser(parent common.ResourceString, parentType common.Location, // Construct a traverser that goes through the child traverser, err := InitResourceTraverser(source, parentType, ctx, credential, handleSymlinks, nil, recursive, getProperties, includeDirectoryStubs, common.EPermanentDeleteOption.None(), incrementEnumerationCounter, - nil, s2sPreserveBlobTags, syncHashType, preservePermissions, logLevel, cpkOptions, nil /* errorChannel */) + nil, s2sPreserveBlobTags, syncHashType, preservePermissions, logLevel, cpkOptions, nil /* errorChannel */, false) if err != nil { return nil, err } diff --git a/cmd/zc_traverser_local.go b/cmd/zc_traverser_local.go index c338d72f3..9629abeb9 100755 --- a/cmd/zc_traverser_local.go +++ b/cmd/zc_traverser_local.go @@ -45,6 +45,7 @@ const MAX_SYMLINKS_TO_FOLLOW = 40 type localTraverser struct { fullPath string recursive bool + stripTopDir bool symlinkHandling common.SymlinkHandlingType appCtx context.Context // a generic function to notify that a new stored object has been enumerated @@ -71,6 +72,10 @@ func (t *localTraverser) IsDirectory(bool) (bool, error) { } func (t *localTraverser) getInfoIfSingleFile() (os.FileInfo, bool, error) { + if t.stripTopDir { + return nil, false, nil // StripTopDir can NEVER be a single file. If a user wants to target a single file, they must escape the *. + } + fileInfo, err := common.OSStat(t.fullPath) if err != nil { @@ -793,7 +798,7 @@ func (t *localTraverser) Traverse(preprocessor objectMorpher, processor objectPr return finalizer(err) } -func newLocalTraverser(ctx context.Context, fullPath string, recursive bool, symlinkHandling common.SymlinkHandlingType, syncHashType common.SyncHashType, incrementEnumerationCounter enumerationCounterFunc, errorChannel chan ErrorFileInfo) *localTraverser { +func newLocalTraverser(ctx context.Context, fullPath string, recursive bool, stripTopDir bool, symlinkHandling common.SymlinkHandlingType, syncHashType common.SyncHashType, incrementEnumerationCounter enumerationCounterFunc, errorChannel chan ErrorFileInfo) *localTraverser { traverser := localTraverser{ fullPath: cleanLocalPath(fullPath), recursive: recursive, @@ -801,7 +806,9 @@ func newLocalTraverser(ctx context.Context, fullPath string, recursive bool, sym appCtx: ctx, incrementEnumerationCounter: incrementEnumerationCounter, errorChannel: errorChannel, - targetHashType: syncHashType} + targetHashType: syncHashType, + stripTopDir: stripTopDir, + } return &traverser } diff --git a/cmd/zt_generic_service_traverser_test.go b/cmd/zt_generic_service_traverser_test.go index 3641a8142..09c5427e2 100644 --- a/cmd/zt_generic_service_traverser_test.go +++ b/cmd/zt_generic_service_traverser_test.go @@ -58,7 +58,7 @@ func (s *genericTraverserSuite) TestBlobFSServiceTraverserWithManyObjects(c *chk scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, objectList) // Create a local traversal - localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) + localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) // Invoke the traversal with an indexer so the results are indexed for easy validation localIndexer := newObjectIndexer() @@ -174,7 +174,7 @@ func (s *genericTraverserSuite) TestServiceTraverserWithManyObjects(c *chk.C) { scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, objectList) // Create a local traversal - localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) + localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) // Invoke the traversal with an indexer so the results are indexed for easy validation localIndexer := newObjectIndexer() @@ -358,7 +358,7 @@ func (s *genericTraverserSuite) TestServiceTraverserWithWildcards(c *chk.C) { scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, objectList) // Create a local traversal - localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) + localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) // Invoke the traversal with an indexer so the results are indexed for easy validation localIndexer := newObjectIndexer() diff --git a/cmd/zt_generic_traverser_test.go b/cmd/zt_generic_traverser_test.go index 3db10c8c4..c094a0590 100644 --- a/cmd/zt_generic_traverser_test.go +++ b/cmd/zt_generic_traverser_test.go @@ -22,9 +22,11 @@ package cmd import ( "context" + "github.com/Azure/azure-pipeline-go/pipeline" "io" "os" "path/filepath" + "runtime" "strings" "time" @@ -55,6 +57,66 @@ func trySymlink(src, dst string, c *chk.C) { } } +func (s *genericTraverserSuite) TestLocalWildcardOverlap(c *chk.C) { + if runtime.GOOS == "windows" { + c.Skip("invalid filename used") + return + } + + /* + Wildcard support is not actually a part of the local traverser, believe it or not. + It's instead implemented in InitResourceTraverser as a short-circuit to a list traverser + utilizing the filepath.Glob function, which then initializes local traversers to achieve the same effect. + */ + tmpDir := scenarioHelper{}.generateLocalDirectory(c) + defer func(path string) { _ = os.RemoveAll(path) }(tmpDir) + + scenarioHelper{}.generateLocalFilesFromList(c, tmpDir, []string{ + "test.txt", + "tes*t.txt", + "foobarbaz/test.txt", + }) + + resource, err := SplitResourceString(filepath.Join(tmpDir, "tes*t.txt"), common.ELocation.Local()) + c.Assert(err, chk.IsNil) + + traverser, err := InitResourceTraverser( + resource, + common.ELocation.Local(), + nil, + nil, + common.ESymlinkHandlingType.Follow(), + nil, + true, + false, + false, + common.EPermanentDeleteOption.None(), + nil, + nil, + false, + common.ESyncHashType.None(), + common.EPreservePermissionsOption.None(), + pipeline.LogInfo, + common.CpkOptions{}, + nil, + true, + ) + c.Assert(err, chk.IsNil) + + seenFiles := make(map[string]bool) + + err = traverser.Traverse(nil, func(storedObject StoredObject) error { + seenFiles[storedObject.relativePath] = true + return nil + }, []ObjectFilter{}) + c.Assert(err, chk.IsNil) + + c.Assert(seenFiles, chk.DeepEquals, map[string]bool{ + "test.txt": true, + "tes*t.txt": true, + }) +} + // GetProperties tests. // GetProperties does not exist on Blob, as the properties come in the list call. // While BlobFS could get properties in the future, it's currently disabled as BFS source S2S isn't set up right now, and likely won't be. @@ -483,7 +545,7 @@ func (s *genericTraverserSuite) TestTraverserWithSingleObject(c *chk.C) { scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, blobList) // construct a local traverser - localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, dstFileName), false, common.ESymlinkHandlingType.Skip(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) + localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, dstFileName), false, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) // invoke the local traversal with a dummy processor localDummyProcessor := dummyProcessor{} @@ -643,7 +705,7 @@ func (s *genericTraverserSuite) TestTraverserContainerAndLocalDirectory(c *chk.C // test two scenarios, either recursive or not for _, isRecursiveOn := range []bool{true, false} { // construct a local traverser - localTraverser := newLocalTraverser(context.TODO(), dstDirName, isRecursiveOn, common.ESymlinkHandlingType.Skip(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) + localTraverser := newLocalTraverser(context.TODO(), dstDirName, isRecursiveOn, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) // invoke the local traversal with an indexer // so that the results are indexed for easy validation @@ -804,7 +866,7 @@ func (s *genericTraverserSuite) TestTraverserWithVirtualAndLocalDirectory(c *chk // test two scenarios, either recursive or not for _, isRecursiveOn := range []bool{true, false} { // construct a local traverser - localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, virDirName), isRecursiveOn, common.ESymlinkHandlingType.Skip(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) + localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, virDirName), isRecursiveOn, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil) // invoke the local traversal with an indexer // so that the results are indexed for easy validation