Skip to content

Commit

Permalink
Support rand! and rand using MPS where appropriate
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Apr 16, 2024
1 parent aedc0c2 commit cef60f4
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 43 deletions.
44 changes: 44 additions & 0 deletions docs/src/usage/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
```@meta
DocTestSetup = quote
using Metal
import Random
Random.seed!(0)
Metal.seed!(0)
end
```

Expand Down Expand Up @@ -106,3 +111,42 @@ julia> Base.mapreducedim!(identity, +, b, a)
1×1 MtlMatrix{Float32, Private}:
6.0
```

## Random numbers

Base's convenience functions for generating random numbers are available in Metal as well:

```jldoctest
julia> Metal.rand(2)
2-element MtlVector{Float32, Private}:
0.39904642
0.8805201
julia> Metal.randn(Float32, 2, 1)
2×1 MtlMatrix{Float32, Private}:
-0.18797699
-0.006818078
```

Behind the scenes, these random numbers come from two different generators: one backed by
[Metal Performance Shaders](https://developer.apple.com/documentation/metalperformanceshaders/mpsmatrixrandom?language=objc),
another by using the GPUArrays.jl random methods. Operations on these generators are implemented using methods from the Random
standard library:

```jldoctest
julia> using Random, GPUArrays
julia> a = Random.rand(MPS.default_rng(), Float32, 1)
1-element MtlVector{Float32, Private}:
0.39904642
julia> a = Random.rand!(GPUArrays.default_rng(MtlArray), a)
1-element MtlVector{Float32, Private}:
0.13394515
```

!!! note
`MPSMatrixRandom` functionality requires Metal.jl > v1.1

!!! warning
Do not use `Random.rand!(::MPS.RNG, args...)` or `Random.randn!(::MPS.RNG, args...)` on views as you will most likely overwrite values outside of the view due to limitations in random number generation in the Metal Performance Shaders framework.
1 change: 1 addition & 0 deletions lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("vector.jl")
include("matrixrandom.jl")

# integrations
include("random.jl")
include("linalg.jl")

# decompositions
Expand Down
109 changes: 109 additions & 0 deletions lib/mps/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using Random
using Metal: DefaultStorageMode

"""
MPS.RNG()
A random number generator using `rand()` in a device kernel.
"""
mutable struct RNG <: AbstractRNG
device::MTLDevice
uniformInteger::MPSMatrixRandomPhilox
uniformFloat32::MPSMatrixRandomPhilox
normalFloat32::MPSMatrixRandomPhilox
end


make_seed() = Base.rand(RandomDevice(), UInt)

function RNG(device::MTLDevice, seed::Integer)
seed = seed%UInt
RNG(device,
MPSMatrixRandomPhilox(device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()),
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)),
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)),)
end
@autoreleasepool RNG(seed::Integer) = RNG(current_device(), seed)
RNG(device::MTLDevice) = RNG(device, make_seed())

@autoreleasepool RNG() = RNG(current_device(), make_seed())

Base.copy(rng::RNG) = RNG(copy(rng.device), copy(rng.uniformInteger), copy(rng.uniformFloat32), copy(rng.normalFloat32))

@autoreleasepool function Random.seed!(rng::RNG, seed::Integer)
rng.uniformInteger = MPSMatrixRandomPhilox(rng.device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor())
rng.uniformFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1))
rng.normalFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1))
return rng
end

Random.seed!(rng::RNG) = Random.seed!(rng, make_seed())

const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}()
@autoreleasepool function default_rng()
dev = current_device()
get!(GLOBAL_RNGs, dev) do
RNG(dev)
end
end

const UniformTypes = [Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64]
const UniformType = Union{[Type{T} for T in UniformTypes]...}
const UniformArray = MtlArray{<:Union{Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}}
isempty(A) && return A
_mpsmat_rand!(rng.uniformInteger, A, UInt32)
return A
end

@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{Float32})
isempty(A) && return A
_mpsmat_rand!(rng.uniformFloat32, A, Float32)
return A
end

const NormalType = Type{Float32}
const NormalArray = MtlArray{<:Float32}
@autoreleasepool function Random.randn!(rng::RNG, A::MtlArray{Float32})
isempty(A) && return A
_mpsmat_rand!(rng.normalFloat32, A, Float32)
return A
end

# CPU arrays
function Random.rand!(rng::RNG, A::AbstractArray{T,N}) where {T <: Union{UniformTypes...}, N}
isempty(A) && return A
B = MtlArray{T,N,Shared}(undef, size(A))
rand!(rng, B)
copyto!(A, unsafe_wrap(Array{T},B))
return A
end
function Random.randn!(rng::RNG, A::AbstractArray{T,N}) where {T <: Float32, N}
isempty(A) && return A
B = MtlArray{T,N,Shared}(undef, size(A))
randn!(rng, B)
copyto!(A, unsafe_wrap(Array{T},B))
return A
end

# Out of place
Random.rand(rng::RNG, T::UniformType, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...))
Random.randn(rng::RNG, T::NormalType, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...))

# support all dimension specifications
Random.rand(rng::RNG, T::UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
Random.randn(rng::RNG, T::NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))

# untyped out-of-place
Random.rand(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# scalars
Random.rand(rng::RNG, T::UniformType=Float32; storage=Shared) = rand(rng, T, 1; storage)[]
Random.randn(rng::RNG, T::NormalType=Float32; storage=Shared) = randn(rng, T, 1; storage)[]
61 changes: 53 additions & 8 deletions src/random.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,69 @@
using Random
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor,
MPSMatrixRandomNormalDistributionDescriptor

gpuarrays_rng() = GPUArrays.default_rng(MtlArray)
mpsrand_rng() = MPS.default_rng()

# GPUArrays in-place
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A)
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A)

@inline function can_use_mpsrandom(A::MtlArray{T}) where {T}
return A.offset * sizeof(T) % 4 == 0 && sizeof(A) % 4 == 0
end

# Use MPS random functionality where possible
function Random.rand!(A::MPS.UniformArray)
if can_use_mpsrandom(A)
@inline Random.rand!(mpsrand_rng(), A)
else
@inline Random.rand!(gpuarrays_rng(), A)
end
return A
end
function Random.randn!(A::MPS.NormalArray)
if can_use_mpsrandom(A)
@inline Random.randn!(mpsrand_rng(), A)
else
@inline Random.randn!(gpuarrays_rng(), A)
end
return A
end

# GPUArrays out-of-place
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...)
rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
rand(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))
randn(T::Type, dims::Dims; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...))

# support all dimension specifications
rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))

rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) =
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...))

# untyped out-of-place
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...)
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) =
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...))

# scalars
rand(T::Type=Float32; storage=Shared) = rand(T, 1; storage)[]
randn(T::Type=Float32; storage=Shared) = randn(T, 1; storage)[]

# seeding
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed)
function seed!(seed=Base.rand(UInt64))
Random.seed!(gpuarrays_rng(), seed)
Random.seed!(mpsrand_rng(), seed)
end
Loading

0 comments on commit cef60f4

Please sign in to comment.