From 65a60ffaa6ee3956de65c50f7ee8aefecd298b56 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 21 Jun 2020 11:30:35 -0400 Subject: [PATCH] CUDA.jl compat --- src/init.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/init.jl b/src/init.jl index e4bb4aba..98a9db7d 100644 --- a/src/init.jl +++ b/src/init.jl @@ -19,6 +19,15 @@ function __init__() Base.convert(::Type{<:CuArrays.CuArray},VA::AbstractVectorOfArray) = CuArrays.CuArray(VA) @adjoint CuArrays.CuArray(xs::AbstractVectorOfArray) = CuArrays.CuArray(xs), ȳ -> (ȳ,) end + + @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin + function CUDA.CuArray(VA::AbstractVectorOfArray) + vecs = vec.(VA.u) + return CUDA.CuArray(reshape(reduce(hcat,vecs),size(VA.u[1])...,length(VA.u))) + end + Base.convert(::Type{<:CUDA.CuArray},VA::AbstractVectorOfArray) = CUDA.CuArray(VA) + @adjoint CUDA.CuArray(xs::AbstractVectorOfArray) = CUDA.CuArray(xs), ȳ -> (ȳ,) + end @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin function recursivecopy!(b::AbstractArray{T,N},a::AbstractArray{T2,N}) where {T<:Tracker.TrackedArray,T2<:Tracker.TrackedArray,N}