Skip to content

Commit

Permalink
Migrate MTLLibrary.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 6, 2023
1 parent fabf905 commit 598ec4e
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ version = "1.2.0"

[[deps.ObjectiveC]]
deps = ["Lazy", "MacroTools"]
git-tree-sha1 = "83ef9f60ccd5174ba74281fa4ac1a735a10cefc1"
git-tree-sha1 = "b3b8893ac70755a5642871b9aff554634ce20386"
repo-rev = "tb/modernize"
repo-url = "https://github.com/JuliaInterop/ObjectiveC.jl"
uuid = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
Expand Down
2 changes: 1 addition & 1 deletion lib/mtl/MTL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MTL
using ..cmt

using CEnum
using ObjectiveC, ObjectiveC.Foundation
using ObjectiveC, .Foundation, .Dispatch


## version information
Expand Down
12 changes: 9 additions & 3 deletions lib/mtl/binary_archive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ Base.:(==)(a::MtlBinaryArchive, b::MtlBinaryArchive) = a.handle == b.handle
Base.hash(lib::MtlBinaryArchive, h::UInt) = hash(lib.handle, h)

function MtlBinaryArchive(device::MTLDevice, desc::MtlBinaryArchiveDescriptor)
handle = @mtlthrows _errptr mtNewBinaryArchiveWithDescriptor(device, desc, _errptr)
err = Ref{id{NSError}}(nil)
handle = mtNewBinaryArchiveWithDescriptor(device, desc, err)
err[] == nil || throw(NSError(err[]))

obj = MtlBinaryArchive(handle, device, desc)
finalizer(unsafe_destroy!, obj)
Expand Down Expand Up @@ -137,9 +139,13 @@ end
## operations

function add_functions!(bin::MtlBinaryArchive, desc::MtlComputePipelineDescriptor)
@mtlthrows _errptr mtBinaryArchiveAddComputePipelineFunctions(bin, desc, _errptr)
err = Ref{id{NSError}}(nil)
mtBinaryArchiveAddComputePipelineFunctions(bin, desc, err)
err[] == nil || throw(NSError(err[]))
end

function Base.write(filename::String, bin::MtlBinaryArchive)
@mtlthrows _errptr mtBinaryArchiveSerialize(bin, filename, _errptr)
err = Ref{id{NSError}}(nil)
mtBinaryArchiveSerialize(bin, filename, err)
err[] == nil || throw(NSError(err[]))
end
4 changes: 3 additions & 1 deletion lib/mtl/compute_pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ function unsafe_destroy!(cce::MtlComputePipelineState)
end

function MtlComputePipelineState(d::MTLDevice, f::MtlFunction)
handle = @mtlthrows _errptr mtNewComputePipelineStateWithFunction(d, f, _errptr)
err = Ref{id{NSError}}(nil)
handle = mtNewComputePipelineStateWithFunction(d, f, err)
err[] == nil || throw(NSError(err[]))

obj = MtlComputePipelineState(handle, d)
finalizer(unsafe_destroy!, obj)
Expand Down
4 changes: 2 additions & 2 deletions lib/mtl/function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ mutable struct MtlFunction
handle::MTLFunction

# roots (can be nothing if the function was created directly from a handle)
lib::Union{Nothing,MtlLibrary}
lib::Union{Nothing,MTLLibrary}

MtlFunction(handle::MTLFunction, lib=nothing) = new(handle, lib)
end
Expand All @@ -95,7 +95,7 @@ Base.:(==)(a::MtlFunction, b::MtlFunction) = a.handle == b.handle
Base.hash(fun::MtlFunction, h::UInt) = hash(mod.handle, h)

# Get a handle to a kernel function in a Metal Library.
function MtlFunction(lib::MtlLibrary, name::String)
function MtlFunction(lib::MTLLibrary, name::String)
handle = mtNewFunctionWithName(lib, name)
handle == C_NULL && throw(KeyError(name))
obj = MtlFunction(handle, lib)
Expand Down
27 changes: 0 additions & 27 deletions lib/mtl/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,3 @@ end
##
Base.convert(::Type{MtResourceOptions}, val::UInt32) =
MtResourceOptions(val)

##
"""
@mtlthrows error_var function(..., error_var)
Marks that this Metal function has an argument error_var which
must be passed by reference in the underlying ccall, and when
the function returns checks that no error has been set.
Expands roughly to
```julia
error_var = Ref{id}()
result = function(..., error_var)
error[] != nil && throw(NSError(error[]))
```
"""
macro mtlthrows(error, fun)
expr = quote
$error = Ref{id{NSError}}(nil)
result = $fun
if $error[] != nil
throw(NSError($(error)[]))
end
result
end
return esc(expr)
end
102 changes: 37 additions & 65 deletions lib/mtl/library.jl
Original file line number Diff line number Diff line change
@@ -1,96 +1,68 @@
export MtlLibrary, MtlLibraryFromFile, MtlLibraryFromData
export MTLLibrary, MTLLibraryFromFile, MTLLibraryFromData

const MTLLibrary = Ptr{MtLibrary}
@objcwrapper immutable=false MTLLibrary <: NSObject

"""
MTLDevice(i::Integer)
# compatibility with cmt
Base.unsafe_convert(T::Type{Ptr{MtLibrary}}, obj::MTLLibrary) =
reinterpret(T, Base.unsafe_convert(id, obj))
MTLLibrary(ptr::Ptr{MtLibrary}) = MTLLibrary(reinterpret(id{MTLLibrary}, ptr))

Get a handle to a compute device.
"""
mutable struct MtlLibrary
handle::MTLLibrary
device::MTLDevice
end

Base.unsafe_convert(::Type{MTLLibrary}, lib::MtlLibrary) = lib.handle

Base.:(==)(a::MtlLibrary, b::MtlLibrary) = a.handle == b.handle
Base.hash(lib::MtlLibrary, h::UInt) = hash(lib.handle, h)
function MTLLibrary(device::MTLDevice, src::String,
opts::MTLCompileOptions=MTLCompileOptions())
err = Ref{id{NSError}}(nil)
handle = @objc [device::id{MTLDevice} newLibraryWithSource:src::id{NSString}
options:opts::id{MTLCompileOptions}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
err[] == nil || throw(NSError(err[]))

function MtlLibrary(device::MTLDevice, src::String, opts::MTLCompileOptions=MTLCompileOptions())
handle = @mtlthrows _errptr mtNewLibraryWithSource(device, src, opts, _errptr)

obj = MtlLibrary(handle, device)
obj = MTLLibrary(handle)
finalizer(unsafe_destroy!, obj)

return obj
end

function MtlLibraryFromFile(device::MTLDevice, path::String)
function MTLLibraryFromFile(device::MTLDevice, path::String)
err = Ref{id{NSError}}(nil)
handle = if macos_version() >= v"13"
@mtlthrows _errptr mtNewLibraryWithURL(device, path, _errptr)
url = NSFileURL(path)
@objc [device::id{MTLDevice} newLibraryWithURL:url::id{NSURL}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
else
@mtlthrows _errptr mtNewLibraryWithFile(device, path, _errptr)
@objc [device::id{MTLDevice} newLibraryWithFile:path::id{NSString}
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
end
err[] == nil || throw(NSError(err[]))

obj = MtlLibrary(handle, device)
obj = MTLLibrary(handle)
finalizer(unsafe_destroy!, obj)

return obj
end

function MtlLibraryFromData(device::MTLDevice, data)
GC.@preserve data begin
handle = @mtlthrows _errptr mtNewLibraryWithData(device, pointer(data), sizeof(data), _errptr)
function MTLLibraryFromData(device::MTLDevice, input_data)
err = Ref{id{NSError}}(nil)
GC.@preserve input_data begin
data = dispatch_data(pointer(input_data), sizeof(input_data))
handle = @objc [device::id{MTLDevice} newLibraryWithData:data::dispatch_data_t
error:err::Ptr{id{NSError}}]::id{MTLLibrary}
end
err[] == nil || throw(NSError(err[]))

obj = MtlLibrary(handle, device)
obj = MTLLibrary(handle)
finalizer(unsafe_destroy!, obj)

return obj
end

function unsafe_destroy!(lib::MtlLibrary)
mtRelease(lib.handle)
function unsafe_destroy!(lib::MTLLibrary)
@objc [lib::id{MTLLibrary} release]::Nothing
end


## properties

Base.propertynames(::MtlLibrary) = (:device, :label, :functionNames)

function Base.getproperty(lib::MtlLibrary, f::Symbol)
if f === :label
ptr = mtLibraryLabel(lib)
ptr == C_NULL ? nothing : unsafe_string(ptr)
elseif f === :functionNames
count = Ref{Csize_t}(0)
mtLibraryFunctionNames(lib, count, C_NULL)
names = Vector{Cstring}(undef, count[])
mtLibraryFunctionNames(lib, count, names)
unsafe_string.(names)
else
getfield(lib, f)
end
end

function Base.setproperty!(lib::MtlLibrary, f::Symbol, val)
if f === :label
mtLibraryLabelSet(lib, val)
else
setfield!(lib, f, val)
end
end


## display

function Base.show(io::IO, lib::MtlLibrary)
print(io, "MtlLibrary($(lib.device))")
end

function Base.show(io::IO, ::MIME"text/plain", lib::MtlLibrary)
println(io, "MtlLibrary:")
println(io, " device: ", lib.device)
print(io, " label: ", lib.label)
@objcproperties MTLLibrary begin
@autoproperty device::id{MTLDevice}
@autoproperty label::id{NSString} setter=setLabel
@autoproperty functionNames::id{NSArray} type=Vector{NSString}
end
2 changes: 1 addition & 1 deletion src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LLVM
using LLVM.Interop
using Metal_LLVM_Tools_jll
using ExprTools: splitdef, combinedef
using ObjectiveC, .Foundation
using ObjectiveC, .Foundation, .Dispatch

# C wrappers
include("../lib/cmt/cmt.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end

function mtlfunction_link(@nospecialize(job::CompilerJob), compiled)
dev = current_device()
lib = MtlLibraryFromData(dev, compiled.image)
lib = MTLLibraryFromData(dev, compiled.image)
fun = MtlFunction(lib, compiled.entry)
pipeline_state = try
MtlComputePipelineState(dev, fun)
Expand Down
14 changes: 7 additions & 7 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
dev = first(devices())
opts = MTLCompileOptions()

let lib = MtlLibrary(dev, "", opts)
let lib = MTLLibrary(dev, "", opts)
@test lib.device == dev
@test lib.label === nothing
lib.label = "MyLibrary"
Expand All @@ -64,7 +64,7 @@ let lib = MtlLibrary(dev, "", opts)
end

metal_code = read(joinpath(@__DIR__, "dummy.metal"), String)
let lib = MtlLibrary(dev, metal_code, opts)
let lib = MTLLibrary(dev, metal_code, opts)
@test lib.device == dev
@test lib.label === nothing
fns = lib.functionNames
Expand All @@ -74,7 +74,7 @@ let lib = MtlLibrary(dev, metal_code, opts)
end

binary_path = joinpath(@__DIR__, "dummy.metallib")
let lib = MtlLibraryFromFile(dev, binary_path)
let lib = MTLLibraryFromFile(dev, binary_path)
@test lib.device == dev
@test lib.label === nothing
fns = lib.functionNames
Expand All @@ -84,7 +84,7 @@ let lib = MtlLibraryFromFile(dev, binary_path)
end

binary_code = read(binary_path)
let lib = MtlLibraryFromData(dev, binary_code)
let lib = MTLLibraryFromData(dev, binary_code)
@test lib.device == dev
@test lib.label === nothing
fns = lib.functionNames
Expand Down Expand Up @@ -115,7 +115,7 @@ desc.specializedName = "MySpecializedKernel"


dev = first(devices())
lib = MtlLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
fun = MtlFunction(lib, "kernel_1")

compact_str = sprint(io->show(io, fun))
Expand Down Expand Up @@ -340,7 +340,7 @@ end
@testset "compute pipeline" begin

dev = first(devices())
lib = MtlLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
fun = MtlFunction(lib, "kernel_1")

pipeline = MtlComputePipelineState(dev, fun)
Expand Down Expand Up @@ -383,7 +383,7 @@ end
@testset "binary archive" begin

dev = first(devices())
lib = MtlLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
lib = MTLLibraryFromFile(dev, joinpath(@__DIR__, "dummy.metallib"))
fun = MtlFunction(lib, "kernel_1")

desc = MtlBinaryArchiveDescriptor()
Expand Down

0 comments on commit 598ec4e

Please sign in to comment.