Skip to content

Commit

Permalink
Support rand! and rand using MPS where appropriate
Browse files Browse the repository at this point in the history
Also add tests
  • Loading branch information
christiangnrd committed Apr 1, 2024
1 parent 57d1935 commit 16da0c0
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 55 deletions.
1 change: 1 addition & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("vector.jl")
include("matrixrandom.jl")

# integrations
include("random.jl")
include("linalg.jl")

# decompositions
Expand Down
7 changes: 7 additions & 0 deletions lib/mps/matrixrandom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ end
synchronizeStateOnCommandBuffer(kern::MPSMatrixRandomMTGP32, cmdBuf::MTLCommandBuffer) =
@objc [obj::id{MPSMatrixRandomMTGP32} synchronizeStateOnCommandBuffer:cmdBuf::id{MTLCommandBuffer}]::Nothing


# For rand! and randn!
function _mpsvector_rand(arr::MtlArray{T}, ::Type{T2}) where {T,T2}
len = UInt(ceil(length(arr) * sizeof(T) / sizeof(T2) / 4) * 4)
return mpsvector(arr, T2, len)
end

@inline function _mpsmat_rand!(mpsvecormat::Union{MPSMatrix,MPSVector};
desc::MPSMatrixRandomDistributionDescriptor = MPSMatrixRandomDistributionDescriptor(),
cmdBuf::MTLCommandBuffer = MTLCommandBuffer(global_queue(current_device())),
Expand Down
69 changes: 69 additions & 0 deletions lib/mps/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Random

"""
MPS.RNG()
A random number generator using `rand()` in a device kernel.
"""
mutable struct RNG <: AbstractRNG
seed::UInt
counter::UInt32

function RNG(seed::Integer)
new(seed%UInt, 0)
end
RNG(seed::UInt, counter::UInt32) = new(seed, counter)
end

make_seed() = Base.rand(RandomDevice(), UInt)

RNG() = RNG(make_seed())

Base.copy(rng::RNG) = RNG(rng.seed, rng.counter)
Base.hash(rng::RNG, h::UInt) = hash(rng.seed, hash(rng.counter, h))
Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter)

function Random.seed!(rng::RNG, seed::Integer)
rng.seed = seed % UInt
rng.counter = 0
end

Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())

@inline function update_state!(rng::RNG, len)
new_counter = Int64(rng.counter) + len
overflow, remainder = fldmod(new_counter, typemax(UInt32))
rng.seed += overflow # XXX: is this OK?
rng.counter = remainder
return rng
end

const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}()
function default_rng()
dev = current_device()
get!(GLOBAL_RNGs, dev) do
RNG()
end
end

function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
mpsvecormat = _mpsvector_rand(A, UInt32)
_mpsmat_rand!(mpsvecormat, seed = rng.seed + rng.counter)

update_state!(rng,length(A))
return A
end
function Random.rand!(rng::RNG, A::MtlArray{Float32})
mpsvecormat = _mpsvector_rand(A, Float32)
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomUniformDistributionDescriptor(0, 1), seed = rng.seed + rng.counter)

update_state!(rng,length(A))
return A
end
function Random.randn!(rng::RNG, A::MtlArray{Float32})
mpsvecormat = _mpsvector_rand(A, Float32)
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomNormalDistributionDescriptor(0, 1), seed = rng.seed + rng.counter)

update_state!(rng,length(A))
return A
end
7 changes: 4 additions & 3 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ end
Metal vector representation used in Performance Shaders.
"""
function MPSVector(arr::MtlVector{T}) where T
len = length(arr)
desc = MPSVectorDescriptor(len, T)
MPSVector(arr::MtlVector{T}) where T = mpsvector(arr, T, length(arr))

@inline function mpsvector(arr::MtlArray{T}, ::Type{T2}, len) where {T,T2}
desc = MPSVectorDescriptor(len, T2)
vec = @objc [MPSVector alloc]::id{MPSVector}
obj = MPSVector(vec)
offset = arr.offset * sizeof(T)
Expand Down
12 changes: 8 additions & 4 deletions lib/mtl/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ end


## allocation
const BUFFER_ALIGNMENT_FOR_RAND::Int = 16
@inline bufferbytesize(bytesize::T) where {T <: Integer} = ceil(T, bytesize / BUFFER_ALIGNMENT_FOR_RAND) * T(BUFFER_ALIGNMENT_FOR_RAND)

function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer;
storage=Private, hazard_tracking=DefaultTracking,
cache_mode=DefaultCPUCache)
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode

@assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
ptr = alloc_buffer(dev, bytesize, opts)
realbytesize = bufferbytesize(bytesize)
@assert 0 < realbytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
ptr = alloc_buffer(dev, realbytesize, opts)

return MTLBuffer(ptr)
end
Expand All @@ -39,8 +42,9 @@ function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr;
storage == Private && error("Can't create a Private copy-allocated buffer.")
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode

@assert 0 < bytesize <= dev.maxBufferLength
ptr = alloc_buffer(dev, bytesize, opts, ptr)
realbytesize = bufferbytesize(bytesize)
@assert 0 < realbytesize <= dev.maxBufferLength
ptr = alloc_buffer(dev, realbytesize, opts, ptr)

return MTLBuffer(ptr)
end
Expand Down
8 changes: 4 additions & 4 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ include("compiler/compilation.jl")
include("compiler/execution.jl")
include("compiler/reflection.jl")

# libraries
include("../lib/mps/MPS.jl")
export MPS

# array implementation
include("utilities.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("random.jl")
include("gpuarrays.jl")

# libraries
include("../lib/mps/MPS.jl")
export MPS

# KernelAbstractions
include("MetalKernels.jl")
import .MetalKernels: MetalBackend
Expand Down
66 changes: 58 additions & 8 deletions src/random.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,74 @@
using Random
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor,
MPSMatrixRandomNormalDistributionDescriptor

gpuarrays_rng() = GPUArrays.default_rng(MtlArray)
mpsrand_rng() = MPS.default_rng()

# GPUArrays in-place
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A)
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A)

@inline function usempsrandom(A::MtlArray{T}) where {T}
return (A.offset == 0 &&
(length(A) * sizeof(T) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0))
end

# Use MPS random functionality where possible
function Random.rand!(A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
if usempsrandom(A)
@inline Random.rand!(gpuarrays_rng(), A)
else
@inline Random.rand!(gpuarrays_rng(), A)
end
return A
end
function Random.rand!(A::MtlArray{Float32})
if usempsrandom(A)
@inline Random.rand!(mpsrand_rng(), A)
else
@inline Random.rand!(gpuarrays_rng(), A)
end
return A
end
function Random.randn!(A::MtlArray{Float32})
if usempsrandom(A)
@inline Random.randn!(mpsrand_rng(), A)
else
@inline Random.randn!(gpuarrays_rng(), A)
end
return A
end

# GPUArrays out-of-place
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...)
rand(::Type{T}, dims::Dims; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(::Type{Float32}, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims),storage}(undef, dims...))
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))

# support all dimension specifications
rand(::Type{T}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(::Type{Float32}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) =
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))

# untyped out-of-place
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# seeding
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed)
function seed!(seed=Base.rand(UInt64))
Random.seed!(gpuarrays_rng(), seed)
Random.seed!(mpsrand_rng(), seed)
end
4 changes: 2 additions & 2 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ dev = first(devices())

buf = MTLBuffer(dev, 8; storage=Shared)

@test buf.length == 8
@test sizeof(buf) == 8
@test buf.length == 16
@test sizeof(buf) == 16

# MTLResource properties
@test buf.device == dev
Expand Down
2 changes: 1 addition & 1 deletion test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
buf_a = MtlArray{input_jl_type}(arr_a)
buf_b = MtlArray{input_jl_type}(arr_b)
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))

truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
for i in 1:batch_size
@views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
Expand Down
129 changes: 96 additions & 33 deletions test/random.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,102 @@
using Random

const RAND_TYPES = [Float16, Float32, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64,
UInt64]
const RANDN_TYPES = [Float16, Float32]
const INPLACE_TUPLES = [[(rand!, T) for T in RAND_TYPES];
[(randn!, T) for T in RANDN_TYPES]]
const OOPLACE_TUPLES = [[(Metal.rand, T) for T in RAND_TYPES];
[(Metal.randn, T) for T in RANDN_TYPES];
[(rand, T) for T in RAND_TYPES];
[(randn, T) for T in RANDN_TYPES]]

@testset "rand" begin
# in-place
@testset "in-place" begin
@testset "$f with $T" for (f, T) in INPLACE_TUPLES
@testset "$d" for d in (1, 3, (3, 3), (3, 3, 3), 16, (16, 16), (16, 16, 16))
A = MtlArray{T}(undef, d)
fill!(A, T(0))
f(A)
@test Metal.usempsrandom(A) ==
((prod(d) * sizeof(T)) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0)
@test !iszero(collect(A))
end
end
end

# in-place contiguous views
@testset "in-place for views" begin
@testset "$f with $T" for (f, T) in INPLACE_TUPLES
alen = 100
A = MtlArray{T}(undef, alen)
function test_view!(X::MtlArray{T}, idx; shouldusemps) where {T}
fill!(X, T(0))
view_X = @view X[idx]
f(view_X)
cpuX = collect(X)
@test Metal.usempsrandom(view_X) == shouldusemps
@test !iszero(cpuX[idx])
@test iszero(cpuX[1:alen .∉ Ref(idx)])
return
end

# Test when view offset is 0 and buffer size not multiple of 16
@testset "Off == 0, buf % 16 != 0" begin
test_view!(A, 1:51; shouldusemps=false)
end

# Test when view offset is 0 and buffer size is multiple of 16
@testset "Off == 0, buf % 16 == 0" begin
test_view!(A, 1:32; shouldusemps=true)
end

# Test when view offset is not 0 nor multiple of 16 and buffer size not multiple of 16
@testset "Off != 0, buf % 16 != 0" begin
test_view!(A, 3:51; shouldusemps=false)
end

# Test when view offset is multiple of 16 and buffer size not multiple of 16
@testset "Off % 16 == 0, buf % 16 != 0" begin
test_view!(A, 17:51; shouldusemps=false)
end

# in-place
for (f,T) in ((rand!,Float16),
(rand!,Float32),
(randn!,Float16),
(randn!,Float32)),
d in (2, (2,2), (2,2,2), 3, (3,3), (3,3,3))
A = MtlArray{T}(undef, d)
fill!(A, T(0))
f(A)
@test !iszero(collect(A))
end

# out-of-place, with implicit type
for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32)),
args in ((2,), (2, 2), (3,), (3, 3))
A = f(args...)
@test eltype(A) == T
end

# out-of-place, with type specified
for (f,T) in ((Metal.rand,Float32), (Metal.randn,Float32),
(rand,Float32), (randn,Float32)),
args in ((T, 2), (T, 2, 2), (T, (2, 2)), (T, 3), (T, 3, 3), (T, (3, 3)))
A = f(args...)
@test eltype(A) == T
end

## seeding
Metal.seed!(1)
a = Metal.rand(Int32, 1)
Metal.seed!(1)
b = Metal.rand(Int32, 1)
@test iszero(collect(a) - collect(b))
# Test when view offset is multiple of 16 and buffer size multiple of 16
@testset "Off % 16 == 0, buf % 16 == 0" begin
test_view!(A, 17:32; shouldusemps=false)
end
end
end
# out-of-place, with implicit type
@testset "out-of-place" begin
@testset "$f with implicit type" for (f, T) in
((Metal.rand, Float32), (Metal.randn, Float32))
@testset "args" for args in ((1,), (3,), (3, 3), (16,), (16, 16))
A = f(args...)
@test eltype(A) == T
end
end

# out-of-place, with type specified
@testset "$f with $T" for (f, T) in OOPLACE_TUPLES
@testset "$args" for args in ((T, 1),
(T, 3),
(T, 3, 3),
(T, (3, 3)),
(T, 16),
(T, 16, 16),
(T, (16, 16)))
A = f(args...)
@test eltype(A) == T
end
end
end
## seeding
@testset "Seeding" begin
Metal.seed!(1)
a = Metal.rand(Int32, 1)
Metal.seed!(1)
b = Metal.rand(Int32, 1)
@test iszero(collect(a) - collect(b))
end
end # testset

0 comments on commit 16da0c0

Please sign in to comment.