Permalink
Browse files

Dynamically loaded functions with pointer argument/return types now g…

…enerate an additional unsafe method. The signature matches the JNI method, except the function address argument is missing; it is passed automatically like standard methods. The call sequence is now: normal/alternative -> unsafe -> JNI, which lets us have the function address check in one place. Credit for the idea goes to JGO's Riven.
  • Loading branch information...
1 parent 143950d commit aa1f9978951835922cc7d0118b05c356bfe92ea4 @Spasi Spasi committed Nov 23, 2013
@@ -145,11 +145,7 @@ public static ALCCapabilities getCapabilities() {
* @param token the information to query. One of:<p/>{@link ALC11#ALC_ALL_DEVICES_SPECIFIER}, {@link ALC11#ALC_CAPTURE_DEVICE_SPECIFIER}
*/
public static List<String> getStringList(long deviceHandle, int token) {
- long alcGetString = functionProvider.getFunctionAddress("alcGetString");
- if ( LWJGLUtil.CHECKS )
- checkFunctionAddress(alcGetString);
-
- long __result = nalcGetString(deviceHandle, token, alcGetString);
+ long __result = nalcGetString(deviceHandle, token);
if ( __result == NULL )
return null;
@@ -35,16 +35,14 @@ public static CLDevice create(long cl_device_id, CLPlatform platform) {
}
private static CLCapabilities createCapabilities(long cl_device_id, CLPlatform platform) {
- long clGetDeviceInfo = CL10.getInstance().GetDeviceInfo;
-
Set<String> supportedExtensions = new HashSet<>(32);
// Parse DEVICE_EXTENSIONS string
- String extensionsString = getDeviceInfo(cl_device_id, CL_DEVICE_EXTENSIONS, clGetDeviceInfo);
+ String extensionsString = getDeviceInfo(cl_device_id, CL_DEVICE_EXTENSIONS);
CL.addExtensions(extensionsString, supportedExtensions);
// Parse DEVICE_VERSION string
- String version = getDeviceInfo(cl_device_id, CL_DEVICE_VERSION, clGetDeviceInfo);
+ String version = getDeviceInfo(cl_device_id, CL_DEVICE_VERSION);
int majorVersion;
int minorVersion;
try {
@@ -60,18 +58,18 @@ private static CLCapabilities createCapabilities(long cl_device_id, CLPlatform p
return new CLCapabilities(majorVersion, minorVersion, supportedExtensions, platform.getCapabilities());
}
- static String getDeviceInfo(long device_id, int param_name, long clGetDeviceInfo) {
+ static String getDeviceInfo(long device_id, int param_name) {
APIBuffer __buffer = apiBuffer();
__buffer.intParam(0);
- int errcode = nclGetDeviceInfo(device_id, param_name, 0L, NULL, __buffer.address(), clGetDeviceInfo);
+ int errcode = nclGetDeviceInfo(device_id, param_name, 0L, NULL, __buffer.address());
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query size of OpenCL device information.");
int bytes = __buffer.intValue(0);
__buffer.bufferParam(bytes);
- errcode = nclGetDeviceInfo(device_id, param_name, bytes, __buffer.address(), NULL, clGetDeviceInfo);
+ errcode = nclGetDeviceInfo(device_id, param_name, bytes, __buffer.address(), NULL);
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query OpenCL device information.");
@@ -38,22 +38,16 @@ public static CLPlatform create(long id) {
}
private static CLCapabilities createCapabilities(long platform, FunctionProvider functionProvider) {
- long clGetPlatformInfo = functionProvider.getFunctionAddress("clGetPlatformInfo");
- long clGetDeviceIDs = functionProvider.getFunctionAddress("clGetDeviceIDs");
- long clGetDeviceInfo = functionProvider.getFunctionAddress("clGetDeviceInfo");
- if ( clGetPlatformInfo == NULL || clGetDeviceIDs == NULL || clGetDeviceInfo == NULL )
- throw new OpenCLException("A core OpenCL function is missing. Make sure that OpenCL is available.");
-
Set<String> supportedExtensions = new HashSet<>(32);
// Parse PLATFORM_EXTENSIONS string
- String extensionsString = getPlatformInfo(platform, CL_PLATFORM_EXTENSIONS, clGetPlatformInfo);
+ String extensionsString = getPlatformInfo(platform, CL_PLATFORM_EXTENSIONS);
CL.addExtensions(extensionsString, supportedExtensions);
// Enumerate devices
{
APIBuffer __buffer = apiBuffer();
- int errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, __buffer.address(), clGetDeviceIDs);
+ int errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, __buffer.address());
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query number of OpenCL platform devices.");
@@ -63,7 +57,7 @@ private static CLCapabilities createCapabilities(long platform, FunctionProvider
__buffer.bufferParam(num_devices << POINTER_SHIFT);
- errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, __buffer.address(), NULL, clGetDeviceIDs);
+ errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, __buffer.address(), NULL);
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query OpenCL platform devices.");
@@ -73,13 +67,13 @@ private static CLCapabilities createCapabilities(long platform, FunctionProvider
// Add device extensions to the set
for ( int i = 0; i < num_devices; i++ ) {
- extensionsString = CLDevice.getDeviceInfo(devices[i], CL_DEVICE_EXTENSIONS, clGetDeviceInfo);
+ extensionsString = CLDevice.getDeviceInfo(devices[i], CL_DEVICE_EXTENSIONS);
CL.addExtensions(extensionsString, supportedExtensions);
}
}
// Parse PLATFORM_VERSION string
- String version = getPlatformInfo(platform, CL_PLATFORM_VERSION, clGetPlatformInfo);
+ String version = getPlatformInfo(platform, CL_PLATFORM_VERSION);
int majorVersion;
int minorVersion;
try {
@@ -95,18 +89,18 @@ private static CLCapabilities createCapabilities(long platform, FunctionProvider
return new CLCapabilities(majorVersion, minorVersion, supportedExtensions, CL.getICD());
}
- private static String getPlatformInfo(long platform, int param_name, long clGetPlatformInfo) {
+ private static String getPlatformInfo(long platform, int param_name) {
APIBuffer __buffer = apiBuffer();
__buffer.intParam(0);
- int errcode = nclGetPlatformInfo(platform, param_name, 0L, NULL, __buffer.address(), clGetPlatformInfo);
+ int errcode = nclGetPlatformInfo(platform, param_name, 0L, NULL, __buffer.address());
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query size of OpenCL platform information.");
int bytes = __buffer.intValue(0);
__buffer.bufferParam(bytes);
- errcode = nclGetPlatformInfo(platform, param_name, bytes, __buffer.address(), NULL, clGetPlatformInfo);
+ errcode = nclGetPlatformInfo(platform, param_name, bytes, __buffer.address(), NULL);
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query OpenCL platform information.");
@@ -214,22 +214,25 @@ public class NativeClassFunction(
return builder.toString()
}
- val isSimpleFunction: Boolean
+ private val isSimpleFunction: Boolean
get() = nativeClass.functionProvider == null && !(isSpecial || returns.isSpecial || hasParam { it.isSpecial })
- val ReturnValue.isStructValue: Boolean
+ private val hasUnsafeMethod: Boolean
+ get() = nativeClass.functionProvider != null && (returns.isBufferPointer || hasParam { it.isBufferPointer }) && !has(Capabilities)
+
+ private val ReturnValue.isStructValue: Boolean
get() = nativeType is StructType && !nativeType.includesPointer
- val returnsStructValue: Boolean
+ internal val returnsStructValue: Boolean
get() = returns.isStructValue && !hasParam { it has autoSizeResult }
- val returnsJavaMethodType: String
+ private val returnsJavaMethodType: String
get() = if ( returnsStructValue ) "void" else returns.javaMethodType
- val returnsNativeMethodType: String
+ private val returnsNativeMethodType: String
get() = if ( returnsStructValue ) "void" else returns.nativeMethodType
- val returnsJniFunctionType: String
+ private val returnsJniFunctionType: String
get() = if ( returnsStructValue ) "void" else returns.jniFunctionType
private fun Parameter.error(msg: String) {
@@ -304,7 +307,7 @@ public class NativeClassFunction(
val checks = ArrayList<String>()
// Validate function address
- if ( nativeClass.functionProvider != null )
+ if ( nativeClass.functionProvider != null && !hasUnsafeMethod )
checks add "checkFunctionAddress($FUNCTION_ADDRESS);"
// We convert multi-byte-per-element buffers to ByteBuffer for NORMAL generation.
@@ -445,6 +448,9 @@ public class NativeClassFunction(
writer.generateNativeMethod(simpleFunction)
if ( !simpleFunction ) {
+ if ( nativeClass.functionProvider != null && hasUnsafeMethod )
+ writer.generateUnsafeMethod()
+
// This the only special case where we don't generate a "normal" Java method. If we did,
// we'd need to add a postfix to either this or the alternative method, since we're
// changing the return type. It looks ugly and LWJGL didn't do it pre-3.0 either.
@@ -488,6 +494,41 @@ public class NativeClassFunction(
println(");\n")
}
+ private fun PrintWriter.generateUnsafeMethod() {
+ generateJavaDocLink("Unsafe version of", this@NativeClassFunction)
+ println("\t@JavadocExclude")
+ print("\t${accessModifier}static ${returnsNativeMethodType} n$name(")
+ printList(getNativeParams()) {
+ it.asNativeMethodParam
+ }
+
+ if ( returnsStructValue ) {
+ if ( this@NativeClassFunction.hasNativeParams ) print(", ")
+ print("long $RESULT")
+ }
+ println(") {")
+
+ // Get and validate function address
+ nativeClass.functionProvider!!.generateFunctionAddress(this, this@NativeClassFunction)
+ println("\t\tif ( LWJGLUtil.CHECKS )")
+ println("\t\t\tcheckFunctionAddress($FUNCTION_ADDRESS);")
+
+ generateNativeMethodCall {
+ printList(getNativeParams()) {
+ it.name
+ }
+
+ if ( returnsStructValue ) {
+ if ( hasNativeParams ) print(", ")
+ print("memAddress($RESULT)")
+ }
+
+ if ( hasNativeParams ) print(", ")
+ print("$FUNCTION_ADDRESS")
+ }
+ println("\t}\n")
+ }
+
private fun PrintWriter.generateJavaMethod() {
// Step 0: JavaDoc
@@ -523,7 +564,7 @@ public class NativeClassFunction(
// Step 2: Get function address
- if ( nativeClass.functionProvider != null )
+ if ( nativeClass.functionProvider != null && !hasUnsafeMethod )
nativeClass.functionProvider.generateFunctionAddress(this, this@NativeClassFunction)
// Step 3.a: Generate checks
@@ -635,10 +676,16 @@ public class NativeClassFunction(
}
}
- private fun PrintWriter.generateNativeMethodCall(returnLater: Boolean = false, printParams: PrintWriter.() -> Unit) {
+ private fun PrintWriter.generateNativeMethodCall(
+ // false: check return type
+ // true: force later
+ // null: force immediate
+ returnLater: Boolean? = null,
+ printParams: PrintWriter.() -> Unit
+ ) {
print("\t\t")
if ( !(returns.isVoid || returnsStructValue) ) {
- if ( returns.isBufferPointer || returnLater ) {
+ if ( returnLater != null && (returns.isBufferPointer || true.equals(returnLater)) ) {
print(
if ( returns.nativeType is ObjectType )
"${returns.nativeType.className} $RESULT = ${returns.nativeType.className}.create("
@@ -647,21 +694,21 @@ public class NativeClassFunction(
)
} else {
print("return ")
- if ( returns.nativeType is ObjectType )
+ if ( returnLater != null && returns.nativeType is ObjectType )
print("${returns.nativeType.className}.create(")
}
}
if ( has(Reuse) ) print("${get(Reuse).reference}.")
print("n$name(")
printParams()
- if ( nativeClass.functionProvider != null ) {
+ if ( nativeClass.functionProvider != null && !hasUnsafeMethod ) {
if ( hasNativeParams ) print(", ")
print("$FUNCTION_ADDRESS")
}
print(")")
- if ( returns.nativeType is ObjectType ) {
+ if ( returnLater != null && returns.nativeType is ObjectType ) {
if ( returns has Construct ) {
val construct = returns[Construct]
print(", ${construct.firstArg}")
@@ -993,7 +1040,7 @@ public class NativeClassFunction(
// Step 2: Get function address
- if ( nativeClass.functionProvider != null )
+ if ( nativeClass.functionProvider != null && !hasUnsafeMethod )
nativeClass.functionProvider.generateFunctionAddress(this, this@NativeClassFunction)
// Step 3.A: Generate checks

0 comments on commit aa1f997

Please sign in to comment.