Skip to content

Commit

Permalink
Fix wildcard handling when a file of the same name as the wildcard in…
Browse files Browse the repository at this point in the history
…put exists (#2062)

* Fix stgexp bug

* Fix test compilation
  • Loading branch information
adreed-msft authored and nakulkar-msft committed Mar 30, 2023
1 parent a3749ee commit a334832
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 20 deletions.
6 changes: 4 additions & 2 deletions cmd/copyEnumeratorInit.go
Expand Up @@ -71,7 +71,8 @@ func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrde
traverser, err = InitResourceTraverser(cca.Source, cca.FromTo.From(), &ctx, &srcCredInfo,
cca.SymlinkHandling, cca.ListOfFilesChannel, cca.Recursive, getRemoteProperties,
cca.IncludeDirectoryStubs, cca.permanentDeleteOption, func(common.EntityType) {}, cca.ListOfVersionIDs,
cca.S2sPreserveBlobTags, common.ESyncHashType.None(), cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), cca.CpkOptions, nil /* errorChannel */)
cca.S2sPreserveBlobTags, common.ESyncHashType.None(), cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(),
cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir)

if err != nil {
return nil, err
Expand Down Expand Up @@ -338,7 +339,8 @@ func (cca *CookedCopyCmdArgs) isDestDirectory(dst common.ResourceString, ctx *co

rt, err := InitResourceTraverser(dst, cca.FromTo.To(), ctx, &dstCredInfo, common.ESymlinkHandlingType.Skip(),
nil, false, false, false, common.EPermanentDeleteOption.None(),
func(common.EntityType) {}, cca.ListOfVersionIDs, false, common.ESyncHashType.None(), cca.preservePermissions, pipeline.LogNone, cca.CpkOptions, nil /* errorChannel */)
func(common.EntityType) {}, cca.ListOfVersionIDs, false, common.ESyncHashType.None(), cca.preservePermissions, pipeline.LogNone,
cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir)

if err != nil {
return false
Expand Down
3 changes: 2 additions & 1 deletion cmd/list.go
Expand Up @@ -224,7 +224,8 @@ func (cooked cookedListCmdArgs) HandleListContainerCommand() (err error) {

traverser, err := InitResourceTraverser(source, cooked.location, &ctx, &credentialInfo, common.ESymlinkHandlingType.Skip(), nil,
true, false, false, common.EPermanentDeleteOption.None(), func(common.EntityType) {},
nil, false, common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), pipeline.LogNone, common.CpkOptions{}, nil /* errorChannel */)
nil, false, common.ESyncHashType.None(), common.EPreservePermissionsOption.None(),
pipeline.LogNone, common.CpkOptions{}, nil /* errorChannel */, false)

if err != nil {
return fmt.Errorf("failed to initialize traverser: %s", err.Error())
Expand Down
3 changes: 2 additions & 1 deletion cmd/removeEnumerator.go
Expand Up @@ -51,7 +51,8 @@ func newRemoveEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator, er
sourceTraverser, err = InitResourceTraverser(cca.Source, cca.FromTo.From(), &ctx, &cca.credentialInfo,
common.ESymlinkHandlingType.Skip(), cca.ListOfFilesChannel, cca.Recursive, false, cca.IncludeDirectoryStubs,
cca.permanentDeleteOption, func(common.EntityType) {}, cca.ListOfVersionIDs, false,
common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(), cca.CpkOptions, nil /* errorChannel */)
common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(),
cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir)

// report failure to create traverser
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion cmd/setPropertiesEnumerator.go
Expand Up @@ -51,7 +51,8 @@ func setPropertiesEnumerator(cca *CookedCopyCmdArgs) (enumerator *CopyEnumerator
common.ESymlinkHandlingType.Preserve(), // preserve because we want to index all blobs, including symlink blobs
cca.ListOfFilesChannel, cca.Recursive, false, cca.IncludeDirectoryStubs,
cca.permanentDeleteOption, func(common.EntityType) {}, cca.ListOfVersionIDs, false,
common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(), cca.CpkOptions, nil /* errorChannel */)
common.ESyncHashType.None(), common.EPreservePermissionsOption.None(), azcopyLogVerbosity.ToPipelineLogLevel(),
cca.CpkOptions, nil /* errorChannel */, cca.StripTopDir)

// report failure to create traverser
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions cmd/syncEnumerator.go
Expand Up @@ -65,7 +65,8 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s
if entityType == common.EEntityType.File() {
atomic.AddUint64(&cca.atomicSourceFilesScanned, 1)
}
}, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), cca.cpkOptions, nil /* errorChannel */)
}, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(),
cca.cpkOptions, nil /* errorChannel */, false)

if err != nil {
return nil, err
Expand All @@ -86,7 +87,8 @@ func (cca *cookedSyncCmdArgs) initEnumerator(ctx context.Context) (enumerator *s
if entityType == common.EEntityType.File() {
atomic.AddUint64(&cca.atomicDestinationFilesScanned, 1)
}
}, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(), cca.cpkOptions, nil /* errorChannel */)
}, nil, cca.s2sPreserveBlobTags, cca.compareHash, cca.preservePermissions, azcopyLogVerbosity.ToPipelineLogLevel(),
cca.cpkOptions, nil /* errorChannel */, false)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions cmd/zc_enumerator.go
Expand Up @@ -332,7 +332,7 @@ type enumerationCounterFunc func(entityType common.EntityType)
func InitResourceTraverser(resource common.ResourceString, location common.Location, ctx *context.Context,
credential *common.CredentialInfo, symlinkHandling common.SymlinkHandlingType, listOfFilesChannel chan string, recursive, getProperties,
includeDirectoryStubs bool, permanentDeleteOption common.PermanentDeleteOption, incrementEnumerationCounter enumerationCounterFunc, listOfVersionIds chan string,
s2sPreserveBlobTags bool, syncHashType common.SyncHashType, preservePermissions common.PreservePermissionsOption, logLevel pipeline.LogLevel, cpkOptions common.CpkOptions, errorChannel chan ErrorFileInfo) (ResourceTraverser, error) {
s2sPreserveBlobTags bool, syncHashType common.SyncHashType, preservePermissions common.PreservePermissionsOption, logLevel pipeline.LogLevel, cpkOptions common.CpkOptions, errorChannel chan ErrorFileInfo, stripTopDir bool) (ResourceTraverser, error) {
var output ResourceTraverser
var p *pipeline.Pipeline

Expand Down Expand Up @@ -389,7 +389,7 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat
_, err := common.OSStat(resource.ValueLocal())

// If wildcard is present and this isn't an existing file/folder, glob and feed the globbed list into a list enum.
if strings.Contains(resource.ValueLocal(), "*") && err != nil {
if strings.Contains(resource.ValueLocal(), "*") && (stripTopDir || err != nil) {
basePath := getPathBeforeFirstWildcard(resource.ValueLocal())
matches, err := filepath.Glob(resource.ValueLocal())

Expand All @@ -411,9 +411,9 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat
globChan, includeDirectoryStubs, incrementEnumerationCounter, s2sPreserveBlobTags, logLevel, cpkOptions, syncHashType, preservePermissions)
} else {
if ctx != nil {
output = newLocalTraverser(*ctx, resource.ValueLocal(), recursive, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel)
output = newLocalTraverser(*ctx, resource.ValueLocal(), recursive, stripTopDir, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel)
} else {
output = newLocalTraverser(context.TODO(), resource.ValueLocal(), recursive, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel)
output = newLocalTraverser(context.TODO(), resource.ValueLocal(), recursive, stripTopDir, symlinkHandling, syncHashType, incrementEnumerationCounter, errorChannel)
}
}
case common.ELocation.Benchmark():
Expand Down
2 changes: 1 addition & 1 deletion cmd/zc_traverser_list.go
Expand Up @@ -108,7 +108,7 @@ func newListTraverser(parent common.ResourceString, parentType common.Location,
// Construct a traverser that goes through the child
traverser, err := InitResourceTraverser(source, parentType, ctx, credential, handleSymlinks,
nil, recursive, getProperties, includeDirectoryStubs, common.EPermanentDeleteOption.None(), incrementEnumerationCounter,
nil, s2sPreserveBlobTags, syncHashType, preservePermissions, logLevel, cpkOptions, nil /* errorChannel */)
nil, s2sPreserveBlobTags, syncHashType, preservePermissions, logLevel, cpkOptions, nil /* errorChannel */, false)
if err != nil {
return nil, err
}
Expand Down
11 changes: 9 additions & 2 deletions cmd/zc_traverser_local.go
Expand Up @@ -45,6 +45,7 @@ const MAX_SYMLINKS_TO_FOLLOW = 40
type localTraverser struct {
fullPath string
recursive bool
stripTopDir bool
symlinkHandling common.SymlinkHandlingType
appCtx context.Context
// a generic function to notify that a new stored object has been enumerated
Expand All @@ -71,6 +72,10 @@ func (t *localTraverser) IsDirectory(bool) (bool, error) {
}

func (t *localTraverser) getInfoIfSingleFile() (os.FileInfo, bool, error) {
if t.stripTopDir {
return nil, false, nil // StripTopDir can NEVER be a single file. If a user wants to target a single file, they must escape the *.
}

fileInfo, err := common.OSStat(t.fullPath)

if err != nil {
Expand Down Expand Up @@ -793,15 +798,17 @@ func (t *localTraverser) Traverse(preprocessor objectMorpher, processor objectPr
return finalizer(err)
}

func newLocalTraverser(ctx context.Context, fullPath string, recursive bool, symlinkHandling common.SymlinkHandlingType, syncHashType common.SyncHashType, incrementEnumerationCounter enumerationCounterFunc, errorChannel chan ErrorFileInfo) *localTraverser {
func newLocalTraverser(ctx context.Context, fullPath string, recursive bool, stripTopDir bool, symlinkHandling common.SymlinkHandlingType, syncHashType common.SyncHashType, incrementEnumerationCounter enumerationCounterFunc, errorChannel chan ErrorFileInfo) *localTraverser {
traverser := localTraverser{
fullPath: cleanLocalPath(fullPath),
recursive: recursive,
symlinkHandling: symlinkHandling,
appCtx: ctx,
incrementEnumerationCounter: incrementEnumerationCounter,
errorChannel: errorChannel,
targetHashType: syncHashType}
targetHashType: syncHashType,
stripTopDir: stripTopDir,
}
return &traverser
}

Expand Down
6 changes: 3 additions & 3 deletions cmd/zt_generic_service_traverser_test.go
Expand Up @@ -58,7 +58,7 @@ func (s *genericTraverserSuite) TestBlobFSServiceTraverserWithManyObjects(c *chk
scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, objectList)

// Create a local traversal
localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)
localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)

// Invoke the traversal with an indexer so the results are indexed for easy validation
localIndexer := newObjectIndexer()
Expand Down Expand Up @@ -174,7 +174,7 @@ func (s *genericTraverserSuite) TestServiceTraverserWithManyObjects(c *chk.C) {
scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, objectList)

// Create a local traversal
localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)
localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)

// Invoke the traversal with an indexer so the results are indexed for easy validation
localIndexer := newObjectIndexer()
Expand Down Expand Up @@ -358,7 +358,7 @@ func (s *genericTraverserSuite) TestServiceTraverserWithWildcards(c *chk.C) {
scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, objectList)

// Create a local traversal
localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)
localTraverser := newLocalTraverser(context.TODO(), dstDirName, true, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)

// Invoke the traversal with an indexer so the results are indexed for easy validation
localIndexer := newObjectIndexer()
Expand Down
68 changes: 65 additions & 3 deletions cmd/zt_generic_traverser_test.go
Expand Up @@ -22,9 +22,11 @@ package cmd

import (
"context"
"github.com/Azure/azure-pipeline-go/pipeline"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"time"

Expand Down Expand Up @@ -55,6 +57,66 @@ func trySymlink(src, dst string, c *chk.C) {
}
}

func (s *genericTraverserSuite) TestLocalWildcardOverlap(c *chk.C) {
if runtime.GOOS == "windows" {
c.Skip("invalid filename used")
return
}

/*
Wildcard support is not actually a part of the local traverser, believe it or not.
It's instead implemented in InitResourceTraverser as a short-circuit to a list traverser
utilizing the filepath.Glob function, which then initializes local traversers to achieve the same effect.
*/
tmpDir := scenarioHelper{}.generateLocalDirectory(c)
defer func(path string) { _ = os.RemoveAll(path) }(tmpDir)

scenarioHelper{}.generateLocalFilesFromList(c, tmpDir, []string{
"test.txt",
"tes*t.txt",
"foobarbaz/test.txt",
})

resource, err := SplitResourceString(filepath.Join(tmpDir, "tes*t.txt"), common.ELocation.Local())
c.Assert(err, chk.IsNil)

traverser, err := InitResourceTraverser(
resource,
common.ELocation.Local(),
nil,
nil,
common.ESymlinkHandlingType.Follow(),
nil,
true,
false,
false,
common.EPermanentDeleteOption.None(),
nil,
nil,
false,
common.ESyncHashType.None(),
common.EPreservePermissionsOption.None(),
pipeline.LogInfo,
common.CpkOptions{},
nil,
true,
)
c.Assert(err, chk.IsNil)

seenFiles := make(map[string]bool)

err = traverser.Traverse(nil, func(storedObject StoredObject) error {
seenFiles[storedObject.relativePath] = true
return nil
}, []ObjectFilter{})
c.Assert(err, chk.IsNil)

c.Assert(seenFiles, chk.DeepEquals, map[string]bool{
"test.txt": true,
"tes*t.txt": true,
})
}

// GetProperties tests.
// GetProperties does not exist on Blob, as the properties come in the list call.
// While BlobFS could get properties in the future, it's currently disabled as BFS source S2S isn't set up right now, and likely won't be.
Expand Down Expand Up @@ -483,7 +545,7 @@ func (s *genericTraverserSuite) TestTraverserWithSingleObject(c *chk.C) {
scenarioHelper{}.generateLocalFilesFromList(c, dstDirName, blobList)

// construct a local traverser
localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, dstFileName), false, common.ESymlinkHandlingType.Skip(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)
localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, dstFileName), false, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)

// invoke the local traversal with a dummy processor
localDummyProcessor := dummyProcessor{}
Expand Down Expand Up @@ -643,7 +705,7 @@ func (s *genericTraverserSuite) TestTraverserContainerAndLocalDirectory(c *chk.C
// test two scenarios, either recursive or not
for _, isRecursiveOn := range []bool{true, false} {
// construct a local traverser
localTraverser := newLocalTraverser(context.TODO(), dstDirName, isRecursiveOn, common.ESymlinkHandlingType.Skip(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)
localTraverser := newLocalTraverser(context.TODO(), dstDirName, isRecursiveOn, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)

// invoke the local traversal with an indexer
// so that the results are indexed for easy validation
Expand Down Expand Up @@ -804,7 +866,7 @@ func (s *genericTraverserSuite) TestTraverserWithVirtualAndLocalDirectory(c *chk
// test two scenarios, either recursive or not
for _, isRecursiveOn := range []bool{true, false} {
// construct a local traverser
localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, virDirName), isRecursiveOn, common.ESymlinkHandlingType.Skip(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)
localTraverser := newLocalTraverser(context.TODO(), filepath.Join(dstDirName, virDirName), isRecursiveOn, false, common.ESymlinkHandlingType.Follow(), common.ESyncHashType.None(), func(common.EntityType) {}, nil)

// invoke the local traversal with an indexer
// so that the results are indexed for easy validation
Expand Down

0 comments on commit a334832

Please sign in to comment.