diff --git a/src/onehot.jl b/src/onehot.jl index 2f1eb36560..b87707fae8 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -28,6 +28,19 @@ Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...]) cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(cudaconvert(x.data)) end +@require CLArrays begin + import CLArrays.Shorthands: cl + using CLArrays: CLArray, GlobalArray, GlobalPointer, PreDeviceArray + cl(xs::OneHotMatrix) = OneHotMatrix(cl(xs.data)) + # the on device conversions are still a bit complicated... + CLArrays.kernel_convert(x::OneHotMatrix{T}) where T <: CLArray = OneHotMatrix(CLArrays.kernel_convert(x.data)) + CLArrays.predevice_type(::Type{OneHotMatrix{T}}) where T <: GlobalArray = OneHotMatrix{CLArrays.predevice_type(T)} + CLArrays.device_type(x::OneHotMatrix{T}) where T <: CLArray = OneHotMatrix{CLArrays.device_type(x.data)} + CLArrays.reconstruct(x::OneHotMatrix{T}, ptr::GlobalPointer) where T <: PreDeviceArray = OneHotMatrix(CLArrays.reconstruct(x.data, ptr)) + + CLArrays.GPUArrays.arg_length(x::OneHotMatrix{T}) where T <: CLArrays.GPUArrays.GPUArray = UInt32.(size(x)) +end + onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels)) onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls]) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 74fcb2b89b..7b3a145833 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -78,4 +78,10 @@ using Requires cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs)))) end +@require CLArrays begin + import CLArrays.Shorthands: cl + cl(xs::TrackedArray) = TrackedArray(xs.f, cl(xs.data), Base.RefValue(cl(grad(xs)))) +end + + end