From 3f91142667d7210b575841edfad77d60ef34e225 Mon Sep 17 00:00:00 2001 From: fjebaker Date: Tue, 1 Aug 2023 13:48:04 +0100 Subject: [PATCH 1/2] fix: sincos ccall types From the Apple developer manual, the sincos functions return the cosine part via a reference argument `&T`. Changed the overrides to now return both the sin and cosine components as a tuple. --- src/device/intrinsics/math.jl | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index cee47792b..1179bff1a 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -125,9 +125,21 @@ using Base: FastMath @device_override Base.sin(x::Float32) = ccall("extern air.sin.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.sin(x::Float16) = ccall("extern air.sin.f16", llvmcall, Float16, (Float16,), x) -@device_override FastMath.sincos_fast(x::Float32) = ccall("extern air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_override Base.sincos(x::Float32) = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat,), x) -@device_override Base.sincos(x::Float16) = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16,), x) +@device_override function FastMath.sincos_fast(x::Float32) + c = Ref{Cfloat}() + s = ccall("extern air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c) + (s, c[]) +end +@device_override function Base.sincos(x::Float32) + c = Ref{Cfloat}() + s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c) + (s, c[]) +end +@device_override function Base.sincos(x::Float16) + c = Ref{Float16}() + s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c) + (s, c[]) +end @device_override FastMath.sinh_fast(x::Float32) = ccall("extern air.fast_sinh.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.sinh(x::Float32) = ccall("extern air.sinh.f32", llvmcall, Cfloat, (Cfloat,), x) From fb5a33c0f7743ce0679a0bcb329335b1213155be Mon Sep 17 00:00:00 2001 From: fjebaker Date: Tue, 1 Aug 2023 14:58:17 +0100 Subject: [PATCH 2/2] test: added test for sincos intrinsic --- test/device/intrinsics.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 5eb4c4296..3d3a166d5 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -121,6 +121,23 @@ end return nothing end @metal intr_test2(bufferA) + synchronize() + + bufferB = MtlArray{eltype(a),length(size(a)),Shared}(a) + vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1) + + function intr_test3(arr_sin, arr_cos) + idx = thread_position_in_grid_1d() + s, c = sincos(arr_cos[idx]) + arr_sin[idx] = s + arr_cos[idx] = c + return nothing + end + + @metal intr_test3(bufferA, bufferB) + synchronize() + @test vecA ≈ sin.(a) + @test vecB ≈ cos.(a) end ############################################################################################