-
Notifications
You must be signed in to change notification settings - Fork 105
/
gpu_support.jl
80 lines (68 loc) · 3.58 KB
/
gpu_support.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import Adapt: adapt_structure
using Adapt: adapt
function adapt_structure(to, itp::BSplineInterpolation{T,N}) where {T,N}
coefs′ = adapt(to, itp.coefs)
T′ = update_eltype(T, coefs′, itp.coefs)
BSplineInterpolation{T′,N}(coefs′, itp.parentaxes, itp.it)
end
function update_eltype(T, coefs′, coefs)
ET = eltype(coefs′)
ET === eltype(coefs) && return T
WT = tweight(coefs′)
T′ = Base.promote_op(*, WT, ET)
(isconcretetype(T′) || isempty(coefs)) && return T′
return typeof(zero(WT) * convert(ET, first(coefs)))
end
function adapt_structure(to, itp::LanczosInterpolation{T,N}) where {T,N}
coefs′ = adapt(to, itp.coefs)
parentaxes′ = adapt(to, itp.parentaxes)
LanczosInterpolation{eltype(coefs′),N}(coefs′, parentaxes′, itp.it)
end
function adapt_structure(to, itp::GriddedInterpolation{T,N}) where {T,N}
coefs′ = adapt(to, itp.coefs)
knots′ = adapt(to, itp.knots)
T′ = update_eltype(T, coefs′, itp.coefs)
GriddedInterpolation{T′,N,typeof(coefs′),itptype(itp),typeof(knots′)}(knots′, coefs′, itp.it)
end
function adapt_structure(to, itp::ScaledInterpolation{T,N,<:Any,IT,RT}) where {T,N,IT,RT<:NTuple{N,AbstractRange}}
ranges = itp.ranges
itp′ = adapt(to, itp.itp)
ScaledInterpolation{eltype(itp′),N,typeof(itp′),IT,RT}(itp′, ranges)
end
function adapt_structure(to, itp::Extrapolation{T,N}) where {T,N}
et = itp.et
itp′ = adapt(to, itp.itp)
Extrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(et)}(itp′, et)
end
function adapt_structure(to, itp::FilledExtrapolation{T,N}) where {T,N}
fillvalue = itp.fillvalue
itp′ = adapt(to, itp.itp)
FilledExtrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(fillvalue)}(itp′, fillvalue)
end
import Base.Broadcast: broadcasted, BroadcastStyle
using Base.Broadcast: broadcastable, combine_styles, AbstractArrayStyle
function broadcasted(itp::AbstractInterpolation, args...)
args′ = map(broadcastable, args)
# we overload BroadcastStyle here (try our best to do broadcast on GPU)
style = combine_styles(Ref(itp), args′...)
broadcasted(style, itp, args′...)
end
"""
Interpolations.root_storage_type(::Type{<:AbstractInterpolation}) -> Type{<:AbstractArray}
This function returns the type of the root coefficients array of an `AbstractInterpolation`.
Some array wrappers, like `OffsetArray`, should be skipped.
"""
root_storage_type(::Type{T}) where {T<:AbstractInterpolation} = Array{eltype(T),ndims(T)} # fallback to `Array` by default.
root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1))
root_storage_type(::Type{T}) where {T<:FilledExtrapolation} = root_storage_type(fieldtype(T, 1))
root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1))
root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1))
root_storage_type(::Type{T}) where {T<:LanczosInterpolation} = root_storage_type(fieldtype(T, 1))
root_storage_type(::Type{T}) where {T<:GriddedInterpolation} = root_storage_type(fieldtype(T, 2))
root_storage_type(::Type{T}) where {T<:OffsetArray} = root_storage_type(fieldtype(T, 1))
root_storage_type(::Type{T}) where {T<:AbstractArray} = T
BroadcastStyle(::Type{<:Ref{T}}) where {T<:AbstractInterpolation} = _to_scalar_style(BroadcastStyle(T))
BroadcastStyle(::Type{T}) where {T<:AbstractInterpolation} = BroadcastStyle(root_storage_type(T))
_to_scalar_style(::S) where {S<:AbstractArrayStyle} = S(Val(0))
_to_scalar_style(S::AbstractArrayStyle{Any}) = S
_to_scalar_style(S) = S