diff --git a/pkg/file/path_set.go b/pkg/file/path_set.go index a46f342a..f4a8084d 100644 --- a/pkg/file/path_set.go +++ b/pkg/file/path_set.go @@ -75,3 +75,39 @@ func (s PathSet) ContainsAny(ids ...Path) bool { } return false } + +type PathCountSet map[Path]int + +func NewPathCountSet(is ...Path) PathCountSet { + s := make(PathCountSet) + s.Add(is...) + return s +} + +func (s PathCountSet) Add(ids ...Path) { + for _, i := range ids { + if _, ok := s[i]; !ok { + s[i] = 1 + continue + } + s[i]++ + } +} + +func (s PathCountSet) Remove(ids ...Path) { + for _, i := range ids { + if _, ok := s[i]; !ok { + continue + } + + s[i]-- + if s[i] <= 0 { + delete(s, i) + } + } +} + +func (s PathCountSet) Contains(i Path) bool { + count, ok := s[i] + return ok && count > 0 +} diff --git a/pkg/file/path_set_test.go b/pkg/file/path_set_test.go index 5d296649..1e578911 100644 --- a/pkg/file/path_set_test.go +++ b/pkg/file/path_set_test.go @@ -224,3 +224,45 @@ func TestPathSet_ContainsAny(t *testing.T) { }) } } + +func TestPathCountSet(t *testing.T) { + s := NewPathCountSet() + + s.Add("a", "b") // {a: 1, b: 1} + assert.True(t, s.Contains("a")) + assert.True(t, s.Contains("b")) + assert.False(t, s.Contains("c")) + + s.Add("a", "c") // {a: 2, b: 1, c: 1} + assert.True(t, s.Contains("a")) + assert.True(t, s.Contains("b")) + assert.True(t, s.Contains("c")) + + s.Remove("a") // {a: 1, b: 1, c: 1} + assert.True(t, s.Contains("a")) + assert.True(t, s.Contains("b")) + assert.True(t, s.Contains("c")) + + s.Remove("a", "b") // {c: 1} + assert.False(t, s.Contains("a")) + assert.False(t, s.Contains("b")) + assert.True(t, s.Contains("c")) + + s.Remove("a", "b", "v", "c") // {} + assert.False(t, s.Contains("a")) + assert.False(t, s.Contains("b")) + assert.False(t, s.Contains("c")) + + s.Add("a", "a", "a", "a") // {a: 4} + assert.True(t, s.Contains("a")) + assert.Equal(t, 4, s["a"]) + + s.Remove("a", "a", "a") // {a: 1} + assert.True(t, s.Contains("a")) + + s.Remove("a", "a", "a", "a") // {} + assert.False(t, s.Contains("a")) + + s.Remove("a", "a", "a", "a", "a", "a", "a", "a") // {} + assert.False(t, s.Contains("a")) +} diff --git a/pkg/filetree/filetree.go b/pkg/filetree/filetree.go index 9352e8d9..141040c8 100644 --- a/pkg/filetree/filetree.go +++ b/pkg/filetree/filetree.go @@ -250,7 +250,7 @@ func (t *FileTree) node(p file.Path, strategy linkResolutionStrategy) (*nodeAcce // return FileNode of the basename in the given path (no resolution is done at or past the basename). Note: it is // assumed that the given path has already been normalized. -func (t *FileTree) resolveAncestorLinks(path file.Path, attemptedPaths file.PathSet) (*nodeAccess, error) { +func (t *FileTree) resolveAncestorLinks(path file.Path, currentlyResolvingLinkPaths file.PathCountSet) (*nodeAccess, error) { // performance optimization... see if there is a node at the path (as if it is a real path). If so, // use it, otherwise, continue with ancestor resolution currentNodeAccess, err := t.node(path, linkResolutionStrategy{}) @@ -306,7 +306,7 @@ func (t *FileTree) resolveAncestorLinks(path file.Path, attemptedPaths file.Path // links until the next Node is resolved (or not). isLastPart := idx == len(pathParts)-1 if !isLastPart && currentNodeAccess.FileNode.IsLink() { - currentNodeAccess, err = t.resolveNodeLinks(currentNodeAccess, true, attemptedPaths) + currentNodeAccess, err = t.resolveNodeLinks(currentNodeAccess, true, currentlyResolvingLinkPaths) if err != nil { // only expected to happen on cycles return currentNodeAccess, err @@ -325,14 +325,16 @@ func (t *FileTree) resolveAncestorLinks(path file.Path, attemptedPaths file.Path // resolveNodeLinks takes the given FileNode and resolves all links at the base of the real path for the node (this implies // that NO ancestors are considered). // nolint: funlen -func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, attemptedPaths file.PathSet) (*nodeAccess, error) { +func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, currentlyResolvingLinkPaths file.PathCountSet) (*nodeAccess, error) { if n == nil { return nil, fmt.Errorf("cannot resolve links with nil Node given") } - // we need to short-circuit link resolution that never resolves (cycles) due to a cycle referencing nodes that do not exist - if attemptedPaths == nil { - attemptedPaths = file.NewPathSet() + // we need to short-circuit link resolution that never resolves (cycles) due to a cycle referencing nodes that do not exist. + // this represents current link resolution requests that are in progress. This set is pruned once the resolution + // has been completed. + if currentlyResolvingLinkPaths == nil { + currentlyResolvingLinkPaths = file.NewPathCountSet() } // note: this assumes that callers are passing paths in which the constituent parts are NOT symlinks @@ -342,8 +344,10 @@ func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, currentNodeAccess := n - // keep resolving links until a regular file or directory is found - alreadySeen := strset.New() + // keep resolving links until a regular file or directory is found. + // Note: this is NOT redundant relative to the 'currentlyResolvingLinkPaths' set. This set is used to short-circuit + // real paths that have been revisited through potentially different links (or really anyway). + realPathsVisited := strset.New() var err error for { nodePath = append(nodePath, *currentNodeAccess) @@ -357,7 +361,7 @@ func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, break } - if alreadySeen.Has(string(currentNodeAccess.FileNode.RealPath)) { + if realPathsVisited.Has(string(currentNodeAccess.FileNode.RealPath)) { return nil, ErrLinkCycleDetected } @@ -368,7 +372,8 @@ func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, } // prepare for the next iteration - alreadySeen.Add(string(currentNodeAccess.FileNode.RealPath)) + // already seen is important for the context of this loop + realPathsVisited.Add(string(currentNodeAccess.FileNode.RealPath)) nextPath = currentNodeAccess.FileNode.RenderLinkDestination() @@ -381,13 +386,14 @@ func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, lastNode = currentNodeAccess // break any cycles with non-existent paths (before attempting to look the path up again) - if attemptedPaths.Contains(nextPath) { + if currentlyResolvingLinkPaths.Contains(nextPath) { return nil, ErrLinkCycleDetected } - // get the next Node (based on the next path) - attemptedPaths.Add(nextPath) - currentNodeAccess, err = t.resolveAncestorLinks(nextPath, attemptedPaths) + // get the next Node (based on the next path)a + // attempted paths maintains state across calls to resolveAncestorLinks + currentlyResolvingLinkPaths.Add(nextPath) + currentNodeAccess, err = t.resolveAncestorLinks(nextPath, currentlyResolvingLinkPaths) if err != nil { if currentNodeAccess != nil { currentNodeAccess.LeafLinkResolution = append(currentNodeAccess.LeafLinkResolution, nodePath...) @@ -396,6 +402,7 @@ func (t *FileTree) resolveNodeLinks(n *nodeAccess, followDeadBasenameLinks bool, // only expected to occur upon cycle detection return currentNodeAccess, err } + currentlyResolvingLinkPaths.Remove(nextPath) } if !currentNodeAccess.HasFileNode() && !followDeadBasenameLinks { diff --git a/pkg/filetree/filetree_test.go b/pkg/filetree/filetree_test.go index e2d592cc..fff09810 100644 --- a/pkg/filetree/filetree_test.go +++ b/pkg/filetree/filetree_test.go @@ -1079,6 +1079,8 @@ func TestFileTree_File_DeadCycleDetection(t *testing.T) { // the test.... do we stop when a cycle is detected? exists, _, err := tr.File("/somewhere/acorn", FollowBasenameLinks) + require.Error(t, err, "should have gotten an error on resolution of a dead cycle") + // TODO: check this case if err != ErrLinkCycleDetected { t.Fatalf("should have gotten an error on resolving a file") } @@ -1089,6 +1091,56 @@ func TestFileTree_File_DeadCycleDetection(t *testing.T) { } +func TestFileTree_File_ShortCircuitDeadBasenameLinkCycles(t *testing.T) { + tr := New() + _, err := tr.AddFile("/usr/bin/ksh93") + require.NoError(t, err) + + linkPath, err := tr.AddSymLink("/usr/local/bin/ksh", "/bin/ksh") + require.NoError(t, err) + + _, err = tr.AddSymLink("/bin", "/usr/bin/ksh93") + require.NoError(t, err) + + // note: we follow dead basename links + exists, resolution, err := tr.File("/usr/local/bin/ksh", FollowBasenameLinks) + require.NoError(t, err) + assert.False(t, exists) + assert.False(t, resolution.HasReference()) + + // note: we don't follow dead basename links + exists, resolution, err = tr.File("/usr/local/bin/ksh", FollowBasenameLinks, DoNotFollowDeadBasenameLinks) + require.NoError(t, err) + assert.True(t, exists) + assert.True(t, resolution.HasReference()) + assert.Equal(t, *linkPath, *resolution.Reference) +} + +// regression: Syft issue https://github.com/anchore/syft/issues/1586 +func TestFileTree_File_ResolutionWithMultipleAncestorResolutionsForSameNode(t *testing.T) { + tr := New() + actualRef, err := tr.AddFile("/usr/bin/ksh93") + require.NoError(t, err) + + _, err = tr.AddSymLink("/usr/local/bin/ksh", "/bin/ksh") + require.NoError(t, err) + + _, err = tr.AddSymLink("/bin", "/usr/bin") + require.NoError(t, err) + + _, err = tr.AddSymLink("/etc/alternatives/ksh", "/bin/ksh93") + require.NoError(t, err) + + _, err = tr.AddSymLink("/usr/bin/ksh", "/etc/alternatives/ksh") + require.NoError(t, err) + + exists, resolution, err := tr.File("/usr/local/bin/ksh", FollowBasenameLinks) + require.NoError(t, err) + assert.True(t, exists) + assert.True(t, resolution.HasReference()) + assert.Equal(t, *actualRef, *resolution.Reference) +} + func TestFileTree_AllFiles(t *testing.T) { tr := New()