Skip to content

Make specifying just DllImportSearchPath.AssemblyDirectory only search in assembly directory #114756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ internal static unsafe void FixupModuleCell(ModuleFixupCell* pCell)

hModule = NativeLibrary.LoadBySearch(
callingAssembly,
hasDllImportSearchPath,
searchAssemblyDirectory: (dllImportSearchPath & (uint)DllImportSearchPath.AssemblyDirectory) != 0,
dllImportSearchPathFlags: (int)(dllImportSearchPath & ~(uint)DllImportSearchPath.AssemblyDirectory),
ref loadLibErrorTracker,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ internal static bool TryLoad(string libraryName, Assembly assembly, DllImportSea
{
handle = LoadLibraryByName(libraryName,
assembly,
userSpecifiedSearchFlags: true,
searchPath,
throwOnError: false);
return handle != IntPtr.Zero;
Expand All @@ -26,21 +27,21 @@ internal static IntPtr LoadLibraryByName(string libraryName, Assembly assembly,
// First checks if a default dllImportSearchPathFlags was passed in, if so, use that value.
// Otherwise checks if the assembly has the DefaultDllImportSearchPathsAttribute attribute.
// If so, use that value.

if (!searchPath.HasValue)
bool userSpecifiedSearchFlags = searchPath.HasValue;
if (!userSpecifiedSearchFlags)
{
searchPath = GetDllImportSearchPath(assembly);
searchPath = GetDllImportSearchPath(assembly, out userSpecifiedSearchFlags);
}
return LoadLibraryByName(libraryName, assembly, searchPath.Value, throwOnError);
return LoadLibraryByName(libraryName, assembly, userSpecifiedSearchFlags, searchPath!.Value, throwOnError);
}

internal static IntPtr LoadLibraryByName(string libraryName, Assembly assembly, DllImportSearchPath searchPath, bool throwOnError)
private static IntPtr LoadLibraryByName(string libraryName, Assembly assembly, bool userSpecifiedSearchFlags, DllImportSearchPath searchPath, bool throwOnError)
{
int searchPathFlags = (int)(searchPath & ~DllImportSearchPath.AssemblyDirectory);
bool searchAssemblyDirectory = (searchPath & DllImportSearchPath.AssemblyDirectory) != 0;

LoadLibErrorTracker errorTracker = default;
IntPtr ret = LoadBySearch(assembly, searchAssemblyDirectory, searchPathFlags, ref errorTracker, libraryName);
IntPtr ret = LoadBySearch(assembly, userSpecifiedSearchFlags, searchAssemblyDirectory, searchPathFlags, ref errorTracker, libraryName);
if (throwOnError && ret == IntPtr.Zero)
{
errorTracker.Throw(libraryName);
Expand All @@ -49,20 +50,22 @@ internal static IntPtr LoadLibraryByName(string libraryName, Assembly assembly,
return ret;
}

internal static DllImportSearchPath GetDllImportSearchPath(Assembly callingAssembly)
private static DllImportSearchPath GetDllImportSearchPath(Assembly callingAssembly, out bool userSpecifiedSearchFlags)
{
foreach (CustomAttributeData cad in callingAssembly.CustomAttributes)
{
if (cad.AttributeType == typeof(DefaultDllImportSearchPathsAttribute))
{
userSpecifiedSearchFlags = true;
return (DllImportSearchPath)cad.ConstructorArguments[0].Value!;
}
}

userSpecifiedSearchFlags = false;
return DllImportSearchPath.AssemblyDirectory;
}

internal static IntPtr LoadBySearch(Assembly callingAssembly, bool searchAssemblyDirectory, int dllImportSearchPathFlags, ref LoadLibErrorTracker errorTracker, string libraryName)
internal static IntPtr LoadBySearch(Assembly callingAssembly, bool userSpecifiedSearchFlags, bool searchAssemblyDirectory, int dllImportSearchPathFlags, ref LoadLibErrorTracker errorTracker, string libraryName)
{
IntPtr ret;

Expand Down Expand Up @@ -107,10 +110,20 @@ internal static IntPtr LoadBySearch(Assembly callingAssembly, bool searchAssembl
}
}

ret = LoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, ref errorTracker);
if (ret != IntPtr.Zero)
// Internally, search path flags and whether or not to search the assembly directory are
// tracked separately. However, on the API level, DllImportSearchPath represents them both.
// When unspecified, the default is to search the assembly directory and all OS defaults,
// which maps to searchAssemblyDirectory being true and dllImportSearchPathFlags being 0.
// When a user specifies DllImportSearchPath.AssemblyDirectory, searchAssemblyDirectory is
// true, dllImportSearchPathFlags is 0, and the desired logic is to only search the assembly
// directory (handled above), so we avoid doing any additional load search in that case.
if (!userSpecifiedSearchFlags || !searchAssemblyDirectory || dllImportSearchPathFlags != 0)
{
return ret;
ret = LoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, ref errorTracker);
if (ret != IntPtr.Zero)
{
return ret;
}
}
}

Expand Down
26 changes: 18 additions & 8 deletions src/coreclr/vm/nativelibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ namespace
#endif // TARGET_UNIX

// Search for the library and variants of its name in probing directories.
NATIVE_LIBRARY_HANDLE LoadNativeLibraryBySearch(Assembly *callingAssembly,
NATIVE_LIBRARY_HANDLE LoadNativeLibraryBySearch(Assembly *callingAssembly, BOOL userSpecifiedSearchFlags,
BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags,
LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName)
{
Expand Down Expand Up @@ -719,10 +719,20 @@ namespace
}
}

hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, pErrorTracker);
if (hmod != NULL)
// Internally, search path flags and whether or not to search the assembly directory are
// tracked separately. However, on the API level, DllImportSearchPath represents them both.
// When unspecified, the default is to search the assembly directory and all OS defaults,
// which maps to searchAssemblyDirectory being true and dllImportSearchPathFlags being 0.
// When a user specifies DllImportSearchPath.AssemblyDirectory, searchAssemblyDirectory is
// true, dllImportSearchPathFlags is 0, and the desired logic is to only search the assembly
// directory (handled above), so we avoid doing any additional load search in that case.
if (!userSpecifiedSearchFlags || !searchAssemblyDirectory || dllImportSearchPathFlags != 0)
{
return hmod;
hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
}
}
}

Expand All @@ -736,10 +746,10 @@ namespace
BOOL searchAssemblyDirectory;
DWORD dllImportSearchPathFlags;

GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory);
BOOL userSpecifiedSearchFlags = GetDllImportSearchPathFlags(pMD, &dllImportSearchPathFlags, &searchAssemblyDirectory);

Assembly *pAssembly = pMD->GetMethodTable()->GetAssembly();
return LoadNativeLibraryBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName);
return LoadNativeLibraryBySearch(pAssembly, userSpecifiedSearchFlags, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName);
}
}

Expand Down Expand Up @@ -776,12 +786,12 @@ NATIVE_LIBRARY_HANDLE NativeLibrary::LoadLibraryByName(LPCWSTR libraryName, Asse
}
else
{
GetDllImportSearchPathFlags(callingAssembly->GetModule(),
hasDllImportSearchFlags = GetDllImportSearchPathFlags(callingAssembly->GetModule(),
&dllImportSearchPathFlags, &searchAssemblyDirectory);
}

LoadLibErrorTracker errorTracker;
hmod = LoadNativeLibraryBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, &errorTracker, libraryName);
hmod = LoadNativeLibraryBySearch(callingAssembly, hasDllImportSearchFlags, searchAssemblyDirectory, dllImportSearchPathFlags, &errorTracker, libraryName);
if (hmod != nullptr)
return hmod;

Expand Down
22 changes: 12 additions & 10 deletions src/mono/mono/metadata/native-library.c
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ netcore_probe_for_module_variations (const char *mdirname, const char *file_name
}

static MonoDl *
netcore_probe_for_module (MonoImage *image, const char *file_name, int flags, MonoError *error)
netcore_probe_for_module (MonoImage *image, const char *file_name, gboolean user_specified_flags, int flags, MonoError *error)
{
MonoDl *module = NULL;
int lflags = convert_dllimport_flags (flags);
Expand Down Expand Up @@ -372,8 +372,9 @@ netcore_probe_for_module (MonoImage *image, const char *file_name, int flags, Mo
g_free (mdirname);
}

// Try without any path additions, if we didn't try it already
if (module == NULL && !probe_first_without_prepend)
// Try without any path additions, if we didn't try it already and the user did not
// explicitly specify to only look in the assembly directory.
if (module == NULL && !probe_first_without_prepend && (!user_specified_flags || flags != DLLIMPORTSEARCHPATH_ASSEMBLY_DIRECTORY))
{
module = netcore_probe_for_module_variations (NULL, file_name, lflags, error);
if (!module && !is_ok (error) && mono_error_get_error_code (error) == MONO_ERROR_BAD_IMAGE)
Expand All @@ -393,12 +394,12 @@ netcore_probe_for_module (MonoImage *image, const char *file_name, int flags, Mo
}

static MonoDl *
netcore_probe_for_module_nofail (MonoImage *image, const char *file_name, int flags)
netcore_probe_for_module_nofail (MonoImage *image, const char *file_name, gboolean user_specified_flags, int flags)
{
MonoDl *result = NULL;

ERROR_DECL (error);
result = netcore_probe_for_module (image, file_name, flags, error);
result = netcore_probe_for_module (image, file_name, user_specified_flags, flags, error);
mono_error_cleanup (error);

return result;
Expand Down Expand Up @@ -661,7 +662,7 @@ netcore_check_alc_cache (MonoAssemblyLoadContext *alc, const char *scope)
}

static MonoDl *
netcore_lookup_native_library (MonoAssemblyLoadContext *alc, MonoImage *image, const char *scope, guint32 flags)
netcore_lookup_native_library (MonoAssemblyLoadContext *alc, MonoImage *image, const char *scope, gboolean user_specified_flags, guint32 flags)
{
MonoDl *module = NULL;
MonoDl *cached;
Expand Down Expand Up @@ -726,7 +727,7 @@ netcore_lookup_native_library (MonoAssemblyLoadContext *alc, MonoImage *image, c
goto add_to_alc_cache;
}

module = netcore_probe_for_module_nofail (image, scope, flags);
module = netcore_probe_for_module_nofail (image, scope, user_specified_flags, flags);
if (module) {
mono_trace (G_LOG_LEVEL_DEBUG, MONO_TRACE_DLLIMPORT, "Native library found via filesystem probing: '%s'.", scope);
goto add_to_global_cache;
Expand Down Expand Up @@ -899,9 +900,10 @@ lookup_pinvoke_call_impl (MonoMethod *method, MonoLookupPInvokeStatus *status_ou
if (cinfo && !cinfo->cached)
mono_custom_attrs_free (cinfo);
}
if (flags < 0)
gboolean user_specified_flags = flags >= 0;
if (!user_specified_flags)
flags = DLLIMPORTSEARCHPATH_ASSEMBLY_DIRECTORY;
module = netcore_lookup_native_library (alc, image, new_scope, flags);
module = netcore_lookup_native_library (alc, image, new_scope, user_specified_flags, flags);

if (!module) {
mono_trace (G_LOG_LEVEL_WARNING, MONO_TRACE_DLLIMPORT,
Expand Down Expand Up @@ -1167,7 +1169,7 @@ ves_icall_System_Runtime_InteropServices_NativeLibrary_LoadByName (MonoStringHan
// FIXME: implement search flag defaults properly
{
ERROR_DECL (load_error);
module = netcore_probe_for_module (image, lib_name, has_search_flag ? search_flag : DLLIMPORTSEARCHPATH_ASSEMBLY_DIRECTORY, load_error);
module = netcore_probe_for_module (image, lib_name, has_search_flag, has_search_flag ? search_flag : DLLIMPORTSEARCHPATH_ASSEMBLY_DIRECTORY, load_error);
if (!module) {
if (mono_error_get_error_code (load_error) == MONO_ERROR_BAD_IMAGE)
mono_error_set_generic_error (error, "System", "BadImageFormatException", "%s", lib_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ public static void AssemblyDirectoryAot_Found()

[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public static void AssemblyDirectory_Fallback_Found()
public static void AssemblyDirectory_NoFallback_NotFound()
{
string currentDirectory = Environment.CurrentDirectory;
try
{
Environment.CurrentDirectory = Subdirectory;

// Library should not be found in the assembly directory, but should fall back to the default OS search which includes CWD on Windows
int sum = NativeLibraryPInvoke.Sum_Copy(1, 2);
Assert.Equal(3, sum);
// Library should not be found in the assembly directory and should not fall back to the default OS search
Assert.Throws<DllNotFoundException>(() => NativeLibraryPInvoke.Sum_Copy(1, 2));
}
finally
{
Expand Down Expand Up @@ -149,10 +148,8 @@ public static int Sum(int a, int b)
=> NativeSum(a, b);

// For NativeAOT, validate the case where the native library is next to the AOT application.
// The passing of DllImportSearchPath.System32 is done to ensure on Windows the runtime won't fallback
// and try to search the application directory by default.
[DllImport(NativeLibraryToLoad.Name + "-in-native")]
[DefaultDllImportSearchPaths(DllImportSearchPath.AssemblyDirectory | DllImportSearchPath.System32)]
[DefaultDllImportSearchPaths(DllImportSearchPath.AssemblyDirectory)]
static extern int NativeSum(int arg1, int arg2);
}

Expand Down
16 changes: 10 additions & 6 deletions src/tests/Interop/NativeLibrary/API/NativeLibraryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void LoadLibraryOnInvalidFile_NameOnly()
[Fact]
public void LoadLibraryRelativePaths_NameOnly()
{

{
string libName = Path.Combine("..", NativeLibraryToLoad.InvalidName, NativeLibraryToLoad.GetLibraryFileName(NativeLibraryToLoad.InvalidName));
EXPECT(LoadLibrary_NameOnly(libName), TestResult.DllNotFound);
Expand Down Expand Up @@ -142,9 +142,9 @@ public void LoadSystemLibrary_WithSearchPath()
EXPECT(LoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory | DllImportSearchPath.System32));
EXPECT(TryLoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory | DllImportSearchPath.System32));

// Library should not be found in the assembly directory, but should fall back to the default OS search which includes CWD on Windows
EXPECT(LoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory));
EXPECT(TryLoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory));
// Library should not be found in the assembly directory and should not fall back to the default OS search which includes CWD on Windows
EXPECT(LoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory), TestResult.DllNotFound);
EXPECT(TryLoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory), TestResult.ReturnFailure);

// Library should not be found in application directory
EXPECT(LoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.ApplicationDirectory), TestResult.DllNotFound);
Expand Down Expand Up @@ -208,8 +208,12 @@ public void LoadLibrary_AssemblyDirectory()
Environment.CurrentDirectory = subdirectory;

// Library should not be found in the assembly directory, but should fall back to the default OS search which includes CWD on Windows
EXPECT(LoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory));
EXPECT(TryLoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory));
EXPECT(LoadLibrary_WithAssembly(libName, assembly, null));
EXPECT(TryLoadLibrary_WithAssembly(libName, assembly, null));

// Library should not be found in the assembly directory and should not fall back to the default OS search which includes CWD on Windows
EXPECT(LoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory), TestResult.DllNotFound);
EXPECT(TryLoadLibrary_WithAssembly(libName, assembly, DllImportSearchPath.AssemblyDirectory), TestResult.ReturnFailure);
}
finally
{
Expand Down
Loading