diff --git a/Manifest.toml b/Manifest.toml index 07caf155..76632ab6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -157,7 +157,7 @@ version = "1.2.0" [[deps.ObjectiveC]] deps = ["Lazy", "MacroTools"] -git-tree-sha1 = "83ef9f60ccd5174ba74281fa4ac1a735a10cefc1" +git-tree-sha1 = "b3b8893ac70755a5642871b9aff554634ce20386" repo-rev = "tb/modernize" repo-url = "https://github.com/JuliaInterop/ObjectiveC.jl" uuid = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9" diff --git a/lib/mtl/MTL.jl b/lib/mtl/MTL.jl index 26d6fe2b..2e21cd6e 100644 --- a/lib/mtl/MTL.jl +++ b/lib/mtl/MTL.jl @@ -3,7 +3,7 @@ module MTL using ..cmt using CEnum -using ObjectiveC, ObjectiveC.Foundation +using ObjectiveC, .Foundation, .Dispatch ## version information diff --git a/lib/mtl/binary_archive.jl b/lib/mtl/binary_archive.jl index 345d7ee2..977ab6da 100644 --- a/lib/mtl/binary_archive.jl +++ b/lib/mtl/binary_archive.jl @@ -82,7 +82,9 @@ Base.:(==)(a::MtlBinaryArchive, b::MtlBinaryArchive) = a.handle == b.handle Base.hash(lib::MtlBinaryArchive, h::UInt) = hash(lib.handle, h) function MtlBinaryArchive(device::MTLDevice, desc::MtlBinaryArchiveDescriptor) - handle = @mtlthrows _errptr mtNewBinaryArchiveWithDescriptor(device, desc, _errptr) + err = Ref{id{NSError}}(nil) + handle = mtNewBinaryArchiveWithDescriptor(device, desc, err) + err[] == nil || throw(NSError(err[])) obj = MtlBinaryArchive(handle, device, desc) finalizer(unsafe_destroy!, obj) @@ -137,9 +139,13 @@ end ## operations function add_functions!(bin::MtlBinaryArchive, desc::MtlComputePipelineDescriptor) - @mtlthrows _errptr mtBinaryArchiveAddComputePipelineFunctions(bin, desc, _errptr) + err = Ref{id{NSError}}(nil) + mtBinaryArchiveAddComputePipelineFunctions(bin, desc, err) + err[] == nil || throw(NSError(err[])) end function Base.write(filename::String, bin::MtlBinaryArchive) - @mtlthrows _errptr mtBinaryArchiveSerialize(bin, filename, _errptr) + err = Ref{id{NSError}}(nil) + mtBinaryArchiveSerialize(bin, filename, err) + err[] == nil || throw(NSError(err[])) end diff --git a/lib/mtl/compute_pipeline.jl b/lib/mtl/compute_pipeline.jl index 8c8f5665..15062087 100644 --- a/lib/mtl/compute_pipeline.jl +++ b/lib/mtl/compute_pipeline.jl @@ -116,7 +116,9 @@ function unsafe_destroy!(cce::MtlComputePipelineState) end function MtlComputePipelineState(d::MTLDevice, f::MtlFunction) - handle = @mtlthrows _errptr mtNewComputePipelineStateWithFunction(d, f, _errptr) + err = Ref{id{NSError}}(nil) + handle = mtNewComputePipelineStateWithFunction(d, f, err) + err[] == nil || throw(NSError(err[])) obj = MtlComputePipelineState(handle, d) finalizer(unsafe_destroy!, obj) diff --git a/lib/mtl/function.jl b/lib/mtl/function.jl index 848e8d4b..2988c448 100644 --- a/lib/mtl/function.jl +++ b/lib/mtl/function.jl @@ -80,7 +80,7 @@ mutable struct MtlFunction handle::MTLFunction # roots (can be nothing if the function was created directly from a handle) - lib::Union{Nothing,MtlLibrary} + lib::Union{Nothing,MTLLibrary} MtlFunction(handle::MTLFunction, lib=nothing) = new(handle, lib) end @@ -95,7 +95,7 @@ Base.:(==)(a::MtlFunction, b::MtlFunction) = a.handle == b.handle Base.hash(fun::MtlFunction, h::UInt) = hash(mod.handle, h) # Get a handle to a kernel function in a Metal Library. -function MtlFunction(lib::MtlLibrary, name::String) +function MtlFunction(lib::MTLLibrary, name::String) handle = mtNewFunctionWithName(lib, name) handle == C_NULL && throw(KeyError(name)) obj = MtlFunction(handle, lib) diff --git a/lib/mtl/helpers.jl b/lib/mtl/helpers.jl index 5e690189..f74d1db8 100644 --- a/lib/mtl/helpers.jl +++ b/lib/mtl/helpers.jl @@ -44,30 +44,3 @@ end ## Base.convert(::Type{MtResourceOptions}, val::UInt32) = MtResourceOptions(val) - -## -""" - @mtlthrows error_var function(..., error_var) - -Marks that this Metal function has an argument error_var which -must be passed by reference in the underlying ccall, and when -the function returns checks that no error has been set. - -Expands roughly to -```julia -error_var = Ref{id}() -result = function(..., error_var) -error[] != nil && throw(NSError(error[])) -``` -""" -macro mtlthrows(error, fun) - expr = quote - $error = Ref{id{NSError}}(nil) - result = $fun - if $error[] != nil - throw(NSError($(error)[])) - end - result - end - return esc(expr) -end diff --git a/lib/mtl/library.jl b/lib/mtl/library.jl index a90b8ac0..58d9cf7b 100644 --- a/lib/mtl/library.jl +++ b/lib/mtl/library.jl @@ -1,96 +1,68 @@ -export MtlLibrary, MtlLibraryFromFile, MtlLibraryFromData +export MTLLibrary, MTLLibraryFromFile, MTLLibraryFromData -const MTLLibrary = Ptr{MtLibrary} +@objcwrapper immutable=false MTLLibrary <: NSObject -""" - MTLDevice(i::Integer) +# compatibility with cmt +Base.unsafe_convert(T::Type{Ptr{MtLibrary}}, obj::MTLLibrary) = + reinterpret(T, Base.unsafe_convert(id, obj)) +MTLLibrary(ptr::Ptr{MtLibrary}) = MTLLibrary(reinterpret(id{MTLLibrary}, ptr)) -Get a handle to a compute device. -""" -mutable struct MtlLibrary - handle::MTLLibrary - device::MTLDevice -end - -Base.unsafe_convert(::Type{MTLLibrary}, lib::MtlLibrary) = lib.handle - -Base.:(==)(a::MtlLibrary, b::MtlLibrary) = a.handle == b.handle -Base.hash(lib::MtlLibrary, h::UInt) = hash(lib.handle, h) +function MTLLibrary(device::MTLDevice, src::String, + opts::MTLCompileOptions=MTLCompileOptions()) + err = Ref{id{NSError}}(nil) + handle = @objc [device::id{MTLDevice} newLibraryWithSource:src::id{NSString} + options:opts::id{MTLCompileOptions} + error:err::Ptr{id{NSError}}]::id{MTLLibrary} + err[] == nil || throw(NSError(err[])) -function MtlLibrary(device::MTLDevice, src::String, opts::MTLCompileOptions=MTLCompileOptions()) - handle = @mtlthrows _errptr mtNewLibraryWithSource(device, src, opts, _errptr) - - obj = MtlLibrary(handle, device) + obj = MTLLibrary(handle) finalizer(unsafe_destroy!, obj) return obj end -function MtlLibraryFromFile(device::MTLDevice, path::String) +function MTLLibraryFromFile(device::MTLDevice, path::String) + err = Ref{id{NSError}}(nil) handle = if macos_version() >= v"13" - @mtlthrows _errptr mtNewLibraryWithURL(device, path, _errptr) + url = NSFileURL(path) + @objc [device::id{MTLDevice} newLibraryWithURL:url::id{NSURL} + error:err::Ptr{id{NSError}}]::id{MTLLibrary} else - @mtlthrows _errptr mtNewLibraryWithFile(device, path, _errptr) + @objc [device::id{MTLDevice} newLibraryWithFile:path::id{NSString} + error:err::Ptr{id{NSError}}]::id{MTLLibrary} end + err[] == nil || throw(NSError(err[])) - obj = MtlLibrary(handle, device) + obj = MTLLibrary(handle) finalizer(unsafe_destroy!, obj) return obj end -function MtlLibraryFromData(device::MTLDevice, data) - GC.@preserve data begin - handle = @mtlthrows _errptr mtNewLibraryWithData(device, pointer(data), sizeof(data), _errptr) +function MTLLibraryFromData(device::MTLDevice, input_data) + err = Ref{id{NSError}}(nil) + GC.@preserve input_data begin + data = dispatch_data(pointer(input_data), sizeof(input_data)) + handle = @objc [device::id{MTLDevice} newLibraryWithData:data::dispatch_data_t + error:err::Ptr{id{NSError}}]::id{MTLLibrary} end + err[] == nil || throw(NSError(err[])) - obj = MtlLibrary(handle, device) + obj = MTLLibrary(handle) finalizer(unsafe_destroy!, obj) return obj end -function unsafe_destroy!(lib::MtlLibrary) - mtRelease(lib.handle) +function unsafe_destroy!(lib::MTLLibrary) + @objc [lib::id{MTLLibrary} release]::Nothing end ## properties -Base.propertynames(::MtlLibrary) = (:device, :label, :functionNames) - -function Base.getproperty(lib::MtlLibrary, f::Symbol) - if f === :label - ptr = mtLibraryLabel(lib) - ptr == C_NULL ? nothing : unsafe_string(ptr) - elseif f === :functionNames - count = Ref{Csize_t}(0) - mtLibraryFunctionNames(lib, count, C_NULL) - names = Vector{Cstring}(undef, count[]) - mtLibraryFunctionNames(lib, count, names) - unsafe_string.(names) - else - getfield(lib, f) - end -end - -function Base.setproperty!(lib::MtlLibrary, f::Symbol, val) - if f === :label - mtLibraryLabelSet(lib, val) - else - setfield!(lib, f, val) - end -end - - -## display - -function Base.show(io::IO, lib::MtlLibrary) - print(io, "MtlLibrary($(lib.device))") -end - -function Base.show(io::IO, ::MIME"text/plain", lib::MtlLibrary) - println(io, "MtlLibrary:") - println(io, " device: ", lib.device) - print(io, " label: ", lib.label) +@objcproperties MTLLibrary begin + @autoproperty device::id{MTLDevice} + @autoproperty label::id{NSString} setter=setLabel + @autoproperty functionNames::id{NSArray} type=Vector{NSString} end diff --git a/src/Metal.jl b/src/Metal.jl index b8a7c8d9..b28e2e64 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -8,7 +8,7 @@ using LLVM using LLVM.Interop using Metal_LLVM_Tools_jll using ExprTools: splitdef, combinedef -using ObjectiveC, .Foundation +using ObjectiveC, .Foundation, .Dispatch # C wrappers include("../lib/cmt/cmt.jl") diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 3257c419..b3e16fa6 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -171,7 +171,7 @@ end function mtlfunction_link(@nospecialize(job::CompilerJob), compiled) dev = current_device() - lib = MtlLibraryFromData(dev, compiled.image) + lib = MTLLibraryFromData(dev, compiled.image) fun = MtlFunction(lib, compiled.entry) pipeline_state = try MtlComputePipelineState(dev, fun) diff --git a/test/metal.jl b/test/metal.jl index 3bde7192..310dcda9 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -55,7 +55,7 @@ end dev = first(devices()) opts = MTLCompileOptions() -let lib = MtlLibrary(dev, "", opts) +let lib = MTLLibrary(dev, "", opts) @test lib.device == dev @test lib.label === nothing lib.label = "MyLibrary" @@ -64,7 +64,7 @@ let lib = MtlLibrary(dev, "", opts) end metal_code = read(joinpath(@__DIR__, "dummy.metal"), String) -let lib = MtlLibrary(dev, metal_code, opts) +let lib = MTLLibrary(dev, metal_code, opts) @test lib.device == dev @test lib.label === nothing fns = lib.functionNames @@ -74,7 +74,7 @@ let lib = MtlLibrary(dev, metal_code, opts) end binary_path = joinpath(@__DIR__, "dummy.metallib") -let lib = MtlLibraryFromFile(dev, binary_path) +let lib = MTLLibraryFromFile(dev, binary_path) @test lib.device == dev @test lib.label === nothing fns = lib.functionNames @@ -84,7 +84,7 @@ let lib = MtlLibraryFromFile(dev, binary_path) end binary_code = read(binary_path) -let lib = MtlLibraryFromData(dev, binary_code) +let lib = MTLLibraryFromData(dev, binary_code) @test lib.device == dev @test lib.label === nothing fns = lib.functionNames @@ -115,7 +115,7 @@ desc.specializedName = "MySpecializedKernel" dev = first(devices()) -lib = MtlLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib")) +lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib")) fun = MtlFunction(lib, "kernel_1") compact_str = sprint(io->show(io, fun)) @@ -340,7 +340,7 @@ end @testset "compute pipeline" begin dev = first(devices()) -lib = MtlLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib")) +lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib")) fun = MtlFunction(lib, "kernel_1") pipeline = MtlComputePipelineState(dev, fun) @@ -383,7 +383,7 @@ end @testset "binary archive" begin dev = first(devices()) -lib = MtlLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib")) +lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib")) fun = MtlFunction(lib, "kernel_1") desc = MtlBinaryArchiveDescriptor()