-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support rand! and rand using MPS where appropriate
- Loading branch information
1 parent
aedc0c2
commit cef60f4
Showing
5 changed files
with
454 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
using Random | ||
using Metal: DefaultStorageMode | ||
|
||
""" | ||
MPS.RNG() | ||
A random number generator using `rand()` in a device kernel. | ||
""" | ||
mutable struct RNG <: AbstractRNG | ||
device::MTLDevice | ||
uniformInteger::MPSMatrixRandomPhilox | ||
uniformFloat32::MPSMatrixRandomPhilox | ||
normalFloat32::MPSMatrixRandomPhilox | ||
end | ||
|
||
|
||
make_seed() = Base.rand(RandomDevice(), UInt) | ||
|
||
function RNG(device::MTLDevice, seed::Integer) | ||
seed = seed%UInt | ||
RNG(device, | ||
MPSMatrixRandomPhilox(device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()), | ||
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)), | ||
MPSMatrixRandomPhilox(device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)),) | ||
end | ||
@autoreleasepool RNG(seed::Integer) = RNG(current_device(), seed) | ||
RNG(device::MTLDevice) = RNG(device, make_seed()) | ||
|
||
@autoreleasepool RNG() = RNG(current_device(), make_seed()) | ||
|
||
Base.copy(rng::RNG) = RNG(copy(rng.device), copy(rng.uniformInteger), copy(rng.uniformFloat32), copy(rng.normalFloat32)) | ||
|
||
@autoreleasepool function Random.seed!(rng::RNG, seed::Integer) | ||
rng.uniformInteger = MPSMatrixRandomPhilox(rng.device, UInt32, seed, MPSMatrixRandomDefaultDistributionDescriptor()) | ||
rng.uniformFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomUniformDistributionDescriptor(0, 1)) | ||
rng.normalFloat32 = MPSMatrixRandomPhilox(rng.device, Float32, seed, MPSMatrixRandomNormalDistributionDescriptor(0, 1)) | ||
return rng | ||
end | ||
|
||
Random.seed!(rng::RNG) = Random.seed!(rng, make_seed()) | ||
|
||
const GLOBAL_RNGs = Dict{MTLDevice,MPS.RNG}() | ||
@autoreleasepool function default_rng() | ||
dev = current_device() | ||
get!(GLOBAL_RNGs, dev) do | ||
RNG(dev) | ||
end | ||
end | ||
|
||
const UniformTypes = [Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64] | ||
const UniformType = Union{[Type{T} for T in UniformTypes]...} | ||
const UniformArray = MtlArray{<:Union{Float32,UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} | ||
@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{T}) where {T<:Union{UInt8,Int8,UInt16,Int16,UInt32,Int32,UInt64,Int64}} | ||
isempty(A) && return A | ||
_mpsmat_rand!(rng.uniformInteger, A, UInt32) | ||
return A | ||
end | ||
|
||
@autoreleasepool function Random.rand!(rng::RNG, A::MtlArray{Float32}) | ||
isempty(A) && return A | ||
_mpsmat_rand!(rng.uniformFloat32, A, Float32) | ||
return A | ||
end | ||
|
||
const NormalType = Type{Float32} | ||
const NormalArray = MtlArray{<:Float32} | ||
@autoreleasepool function Random.randn!(rng::RNG, A::MtlArray{Float32}) | ||
isempty(A) && return A | ||
_mpsmat_rand!(rng.normalFloat32, A, Float32) | ||
return A | ||
end | ||
|
||
# CPU arrays | ||
function Random.rand!(rng::RNG, A::AbstractArray{T,N}) where {T <: Union{UniformTypes...}, N} | ||
isempty(A) && return A | ||
B = MtlArray{T,N,Shared}(undef, size(A)) | ||
rand!(rng, B) | ||
copyto!(A, unsafe_wrap(Array{T},B)) | ||
return A | ||
end | ||
function Random.randn!(rng::RNG, A::AbstractArray{T,N}) where {T <: Float32, N} | ||
isempty(A) && return A | ||
B = MtlArray{T,N,Shared}(undef, size(A)) | ||
randn!(rng, B) | ||
copyto!(A, unsafe_wrap(Array{T},B)) | ||
return A | ||
end | ||
|
||
# Out of place | ||
Random.rand(rng::RNG, T::UniformType, dims::Dims; storage=DefaultStorageMode) = | ||
Random.rand!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) | ||
Random.randn(rng::RNG, T::NormalType, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(rng, MtlArray{T,length(dims),storage}(undef, dims...)) | ||
|
||
# support all dimension specifications | ||
Random.rand(rng::RNG, T::UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
Random.randn(rng::RNG, T::NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(rng, MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# untyped out-of-place | ||
Random.rand(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
Random.randn(rng::RNG, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(rng, MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# scalars | ||
Random.rand(rng::RNG, T::UniformType=Float32; storage=Shared) = rand(rng, T, 1; storage)[] | ||
Random.randn(rng::RNG, T::NormalType=Float32; storage=Shared) = randn(rng, T, 1; storage)[] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,69 @@ | ||
using Random | ||
using ..MPS: MPSVector, _mpsmat_rand!, MPSMatrixRandomUniformDistributionDescriptor, | ||
MPSMatrixRandomNormalDistributionDescriptor | ||
|
||
gpuarrays_rng() = GPUArrays.default_rng(MtlArray) | ||
mpsrand_rng() = MPS.default_rng() | ||
|
||
# GPUArrays in-place | ||
Random.rand!(A::MtlArray) = Random.rand!(gpuarrays_rng(), A) | ||
Random.randn!(A::MtlArray) = Random.randn!(gpuarrays_rng(), A) | ||
|
||
@inline function can_use_mpsrandom(A::MtlArray{T}) where {T} | ||
return A.offset * sizeof(T) % 4 == 0 && sizeof(A) % 4 == 0 | ||
end | ||
|
||
# Use MPS random functionality where possible | ||
function Random.rand!(A::MPS.UniformArray) | ||
if can_use_mpsrandom(A) | ||
@inline Random.rand!(mpsrand_rng(), A) | ||
else | ||
@inline Random.rand!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
function Random.randn!(A::MPS.NormalArray) | ||
if can_use_mpsrandom(A) | ||
@inline Random.randn!(mpsrand_rng(), A) | ||
else | ||
@inline Random.randn!(gpuarrays_rng(), A) | ||
end | ||
return A | ||
end | ||
|
||
# GPUArrays out-of-place | ||
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = Random.rand!(MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::Type, dims::Dims; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{T,length(dims),storage}(undef, dims...); kwargs...) | ||
rand(T::MPS.UniformType, dims::Dims; storage=DefaultStorageMode) = | ||
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::MPS.NormalType, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
rand(T::Type, dims::Dims; storage=DefaultStorageMode) = | ||
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
randn(T::Type, dims::Dims; storage=DefaultStorageMode) = | ||
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims),storage}(undef, dims...)) | ||
|
||
# support all dimension specifications | ||
rand(T::MPS.UniformType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(T::MPS.NormalType, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
rand(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...)) | ||
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = | ||
Random.randn!(MtlArray{T,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) | ||
Random.rand!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(T::Type, dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(gpuarrays_rng(), MtlArray{T,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# untyped out-of-place | ||
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = Random.rand!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...)) | ||
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode, kwargs...) = Random.randn!(MtlArray{Float32,length(dims)+1,storage}(undef, dim1, dims...); kwargs...) | ||
rand(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.rand!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
randn(dim1::Integer, dims::Integer...; storage=DefaultStorageMode) = | ||
Random.randn!(mpsrand_rng(), MtlArray{Float32,length(dims) + 1,storage}(undef, dim1, dims...)) | ||
|
||
# scalars | ||
rand(T::Type=Float32; storage=Shared) = rand(T, 1; storage)[] | ||
randn(T::Type=Float32; storage=Shared) = randn(T, 1; storage)[] | ||
|
||
# seeding | ||
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed) | ||
function seed!(seed=Base.rand(UInt64)) | ||
Random.seed!(gpuarrays_rng(), seed) | ||
Random.seed!(mpsrand_rng(), seed) | ||
end |
Oops, something went wrong.