-
-
Notifications
You must be signed in to change notification settings - Fork 55
More general device conversions #121
Comments
FWIW, I put up the adapt package here: https://github.com/MikeInnes/Adapt.jl At some point I'll get it integrated with cuda conversions here. |
Do you propose to replace |
Right, so I'm basically imagining that the fallback definition for |
Why not do eg. |
On a technical level, because we can't dispatch on that. Design-wise it's not completely arbitrary as that usually is what we want. e.g. (But I don't have a strong opinion that it should be CuDeviceArray - other libraries shouldn't even need to know about that type, so you could equally introduce a new one for this purpose.) |
OK, let's try again. What about: ## Adapt.jl
abstract type AbstractAdaptor end
adapt(::Type{<:AbstractAdaptor}, x) = x
# Base
adapt(A::Type{<:AbstractAdaptor}, xs::Tuple) = Tuple(adapt(A, x) for x in xs)
@generated adapt(A::Type{<:AbstractAdaptor}, x::NamedTuple) =
Expr(:tuple, (:($f=adapt(A, x.$f)) for f in fieldnames(x))...)
# LinearAlgebra
import LinearAlgebra: Adjoint, Transpose
adapt(A::Type{<:AbstractAdaptor}, x::Adjoint) = Adjoint(adapt(A, parent(x)))
adapt(A::Type{<:AbstractAdaptor}, x::Transpose) = Transpose(adapt(A, parent(x)))
## CUDAnative
abstract type CUDAAdaptor <: AbstractAdaptor end
## CuArrays.jl
adapt(::Type{<:CUDAAdaptor}, x::Array) = (println("create CuDeviceArray"); x)
## User
adapt(CUDAAdaptor, [1]') I was also thinking of calling the function |
If you want a composite / wrapper type like
RowVector{CuArray}
to be compatible with cuda kernels, you need to explicitly overload thecudaconvert
function for that type. This means that new types need to be aware of the cuda stack.I just created the adapt function to handle this for conversion to CuArrays; it essentially works like
convert
but doesn't have to return aT
, so that e.g.adapt(CuArray, ::RowVector{Array})::RowVector{CuArray}
. This could be extended to kernels if the conversion usedadapt(CuDeviceArray, x)
for each input argument.I can put
adapt
into its own package if we're on board with including it here (depending on NNlib won't cause any real harm, but does seem wrong).The text was updated successfully, but these errors were encountered: