Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

More general device conversions #121

Closed
MikeInnes opened this issue Oct 4, 2017 · 6 comments
Closed

More general device conversions #121

MikeInnes opened this issue Oct 4, 2017 · 6 comments
Labels

Comments

@MikeInnes
Copy link
Contributor

If you want a composite / wrapper type like RowVector{CuArray} to be compatible with cuda kernels, you need to explicitly overload the cudaconvert function for that type. This means that new types need to be aware of the cuda stack.

I just created the adapt function to handle this for conversion to CuArrays; it essentially works like convert but doesn't have to return a T, so that e.g. adapt(CuArray, ::RowVector{Array})::RowVector{CuArray}. This could be extended to kernels if the conversion used adapt(CuDeviceArray, x) for each input argument.

I can put adapt into its own package if we're on board with including it here (depending on NNlib won't cause any real harm, but does seem wrong).

@MikeInnes
Copy link
Contributor Author

FWIW, I put up the adapt package here: https://github.com/MikeInnes/Adapt.jl

At some point I'll get it integrated with cuda conversions here.

@maleadt
Copy link
Member

maleadt commented Mar 16, 2018

Do you propose to replace cudaconvert with adapt? I can imagine calling adapt(CuDeviceArray, x) where x<:AbstractArray, but that wouldn't work when people implement other GPU-compatible array derivatives. And what to do with cudaconvert(x::Tuple) = cudaconvert.(x)?

@MikeInnes
Copy link
Contributor Author

Right, so I'm basically imagining that the fallback definition for cudaconvert calls adapt(CuDeviceArray, x), along with the checks that you currently have in place that the result is gpu-compatible. I haven't thought about it in a little while but I think it's a strict addition of functionality.

@maleadt
Copy link
Member

maleadt commented Mar 16, 2018

Why not do eg. adapt(CUDAnative, x) then? Tying it to CuDeviceArray seems arbitrary.

@MikeInnes
Copy link
Contributor Author

MikeInnes commented Mar 16, 2018

On a technical level, because we can't dispatch on that.

Design-wise it's not completely arbitrary as that usually is what we want. e.g. adapt(CuDeviceArray, ::RowVector{CuArray}) should turn the internal cuarray into a cudevicearray. The core underlying definition here will be adapt(CuDeviceArray, x::CuArray) = convert(CuDeviceArray, x).

(But I don't have a strong opinion that it should be CuDeviceArray - other libraries shouldn't even need to know about that type, so you could equally introduce a new one for this purpose.)

@maleadt maleadt added the design label Apr 6, 2018
@maleadt
Copy link
Member

maleadt commented Oct 19, 2018

OK, let's try again.

What about:

## Adapt.jl

abstract type AbstractAdaptor end

adapt(::Type{<:AbstractAdaptor}, x) = x

# Base
adapt(A::Type{<:AbstractAdaptor}, xs::Tuple) = Tuple(adapt(A, x) for x in xs)
@generated adapt(A::Type{<:AbstractAdaptor}, x::NamedTuple) =
    Expr(:tuple, (:($f=adapt(A, x.$f)) for f in fieldnames(x))...)

# LinearAlgebra
import LinearAlgebra: Adjoint, Transpose
adapt(A::Type{<:AbstractAdaptor}, x::Adjoint)   = Adjoint(adapt(A, parent(x)))
adapt(A::Type{<:AbstractAdaptor}, x::Transpose) = Transpose(adapt(A, parent(x)))


## CUDAnative

abstract type CUDAAdaptor <: AbstractAdaptor end


## CuArrays.jl

adapt(::Type{<:CUDAAdaptor}, x::Array) = (println("create CuDeviceArray"); x)


## User

adapt(CUDAAdaptor, [1]')

I was also thinking of calling the function unsafe_adapt because the CUDAnative adaptors are mostly going to take a pointer and create wrappers, requiring protection of the original object with GC.@preserve (ie. something that is typically conveyed by naming the function unsafe_). But not all adaptors will do that.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

2 participants