Skip to content

Commit

Permalink
Support rand! and rand using MPS where appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Apr 1, 2024
1 parent 298ded2 commit 40e640e
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 55 deletions.
1 change: 1 addition & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("vector.jl")
include("matrixrandom.jl")

# integrations
include("random.jl")
include("linalg.jl")

# decompositions
Expand Down
69 changes: 69 additions & 0 deletions lib/mps/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using Random

"""
MPS.RNG()
A random number generator using `rand()` in a device kernel.
"""
mutable struct RNG <: AbstractRNG
seed::UInt
counter::UInt32

function RNG(seed::Integer)
new(seed%UInt, 0)
end
RNG(seed::UInt, counter::UInt32) = new(seed, counter)
end

make_seed() = Base.rand(RandomDevice(), UInt)

RNG() = RNG(make_seed())

Base.copy(rng::RNG) = RNG(rng.seed, rng.counter)
Base.hash(rng::RNG, h::UInt) = hash(rng.seed, hash(rng.counter, h))
Base.:(==)(a::RNG, b::RNG) = (a.seed == b.seed) && (a.counter == b.counter)

function Random.seed!(rng::RNG, seed::Integer)
rng.seed = seed % UInt
rng.counter = 0
end

Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())

@inline function update_state!(rng::RNG, len)
new_counter = Int64(rng.counter) + len
overflow, remainder = fldmod(new_counter, typemax(UInt32))
rng.seed += overflow # XXX: is this OK?
rng.counter = remainder
return rng
end

const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}()
function default_rng()
dev = current_device()
get!(GLOBAL_RNGs, dev) do
RNG()
end
end

function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
mpsvecormat = MPSVector(A, UInt32)
_mpsmat_rand!(mpsvecormat, seed = rng.seed + rng.counter)

update_state!(rng,length(A))
return A
end
function Random.rand!(rng::RNG, A::MtlArray{Float32})
mpsvecormat = MPSVector(A, Float32)
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomUniformDistributionDescriptor(0, 1), seed = rng.seed + rng.counter)

update_state!(rng,length(A))
return A
end
function Random.randn!(rng::RNG, A::MtlArray{Float32})
mpsvecormat = MPSVector(A, Float32)
_mpsmat_rand!(mpsvecormat; desc=MPSMatrixRandomNormalDistributionDescriptor(0, 1), seed = rng.seed + rng.counter)

update_state!(rng,length(A))
return A
end
23 changes: 20 additions & 3 deletions lib/mps/vector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ function MPSVectorDescriptor(length, vectors, vectorBytes, dataType)
return obj
end

function vectorBytesForLength(length, dataType)
@objc [MPSVectorDescriptor vectorBytesForLength:length::NSUInteger
dataType:dataType::MPSDataType]::NSUInteger
end

export MPSVector

Expand All @@ -48,9 +52,16 @@ end
Metal vector representation used in Performance Shaders.
"""
function MPSVector(arr::MtlVector{T}) where T
len = length(arr)
desc = MPSVectorDescriptor(len, T)
MPSVector(arr::MtlVector{T}) where T = mpsvector(arr, T, length(arr))

# For rand! and randn!
function MPSVector(arr::MtlArray{T}, ::Type{T2}) where {T,T2}
len = UInt(ceil(length(arr) * sizeof(T) / sizeof(T2) / 4) * 4)
return mpsvector(arr, T2, len)
end

@inline function mpsvector(arr::MtlArray{T}, ::Type{T2}, len) where {T,T2}
desc = MPSVectorDescriptor(len, T2)
vec = @objc [MPSVector alloc]::id{MPSVector}
obj = MPSVector(vec)
offset = arr.offset * sizeof(T)
Expand All @@ -61,6 +72,12 @@ function MPSVector(arr::MtlVector{T}) where T
return obj
end

resourceSize(vecormat::M) where {M<:Union{MPSVector,MPSMatrix}} =
@objc [vecormat::id{M} resourceSize]::NSUInteger

synchronizeOnCommandBuffer(vecormat::M, cmdBuf::MTLCommandBuffer) where {M<:Union{MPSVector,MPSMatrix}} =
@objc [vecormat::id{M} synchronizeOnCommandBuffer:cmdBuf::id{MTLCommandBuffer}]::Nothing

#
# matrix vector multiplication
#
Expand Down
12 changes: 8 additions & 4 deletions lib/mtl/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ end


## allocation
const BUFFER_ALIGNMENT_FOR_RAND::Int = 16
@inline bufferbytesize(bytesize::T) where {T <: Integer} = ceil(T, bytesize / BUFFER_ALIGNMENT_FOR_RAND) * T(BUFFER_ALIGNMENT_FOR_RAND)

function MTLBuffer(dev::Union{MTLDevice,MTLHeap}, bytesize::Integer;
storage=Private, hazard_tracking=DefaultTracking,
cache_mode=DefaultCPUCache)
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode

@assert 0 < bytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
ptr = alloc_buffer(dev, bytesize, opts)
realbytesize = bufferbytesize(bytesize)
@assert 0 < realbytesize <= dev.maxBufferLength # XXX: not supported by MTLHeap
ptr = alloc_buffer(dev, realbytesize, opts)

return MTLBuffer(ptr)
end
Expand All @@ -39,8 +42,9 @@ function MTLBuffer(dev::MTLDevice, bytesize::Integer, ptr::Ptr;
storage == Private && error("Can't create a Private copy-allocated buffer.")
opts = convert(MTLResourceOptions, storage) | hazard_tracking | cache_mode

@assert 0 < bytesize <= dev.maxBufferLength
ptr = alloc_buffer(dev, bytesize, opts, ptr)
realbytesize = bufferbytesize(bytesize)
@assert 0 < realbytesize <= dev.maxBufferLength
ptr = alloc_buffer(dev, realbytesize, opts, ptr)

return MTLBuffer(ptr)
end
Expand Down
8 changes: 4 additions & 4 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,17 @@ include("compiler/compilation.jl")
include("compiler/execution.jl")
include("compiler/reflection.jl")

# libraries
include("../lib/mps/MPS.jl")
export MPS

# array implementation
include("utilities.jl")
include("broadcast.jl")
include("mapreduce.jl")
include("random.jl")
include("gpuarrays.jl")

# libraries
include("../lib/mps/MPS.jl")
export MPS

# KernelAbstractions
include("MetalKernels.jl")
import .MetalKernels: MetalBackend
Expand Down
66 changes: 58 additions & 8 deletions src/random.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,74 @@
using Random
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor,
MPSMatrixRandomNormalDistributionDescriptor

gpuarrays_rng() = GPUArrays.default_rng(MtlArray)
mpsrand_rng() = MPS.default_rng()

# GPUArrays in-place
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A)
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A)

@inline function usempsrandom(A::MtlArray{T}) where {T}
return (A.offset == 0 &&
(length(A) * sizeof(T) % MTL.BUFFER_ALIGNMENT_FOR_RAND == 0))
end

# Use MPS random functionality where possible
function Random.rand!(A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
if usempsrandom(A)
@inline Random.rand!(gpuarrays_rng(), A)
else
@inline Random.rand!(gpuarrays_rng(), A)
end
return A
end
function Random.rand!(A::MtlArray{Float32})
if usempsrandom(A)
@inline Random.rand!(mpsrand_rng(), A)
else
@inline Random.rand!(gpuarrays_rng(), A)
end
return A
end
function Random.randn!(A::MtlArray{Float32})
if usempsrandom(A)
@inline Random.randn!(mpsrand_rng(), A)
else
@inline Random.randn!(gpuarrays_rng(), A)
end
return A
end

# GPUArrays out-of-place
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...)
rand(::Type{T}, dims::Dims; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(::Type{Float32}, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims),storage}(undef, dims...))
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))

# support all dimension specifications
rand(::Type{T}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64,Float32}} =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(::Type{Float32}, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) =
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))

# untyped out-of-place
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# seeding
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed)
function seed!(seed=Base.rand(UInt64))
Random.seed!(gpuarrays_rng(), seed)
Random.seed!(mpsrand_rng(), seed)
end
4 changes: 2 additions & 2 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ dev = first(devices())

buf = MTLBuffer(dev, 8; storage=Shared)

@test buf.length == 8
@test sizeof(buf) == 8
@test buf.length == 16
@test sizeof(buf) == 16

# MTLResource properties
@test buf.device == dev
Expand Down
2 changes: 1 addition & 1 deletion test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ end
buf_a = MtlArray{input_jl_type}(arr_a)
buf_b = MtlArray{input_jl_type}(arr_b)
buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))

truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
for i in 1:batch_size
@views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
Expand Down
Loading

0 comments on commit 40e640e

Please sign in to comment.