diff --git a/ArchUnitNET/Loader/ArchLoader.cs b/ArchUnitNET/Loader/ArchLoader.cs index 49895a72..93afeb37 100644 --- a/ArchUnitNET/Loader/ArchLoader.cs +++ b/ArchUnitNET/Loader/ArchLoader.cs @@ -38,9 +38,14 @@ public ArchLoader LoadAssemblies(params Assembly[] assemblies) } public ArchLoader LoadAssembliesIncludingDependencies(params Assembly[] assemblies) + { + return LoadAssembliesIncludingDependencies(assemblies, false); + } + + public ArchLoader LoadAssembliesIncludingDependencies(IEnumerable assemblies, bool recursive) { var assemblySet = new HashSet(assemblies); - assemblySet.ForEach(assembly => LoadAssemblyIncludingDependencies(assembly)); + assemblySet.ForEach(assembly => LoadAssemblyIncludingDependencies(assembly, recursive)); return this; } @@ -53,10 +58,11 @@ public ArchLoader LoadFilteredDirectory(string directory, string filter, var result = this; return assemblies.Aggregate(result, - (current, assembly) => current.LoadAssembly(assembly, false)); + (current, assembly) => current.LoadAssembly(assembly, false, false)); } public ArchLoader LoadFilteredDirectoryIncludingDependencies(string directory, string filter, + bool recursive = false, SearchOption searchOption = TopDirectoryOnly) { var path = Path.GetFullPath(directory); @@ -65,67 +71,74 @@ public ArchLoader LoadFilteredDirectoryIncludingDependencies(string directory, s var result = this; return assemblies.Aggregate(result, - (current, assembly) => current.LoadAssembly(assembly, true)); + (current, assembly) => current.LoadAssembly(assembly, true, recursive)); } public ArchLoader LoadNamespacesWithinAssembly(Assembly assembly, params string[] namespc) { var nameSpaces = new HashSet(namespc); - nameSpaces.ForEach(nameSpace => { LoadModule(assembly.Location, nameSpace, false); }); + nameSpaces.ForEach(nameSpace => { LoadModule(assembly.Location, nameSpace, false, false); }); return this; } public ArchLoader LoadAssembly(Assembly assembly) { - return LoadAssembly(assembly.Location, false); + return LoadAssembly(assembly.Location, false, false); } - public ArchLoader LoadAssemblyIncludingDependencies(Assembly assembly) + public ArchLoader LoadAssemblyIncludingDependencies(Assembly assembly, bool recursive = false) { - return LoadAssembly(assembly.Location, true); + return LoadAssembly(assembly.Location, true, recursive); } - private ArchLoader LoadAssembly(string fileName, bool includeDependencies) + private ArchLoader LoadAssembly(string fileName, bool includeDependencies, bool recursive) { - LoadModule(fileName, null, includeDependencies); + LoadModule(fileName, null, includeDependencies, recursive); return this; } - private void LoadModule(string fileName, string nameSpace, bool includeDependencies) + private void LoadModule(string fileName, string nameSpace, bool includeDependencies, bool recursive) { try { var module = ModuleDefinition.ReadModule(fileName, - new ReaderParameters {AssemblyResolver = _assemblyResolver}); + new ReaderParameters { AssemblyResolver = _assemblyResolver }); + var processedAssemblies = new List { module.Assembly.Name }; + var resolvedModules = new List(); _assemblyResolver.AddLib(module.Assembly); _archBuilder.AddAssembly(module.Assembly, false); foreach (var assemblyReference in module.AssemblyReferences) { - try + if (includeDependencies && recursive) { - _assemblyResolver.AddLib(assemblyReference); - if (includeDependencies) - { - _archBuilder.AddAssembly( - _assemblyResolver.Resolve(assemblyReference) ?? - throw new AssemblyResolutionException(assemblyReference), false); - } + AddReferencedAssembliesRecursively(assemblyReference, processedAssemblies, resolvedModules); } - catch (AssemblyResolutionException) + else { - //Failed to resolve assembly, skip it + try + { + processedAssemblies.Add(assemblyReference); + _assemblyResolver.AddLib(assemblyReference); + if (includeDependencies) + { + var assemblyDefinition = _assemblyResolver.Resolve(assemblyReference) ?? + throw new AssemblyResolutionException(assemblyReference); + _archBuilder.AddAssembly(assemblyDefinition, false); + resolvedModules.AddRange(assemblyDefinition.Modules); + } + } + catch (AssemblyResolutionException) + { + //Failed to resolve assembly, skip it + } } } _archBuilder.LoadTypesForModule(module, nameSpace); - if (includeDependencies) + foreach (var moduleDefinition in resolvedModules) { - foreach (var moduleDefinition in module.AssemblyReferences.SelectMany(reference => - _assemblyResolver.Resolve(reference)?.Modules)) - { - _archBuilder.LoadTypesForModule(moduleDefinition, null); - } + _archBuilder.LoadTypesForModule(moduleDefinition, null); } } catch (BadImageFormatException) @@ -133,5 +146,32 @@ private void LoadModule(string fileName, string nameSpace, bool includeDependenc // invalid file format of DLL or executable, therefore ignored } } + + private void AddReferencedAssembliesRecursively(AssemblyNameReference currentAssemblyReference, + ICollection processedAssemblies, List resolvedModules) + { + if (processedAssemblies.Contains(currentAssemblyReference)) + { + return; + } + + processedAssemblies.Add(currentAssemblyReference); + try + { + _assemblyResolver.AddLib(currentAssemblyReference); + var assemblyDefinition = _assemblyResolver.Resolve(currentAssemblyReference) ?? + throw new AssemblyResolutionException(currentAssemblyReference); + _archBuilder.AddAssembly(assemblyDefinition, false); + resolvedModules.AddRange(assemblyDefinition.Modules); + foreach (var reference in assemblyDefinition.Modules.SelectMany(m => m.AssemblyReferences)) + { + AddReferencedAssembliesRecursively(reference, processedAssemblies, resolvedModules); + } + } + catch (AssemblyResolutionException) + { + //Failed to resolve assembly, skip it + } + } } } \ No newline at end of file diff --git a/ArchUnitNETTests/Loader/ArchLoaderTests.cs b/ArchUnitNETTests/Loader/ArchLoaderTests.cs index 7a4b6f10..e67b8b11 100644 --- a/ArchUnitNETTests/Loader/ArchLoaderTests.cs +++ b/ArchUnitNETTests/Loader/ArchLoaderTests.cs @@ -5,6 +5,8 @@ // SPDX-License-Identifier: Apache-2.0 using System.Linq; +using ArchUnitNET.Loader; +using ArchUnitNETTests.Domain.Dependencies.Members; using Xunit; using static ArchUnitNETTests.StaticTestArchitectures; @@ -22,5 +24,15 @@ public void LoadAssemblies() Assert.NotEmpty(ArchUnitNETTestArchitectureWithDependencies.Assemblies); Assert.NotEmpty(ArchUnitNETTestAssemblyArchitectureWithDependencies.Assemblies); } + + [Fact(Skip = "This takes very long.")] + public void LoadAssembliesIncludingRecursiveDependencies() + { + var archUnitNetTestArchitectureWithRecursiveDependencies = + new ArchLoader().LoadAssembliesIncludingDependencies(new[] { typeof(BaseClass).Assembly }, true) + .Build(); + + Assert.True(archUnitNetTestArchitectureWithRecursiveDependencies.Assemblies.Count() > 100); + } } } \ No newline at end of file