Permalink
Browse files

Dynamically loaded functions with pointer argument/return types now g…

…enerate an additional unsafe method. The signature matches the JNI method, except the function address argument is missing; it is passed automatically like standard methods. The call sequence is now: normal/alternative -> unsafe -> JNI, which lets us have the function address check in one place. Credit for the idea goes to JGO's Riven.
  • Loading branch information...
Spasi committed Nov 23, 2013
1 parent 143950d commit aa1f9978951835922cc7d0118b05c356bfe92ea4
@@ -145,11 +145,7 @@ public static ALCCapabilities getCapabilities() {
* @param token the information to query. One of:<p/>{@link ALC11#ALC_ALL_DEVICES_SPECIFIER}, {@link ALC11#ALC_CAPTURE_DEVICE_SPECIFIER}
*/
public static List<String> getStringList(long deviceHandle, int token) {
long alcGetString = functionProvider.getFunctionAddress("alcGetString");
if ( LWJGLUtil.CHECKS )
checkFunctionAddress(alcGetString);
long __result = nalcGetString(deviceHandle, token, alcGetString);
long __result = nalcGetString(deviceHandle, token);
if ( __result == NULL )
return null;
@@ -35,16 +35,14 @@ public static CLDevice create(long cl_device_id, CLPlatform platform) {
}
private static CLCapabilities createCapabilities(long cl_device_id, CLPlatform platform) {
long clGetDeviceInfo = CL10.getInstance().GetDeviceInfo;
Set<String> supportedExtensions = new HashSet<>(32);
// Parse DEVICE_EXTENSIONS string
String extensionsString = getDeviceInfo(cl_device_id, CL_DEVICE_EXTENSIONS, clGetDeviceInfo);
String extensionsString = getDeviceInfo(cl_device_id, CL_DEVICE_EXTENSIONS);
CL.addExtensions(extensionsString, supportedExtensions);
// Parse DEVICE_VERSION string
String version = getDeviceInfo(cl_device_id, CL_DEVICE_VERSION, clGetDeviceInfo);
String version = getDeviceInfo(cl_device_id, CL_DEVICE_VERSION);
int majorVersion;
int minorVersion;
try {
@@ -60,18 +58,18 @@ private static CLCapabilities createCapabilities(long cl_device_id, CLPlatform p
return new CLCapabilities(majorVersion, minorVersion, supportedExtensions, platform.getCapabilities());
}
static String getDeviceInfo(long device_id, int param_name, long clGetDeviceInfo) {
static String getDeviceInfo(long device_id, int param_name) {
APIBuffer __buffer = apiBuffer();
__buffer.intParam(0);
int errcode = nclGetDeviceInfo(device_id, param_name, 0L, NULL, __buffer.address(), clGetDeviceInfo);
int errcode = nclGetDeviceInfo(device_id, param_name, 0L, NULL, __buffer.address());
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query size of OpenCL device information.");
int bytes = __buffer.intValue(0);
__buffer.bufferParam(bytes);
errcode = nclGetDeviceInfo(device_id, param_name, bytes, __buffer.address(), NULL, clGetDeviceInfo);
errcode = nclGetDeviceInfo(device_id, param_name, bytes, __buffer.address(), NULL);
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query OpenCL device information.");
@@ -38,22 +38,16 @@ public static CLPlatform create(long id) {
}
private static CLCapabilities createCapabilities(long platform, FunctionProvider functionProvider) {
long clGetPlatformInfo = functionProvider.getFunctionAddress("clGetPlatformInfo");
long clGetDeviceIDs = functionProvider.getFunctionAddress("clGetDeviceIDs");
long clGetDeviceInfo = functionProvider.getFunctionAddress("clGetDeviceInfo");
if ( clGetPlatformInfo == NULL || clGetDeviceIDs == NULL || clGetDeviceInfo == NULL )
throw new OpenCLException("A core OpenCL function is missing. Make sure that OpenCL is available.");
Set<String> supportedExtensions = new HashSet<>(32);
// Parse PLATFORM_EXTENSIONS string
String extensionsString = getPlatformInfo(platform, CL_PLATFORM_EXTENSIONS, clGetPlatformInfo);
String extensionsString = getPlatformInfo(platform, CL_PLATFORM_EXTENSIONS);
CL.addExtensions(extensionsString, supportedExtensions);
// Enumerate devices
{
APIBuffer __buffer = apiBuffer();
int errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, __buffer.address(), clGetDeviceIDs);
int errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, __buffer.address());
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query number of OpenCL platform devices.");
@@ -63,7 +57,7 @@ private static CLCapabilities createCapabilities(long platform, FunctionProvider
__buffer.bufferParam(num_devices << POINTER_SHIFT);
errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, __buffer.address(), NULL, clGetDeviceIDs);
errcode = nclGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, __buffer.address(), NULL);
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query OpenCL platform devices.");
@@ -73,13 +67,13 @@ private static CLCapabilities createCapabilities(long platform, FunctionProvider
// Add device extensions to the set
for ( int i = 0; i < num_devices; i++ ) {
extensionsString = CLDevice.getDeviceInfo(devices[i], CL_DEVICE_EXTENSIONS, clGetDeviceInfo);
extensionsString = CLDevice.getDeviceInfo(devices[i], CL_DEVICE_EXTENSIONS);
CL.addExtensions(extensionsString, supportedExtensions);
}
}
// Parse PLATFORM_VERSION string
String version = getPlatformInfo(platform, CL_PLATFORM_VERSION, clGetPlatformInfo);
String version = getPlatformInfo(platform, CL_PLATFORM_VERSION);
int majorVersion;
int minorVersion;
try {
@@ -95,18 +89,18 @@ private static CLCapabilities createCapabilities(long platform, FunctionProvider
return new CLCapabilities(majorVersion, minorVersion, supportedExtensions, CL.getICD());
}
private static String getPlatformInfo(long platform, int param_name, long clGetPlatformInfo) {
private static String getPlatformInfo(long platform, int param_name) {
APIBuffer __buffer = apiBuffer();
__buffer.intParam(0);
int errcode = nclGetPlatformInfo(platform, param_name, 0L, NULL, __buffer.address(), clGetPlatformInfo);
int errcode = nclGetPlatformInfo(platform, param_name, 0L, NULL, __buffer.address());
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query size of OpenCL platform information.");
int bytes = __buffer.intValue(0);
__buffer.bufferParam(bytes);
errcode = nclGetPlatformInfo(platform, param_name, bytes, __buffer.address(), NULL, clGetPlatformInfo);
errcode = nclGetPlatformInfo(platform, param_name, bytes, __buffer.address(), NULL);
if ( LWJGLUtil.DEBUG && errcode != CL_SUCCESS )
throw new OpenCLException("Failed to query OpenCL platform information.");
@@ -214,22 +214,25 @@ public class NativeClassFunction(
return builder.toString()
}
val isSimpleFunction: Boolean
private val isSimpleFunction: Boolean
get() = nativeClass.functionProvider == null && !(isSpecial || returns.isSpecial || hasParam { it.isSpecial })
val ReturnValue.isStructValue: Boolean
private val hasUnsafeMethod: Boolean
get() = nativeClass.functionProvider != null && (returns.isBufferPointer || hasParam { it.isBufferPointer }) && !has(Capabilities)
private val ReturnValue.isStructValue: Boolean
get() = nativeType is StructType && !nativeType.includesPointer
val returnsStructValue: Boolean
internal val returnsStructValue: Boolean
get() = returns.isStructValue && !hasParam { it has autoSizeResult }
val returnsJavaMethodType: String
private val returnsJavaMethodType: String
get() = if ( returnsStructValue ) "void" else returns.javaMethodType
val returnsNativeMethodType: String
private val returnsNativeMethodType: String
get() = if ( returnsStructValue ) "void" else returns.nativeMethodType
val returnsJniFunctionType: String
private val returnsJniFunctionType: String
get() = if ( returnsStructValue ) "void" else returns.jniFunctionType
private fun Parameter.error(msg: String) {
@@ -304,7 +307,7 @@ public class NativeClassFunction(
val checks = ArrayList<String>()
// Validate function address
if ( nativeClass.functionProvider != null )
if ( nativeClass.functionProvider != null && !hasUnsafeMethod )
checks add "checkFunctionAddress($FUNCTION_ADDRESS);"
// We convert multi-byte-per-element buffers to ByteBuffer for NORMAL generation.
@@ -445,6 +448,9 @@ public class NativeClassFunction(
writer.generateNativeMethod(simpleFunction)
if ( !simpleFunction ) {
if ( nativeClass.functionProvider != null && hasUnsafeMethod )
writer.generateUnsafeMethod()
// This the only special case where we don't generate a "normal" Java method. If we did,
// we'd need to add a postfix to either this or the alternative method, since we're
// changing the return type. It looks ugly and LWJGL didn't do it pre-3.0 either.
@@ -488,6 +494,41 @@ public class NativeClassFunction(
println(");\n")
}
private fun PrintWriter.generateUnsafeMethod() {
generateJavaDocLink("Unsafe version of", this@NativeClassFunction)
println("\t@JavadocExclude")
print("\t${accessModifier}static ${returnsNativeMethodType} n$name(")
printList(getNativeParams()) {
it.asNativeMethodParam
}
if ( returnsStructValue ) {
if ( this@NativeClassFunction.hasNativeParams ) print(", ")
print("long $RESULT")
}
println(") {")
// Get and validate function address
nativeClass.functionProvider!!.generateFunctionAddress(this, this@NativeClassFunction)
println("\t\tif ( LWJGLUtil.CHECKS )")
println("\t\t\tcheckFunctionAddress($FUNCTION_ADDRESS);")
generateNativeMethodCall {
printList(getNativeParams()) {
it.name
}
if ( returnsStructValue ) {
if ( hasNativeParams ) print(", ")
print("memAddress($RESULT)")
}
if ( hasNativeParams ) print(", ")
print("$FUNCTION_ADDRESS")
}
println("\t}\n")
}
private fun PrintWriter.generateJavaMethod() {
// Step 0: JavaDoc
@@ -523,7 +564,7 @@ public class NativeClassFunction(
// Step 2: Get function address
if ( nativeClass.functionProvider != null )
if ( nativeClass.functionProvider != null && !hasUnsafeMethod )
nativeClass.functionProvider.generateFunctionAddress(this, this@NativeClassFunction)
// Step 3.a: Generate checks
@@ -635,10 +676,16 @@ public class NativeClassFunction(
}
}
private fun PrintWriter.generateNativeMethodCall(returnLater: Boolean = false, printParams: PrintWriter.() -> Unit) {
private fun PrintWriter.generateNativeMethodCall(
// false: check return type
// true: force later
// null: force immediate
returnLater: Boolean? = null,
printParams: PrintWriter.() -> Unit
) {
print("\t\t")
if ( !(returns.isVoid || returnsStructValue) ) {
if ( returns.isBufferPointer || returnLater ) {
if ( returnLater != null && (returns.isBufferPointer || true.equals(returnLater)) ) {
print(
if ( returns.nativeType is ObjectType )
"${returns.nativeType.className} $RESULT = ${returns.nativeType.className}.create("
@@ -647,21 +694,21 @@ public class NativeClassFunction(
)
} else {
print("return ")
if ( returns.nativeType is ObjectType )
if ( returnLater != null && returns.nativeType is ObjectType )
print("${returns.nativeType.className}.create(")
}
}
if ( has(Reuse) ) print("${get(Reuse).reference}.")
print("n$name(")
printParams()
if ( nativeClass.functionProvider != null ) {
if ( nativeClass.functionProvider != null && !hasUnsafeMethod ) {
if ( hasNativeParams ) print(", ")
print("$FUNCTION_ADDRESS")
}
print(")")
if ( returns.nativeType is ObjectType ) {
if ( returnLater != null && returns.nativeType is ObjectType ) {
if ( returns has Construct ) {
val construct = returns[Construct]
print(", ${construct.firstArg}")
@@ -993,7 +1040,7 @@ public class NativeClassFunction(
// Step 2: Get function address
if ( nativeClass.functionProvider != null )
if ( nativeClass.functionProvider != null && !hasUnsafeMethod )
nativeClass.functionProvider.generateFunctionAddress(this, this@NativeClassFunction)
// Step 3.A: Generate checks

0 comments on commit aa1f997

Please sign in to comment.