Skip to content

Commit

Permalink
Comply with Metal documentation by preventing copies between buffer s…
Browse files Browse the repository at this point in the history
…izes that are not divisible by 4
  • Loading branch information
christiangnrd committed May 21, 2024
1 parent 08ae564 commit 0ce4936
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 62 deletions.
2 changes: 1 addition & 1 deletion docs/src/usage/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,4 @@ julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a)
`MPSMatrixRandom` functionality requires Metal.jl > v1.1

!!! warning
Do not use `Random.rand!(::MPS.RNG, args...)` or `Random.randn!(::MPS.RNG, args...)` on views as you will most likely overwrite values outside of the view due to limitations in random number generation in the Metal Performance Shaders framework.
`Random.rand!(::MPS.RNG, args...)` andc `Random.randn!(::MPS.RNG, args...)` have a framework limitation that requires the byte offset and byte size of the destination array to be a multiple of 4.
17 changes: 10 additions & 7 deletions lib/mps/matrixrandom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,32 @@ synchronizeStateOnCommandBuffer(kern::MPSMatrixRandomMTGP32, cmdbuf::MTLCommandB
@objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdbuf::id{MTLCommandBuffer}]::Nothing



@inline function _mpsmat_rand!(randkern::MPSMatrixRandom, dest::MtlArray{T}, ::Type{T2};
queue::MTLCommandQueue = global_queue(current_device()),
async::Bool=false) where {T,T2}
byteoffset = dest.offset * sizeof(T)
(byteoffset % 4 == 0) || error(lazy"Destination buffer offset ($(byteoffset)) must be a multiple of 4.")
bytesize = sizeof(dest)

srcbytes = sizeof(dest)
# Even though `append_copy`` seems to work with any size or offset values, the documentation at
# https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc
# mentions that both must be multiples of 4 bytes in MacOS so error when they are not
(bytesize % 4 == 0) || error(lazy"Destination buffer bytesize ($(bytesize)) must be a multiple of 4.")
(byteoffset % 4 == 0) || error(lazy"Destination buffer offset ($(byteoffset)) must be a multiple of 4.")

cmdbuf = if srcbytes % 16 == 0 && dest.offset == 0
cmdbuf = if bytesize % 16 == 0 && dest.offset == 0
MTLCommandBuffer(queue) do cmdbuf
vecDesc = MPSVectorDescriptor(srcbytes ÷ sizeof(T2), T2)
vecDesc = MPSVectorDescriptor(bytesize ÷ sizeof(T2), T2)
mpsdest = MPSVector(dest, vecDesc)
encode!(cmdbuf, randkern, mpsdest)
end
else
MTLCommandBuffer(queue) do cmdbuf
len = UInt(ceil(srcbytes / sizeof(T2)) * 4)
len = UInt(ceil(bytesize / sizeof(T2)) * 4)
vecDesc = MPSVectorDescriptor(len, T2)
tempVec = MPSTemporaryVector(cmdbuf, vecDesc)
encode!(cmdbuf, randkern, tempVec)
MTLBlitCommandEncoder(cmdbuf) do enc
MTL.append_copy!(enc, dest.data[], byteoffset, tempVec.data, tempVec.offset, srcbytes)
MTL.append_copy!(enc, dest.data[], byteoffset, tempVec.data, tempVec.offset, bytesize)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions lib/mps/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,5 @@ Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMo
Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# scalars
Random.rand(rng::RNG, T::UniformType=Float32; storage=Shared) = rand(rng, T, 1; storage)[]
Random.randn(rng::RNG, T::NormalType=Float32; storage=Shared) = randn(rng, T, 1; storage)[]
Random.rand(rng::RNG, T::UniformType=Float32; storage=Shared) = rand(rng, T, 4; storage)[1]
Random.randn(rng::RNG, T::NormalType=Float32; storage=Shared) = randn(rng, T, 4; storage)[1]
44 changes: 22 additions & 22 deletions src/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,37 @@ end

# Use MPS random functionality where possible
function Random.rand!(A::MPS.UniformArray)
if can_use_mpsrandom(A)
@inline Random.rand!(mpsrand_rng(), A)
else
@inline Random.rand!(gpuarrays_rng(), A)
end
return A
rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng()
return Random.rand!(rng, A)
end
function Random.randn!(A::MPS.NormalArray)
if can_use_mpsrandom(A)
@inline Random.randn!(mpsrand_rng(), A)
else
@inline Random.randn!(gpuarrays_rng(), A)
end
return A
rng = can_use_mpsrandom(A) ? mpsrand_rng() : gpuarrays_rng()
return Random.randn!(rng, A)
end

# GPUArrays out-of-place
rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
function rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode)
rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
end
function randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode)
rng = prod(dims) * sizeof(T) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
end
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))

# support all dimension specifications
rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
function rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
end
function randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode)
rng = (dim1 * prod(dims) * sizeof(T)) % 4 == 0 ? mpsrand_rng() : gpuarrays_rng()
return Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
end

rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
Expand All @@ -59,8 +59,8 @@ randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# scalars
rand(T::Type=Float32; storage=Shared) = rand(T, 1; storage)[]
randn(T::Type=Float32; storage=Shared) = randn(T, 1; storage)[]
rand(T::Type=Float32; storage=Shared) = rand(T, 4; storage)[1]
randn(T::Type=Float32; storage=Shared) = randn(T, 4; storage)[1]

# seeding
function seed!(seed=Base.rand(UInt64))
Expand Down
69 changes: 39 additions & 30 deletions test/random.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Random
using Metal: can_use_mpsrandom

const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
UInt64]
Expand Down Expand Up @@ -29,8 +30,12 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
# specified MPS rng
if T != Float16
fill!(A, T(0))
f(rng, A)
@test !iszero(collect(A))
if can_use_mpsrandom(A)
f(rng, A)
@test !iszero(collect(A))
else
@test_throws "Destination buffer" f(rng, A)
end
end
end

Expand All @@ -45,8 +50,12 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];
# specified MPS rng
if T != Float16
fill!(A, T(0))
f(rng, A)
@test Array(A) == fill(1, 0)
if can_use_mpsrandom(A)
f(rng, A)
@test Array(A) == fill(1, 0)
else
@test_throws "Destination buffer" f(rng, A)
end
end
end
end
Expand Down Expand Up @@ -125,32 +134,33 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];

## Offset > 0
fill!(A, T(0))
idx = 4:51
idx = 4:50
view_A = @view A[idx]

# Errors in Julia before crashing whole process
if view_A.offset * sizeof(T) % 4 != 0
@test_throws "Destination buffer offset ($(view_A.offset*sizeof(T)))" f(rng, view_A)
else
if can_use_mpsrandom(view_A)
f(rng, view_A)

cpuA = collect(A)
@test !iszero(cpuA[idx])

@test iszero(cpuA[1:100 .∉ Ref(idx)]) broken=(sizeof(view_A) % 4 != 0)
else
@test_throws "Destination buffer" f(rng, view_A)
end

## Offset == 0
fill!(A, T(0))
idx = 1:51
view_A = @view A[idx]
f(rng, view_A)

cpuA = collect(A)
@test !iszero(cpuA[idx])
if can_use_mpsrandom(view_A)
f(rng, view_A)

# XXX: Why are the 8-bit and 16-bit type tests not broken?
@test iszero(cpuA[1:100 .∉ Ref(idx)])# broken=(sizeof(view_A) % 4 != 0)
cpuA = collect(A)
@test !iszero(cpuA[idx])
@test iszero(cpuA[1:100 .∉ Ref(idx)])
else
@test_throws "Destination buffer" f(rng, view_A)
end
end
end
end
Expand Down Expand Up @@ -196,8 +206,12 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];

# specified MPS rng
if T != Float16
B = fr(rng, args...)
@test eltype(B) == T
if length(zeros(args...)) * sizeof(T) % 4 == 0
B = fr(rng, args...)
@test eltype(B) == T
else
@test_throws "Destination buffer" fr(rng, args...)
end
end
end

Expand All @@ -212,24 +226,19 @@ const OOPLACE_TUPLES = [[(Metal.rand, rand, T) for T in RAND_TYPES];

## CPU Arrays with MPS rng
@testset "CPU Arrays" begin
MPS_TUPLES = filter(INPLACE_TUPLES) do tup
mps_tuples = filter(INPLACE_TUPLES) do tup
tup[2] != Float16
end
rng = Metal.MPS.RNG()
@testset "$f with $T" for (f, T) in MPS_TUPLES

@testset "$f with $T" for (f, T) in mps_tuples
@testset "$d" for d in (1, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16), (1000,), (1000,1000))
A = zeros(T, d)
f(rng, A)
@test !iszero(collect(A))
end

@testset "0" begin
A = rand(T, 0)
b = rand(T)
fill!(A, b)
@test A isa Array{T,1}
@test Array(A) == fill(b, 0)
if (prod(d) * sizeof(T)) % 4 == 0
f(rng, A)
@test !iszero(collect(A))
else
@test_throws "Destination buffer" f(rng, A)
end
end
end
end
Expand Down

0 comments on commit 0ce4936

Please sign in to comment.