-
-
Notifications
You must be signed in to change notification settings - Fork 300
/
interp.jl
170 lines (134 loc) · 4.3 KB
/
interp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Linear interpolation in one dimension
##### Fields
- `breaks::AbstractVector` : A sorted array of grid points on which to interpolate
- `vals::AbstractVector` : The function values associated with each of the grid points
##### Examples
```julia
breaks = cumsum(0.1 .* rand(20))
vals = 0.1 .* sin.(breaks)
li = LinInterp(breaks, vals)
# do interpolation via `call` method on a LinInterp object
li(0.2)
# use broadcasting to evaluate at multiple points
li.([0.1, 0.2, 0.3])
```
"""
immutable LinInterp{TV<:AbstractArray,TB<:AbstractVector}
breaks::TB
vals::TV
_n::Int
_ncol::Int
function (::Type{LinInterp{TV,TB}}){TB,TV}(b::TB, v::TV)
if size(b, 1) != size(v, 1)
m = "breaks and vals must have same number of elements"
throw(DimensionMismatch(m))
end
if !issorted(b)
m = "breaks must be sorted"
throw(ArgumentError(m))
end
new{TV,TB}(b, v, length(b), size(v, 2))
end
end
function Base.:(==)(li1::LinInterp, li2::LinInterp)
all(getfield(li1, f) == getfield(li2, f) for f in fieldnames(li1))
end
function LinInterp{TV<:AbstractArray,TB<:AbstractVector}(b::TB, v::TV)
LinInterp{TV,TB}(b, v)
end
@compat function (li::LinInterp{<:AbstractVector})(xp::Number)
ix = searchsortedfirst(li.breaks, xp)
# handle corner cases
@inbounds begin
ix == 1 && return li.vals[1]
ix == li._n + 1 && return li.vals[end]
# now get on to the real work...
z = (li.breaks[ix] - xp)/(li.breaks[ix] - li.breaks[ix-1])
return (1-z) * li.vals[ix] + z * li.vals[ix-1]
end
end
@compat function (li::LinInterp{<:AbstractMatrix})(xp::Number, col::Int)
ix = searchsortedfirst(li.breaks, xp)
@boundscheck begin
if col > li._ncol || col < 1
msg = "col must be beteween 1 and $(li._ncol), found $col"
throw(BoundsError(msg))
end
end
@inbounds begin
# handle corner cases
ix == 1 && return li.vals[1, col]
ix == li._n + 1 && return li.vals[end, col]
# now get on to the real work...
z = (li.breaks[ix] - xp)/(li.breaks[ix] - li.breaks[ix-1])
return (1-z) * li.vals[ix, col] + z * li.vals[ix-1, col]
end
end
_out_eltype{TV,TB}(li::LinInterp{TV,TB}) = promote_type(eltype(TV), eltype(TB))
@compat function (li::LinInterp{<:AbstractMatrix})(
xp::Number, cols::AbstractVector{<:Integer}
)
ix = searchsortedfirst(li.breaks, xp)
@boundscheck begin
for col in cols
if col > li._ncol || col < 1
msg = "all cols must be beteween 1 and $(li._ncol), found $col"
throw(BoundsError(msg))
end
end
end
out = Array{_out_eltype(li)}(length(cols))
@inbounds begin
# handle corner cases
if ix == 1
for (ind, col) in enumerate(cols)
out[ind] = li.vals[1, col]
end
return out
end
if ix == li._n + 1
for (ind, col) in enumerate(cols)
out[ind] = li.vals[end, col]
end
return out
end
# now get on to the real work...
z = (li.breaks[ix] - xp)/(li.breaks[ix] - li.breaks[ix-1])
for (ind, col) in enumerate(cols)
out[ind] = (1-z) * li.vals[ix, col] + z * li.vals[ix-1, col]
end
return out
end
end
@compat (li::LinInterp{<:AbstractMatrix})(xp::Number) = li(xp, 1:li._ncol)
"""
interp(grid::AbstractVector, function_vals::AbstractVector)
Linear interpolation in one dimension
##### Examples
```julia
breaks = cumsum(0.1 .* rand(20))
vals = 0.1 .* sin.(breaks)
li = interp(breaks, vals)
# Do interpolation by treating `li` as a function you can pass scalars to
li(0.2)
# use broadcasting to evaluate at multiple points
li.([0.1, 0.2, 0.3])
```
"""
function interp(grid::AbstractVector, function_vals::AbstractVector)
if !issorted(grid)
inds = sortperm(grid)
return LinInterp(grid[inds], function_vals[inds])
else
return LinInterp(grid, function_vals)
end
end
function interp(grid::AbstractVector, function_vals::AbstractMatrix)
if !issorted(grid)
inds = sortperm(grid)
return LinInterp(grid[inds], function_vals[inds, :])
else
return LinInterp(grid, function_vals)
end
end