Skip to content

Commit

Permalink
Merge pull request #19 from PoisotLab/tests/SimpleSDMLayers
Browse files Browse the repository at this point in the history
Fix SimpleSDMLayers tests
  • Loading branch information
tpoisot authored Nov 13, 2022
2 parents 7dbd546 + bcc40bf commit 7b065aa
Show file tree
Hide file tree
Showing 13 changed files with 318 additions and 246 deletions.
1 change: 1 addition & 0 deletions SimpleSDMLayers/src/SimpleSDMLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ include("lib/overloads.jl")

# Raster clipping
include("lib/clip.jl")
export clip

include("lib/generated.jl")

Expand Down
162 changes: 101 additions & 61 deletions SimpleSDMLayers/src/lib/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: similar
import Base: copy
import Base: eltype
import Base: convert
import Base: collect
import Base: values
import Base: hcat
import Base: vcat
import Base: show
Expand All @@ -24,15 +24,21 @@ Shows a textual representation of the layer.
function Base.show(io::IO, ::MIME"text/plain", layer::T) where {T <: SimpleSDMLayer}
itype = eltype(layer)
otype = T <: SimpleSDMPredictor ? "predictor" : "response"
print(io, """SDM $(otype)$(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells
\x20\x20Latitudes\t$(layer.bottom)$(layer.top)
\x20\x20Longitudes\t$(layer.left)$(layer.right)""")
return print(
io,
"""SDM $(otype)$(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells
\x20\x20Latitudes\t$(layer.bottom)$(layer.top)
\x20\x20Longitudes\t$(layer.left)$(layer.right)""",
)
end

function Base.show(io::IO, layer::T) where {T <: SimpleSDMLayer}
itype = eltype(layer)
otype = T <: SimpleSDMPredictor ? "predictor" : "response"
print(io, "SDM $(otype)$(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells")
return print(
io,
"SDM $(otype)$(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells",
)
end

"""
Expand All @@ -41,7 +47,9 @@ end
Returns a response with the same grid and bounding box as the predictor.
"""
function Base.convert(::Type{SimpleSDMResponse}, layer::T) where {T <: SimpleSDMPredictor}
return copy(SimpleSDMResponse(layer.grid, layer.left, layer.right, layer.bottom, layer.top))
return copy(
SimpleSDMResponse(layer.grid, layer.left, layer.right, layer.bottom, layer.top),
)
end

"""
Expand All @@ -50,7 +58,9 @@ end
Returns a predictor with the same grid and bounding box as the response.
"""
function Base.convert(::Type{SimpleSDMPredictor}, layer::T) where {T <: SimpleSDMResponse}
return copy(SimpleSDMPredictor(layer.grid, layer.left, layer.right, layer.bottom, layer.top))
return copy(
SimpleSDMPredictor(layer.grid, layer.left, layer.right, layer.bottom, layer.top),
)
end

"""
Expand Down Expand Up @@ -92,14 +102,17 @@ Returns the stride, *i.e.* half the length, of cell dimensions, possibly
alongside a side of the grid. The first position is the length of the
*longitude* cells, the second the *latitude*.
"""
function Base.stride(layer::T; dims::Union{Nothing,Integer}=nothing) where {T <: SimpleSDMLayer}
lon_stride = (layer.right-layer.left)/2.0size(layer, 2)
lat_stride = (layer.top-layer.bottom)/2.0size(layer, 1)
function Base.stride(
layer::T;
dims::Union{Nothing, Integer} = nothing,
) where {T <: SimpleSDMLayer}
lon_stride = (layer.right - layer.left) / 2.0size(layer, 2)
lat_stride = (layer.top - layer.bottom) / 2.0size(layer, 1)
isnothing(dims) && return (lon_stride, lat_stride)
dims == 1 && return lon_stride
dims == 2 && return lat_stride
end
Base.stride(layer::T, i::Int) where {T<:SimpleSDMLayer} = stride(layer; dims=i)
Base.stride(layer::T, i::Int) where {T <: SimpleSDMLayer} = stride(layer; dims = i)

"""
Base.eachindex(layer::T) where {T <: SimpleSDMLayer}
Expand All @@ -118,12 +131,11 @@ the type. If not, the same result can always be achieved through the use of
`copy`, manual update, and `convert`.
"""
function Base.similar(layer::T, ::Type{TC}) where {TC <: Any, T <: SimpleSDMLayer}
emptygrid = convert(Matrix{Union{Nothing,TC}}, zeros(TC, size(layer)))
emptygrid = convert(Matrix{Union{Nothing, TC}}, zeros(TC, size(layer)))
emptygrid[findall(isnothing, layer.grid)] .= nothing
return SimpleSDMResponse(emptygrid, layer.left, layer.right, layer.bottom, layer.top)
end


"""
Base.similar(layer::T) where {T <: SimpleSDMLayer}
Expand All @@ -145,32 +157,47 @@ Returns a new copy of the layer, which has the same type.
function Base.copy(layer::T) where {T <: SimpleSDMLayer}
copygrid = copy(layer.grid)
RT = T <: SimpleSDMResponse ? SimpleSDMResponse : SimpleSDMPredictor
return RT(copygrid, copy(layer.left), copy(layer.right), copy(layer.bottom), copy(layer.top))
return RT(
copygrid,
copy(layer.left),
copy(layer.right),
copy(layer.bottom),
copy(layer.top),
)
end

"""
Base.collect(l::T) where {T <: SimpleSDMLayer}
Base.values(l::T) where {T <: SimpleSDMLayer}
Returns the non-`nothing` values of a layer.
"""
function Base.collect(l::T) where {T <: SimpleSDMLayer}
function Base.values(l::T) where {T <: SimpleSDMLayer}
v = filter(!isnothing, l.grid)
return convert(Vector{typeof(v[1])}, v)
return convert(Vector{typeof(v[1])}, v)
end

"""
Base.vcat(l1::T, l2::T) where {T <: SimpleSDMLayers}
Adds the second layer *under* the first one (according to coordinates),
Adds the second layer *under* the first one (according to coordinates),
assuming the strides and left/right coordinates match. This will automatically
re-order the layers if the second is above the first.
"""
function Base.vcat(l1::T, l2::T) where {T <: SimpleSDMLayer}
(l1.left == l2.left) || throw(ArgumentError("The two layers passed to vcat must have the same left coordinate"))
(l1.right == l2.right) || throw(ArgumentError("The two layers passed to vcat must have the same right coordinate"))
all(stride(l1) .≈ stride(l2)) || throw(ArgumentError("The two layers passed to vcat must have the same stride"))
(l1.left == l2.left) || throw(
ArgumentError("The two layers passed to vcat must have the same left coordinate"),
)
(l1.right == l2.right) || throw(
ArgumentError("The two layers passed to vcat must have the same right coordinate"),
)
all(stride(l1) .≈ stride(l2)) ||
throw(ArgumentError("The two layers passed to vcat must have the same stride"))
(l1.top == l2.bottom) && return vcat(l2, l1)
(l2.top == l1.bottom) || throw(ArgumentError("The two layers passed to vcat must have contiguous bottom and top coordinates"))
(l2.top == l1.bottom) || throw(
ArgumentError(
"The two layers passed to vcat must have contiguous bottom and top coordinates",
),
)
new_grid = vcat(l2.grid, l1.grid)
RT = T <: SimpleSDMPredictor ? SimpleSDMPredictor : SimpleSDMResponse
return RT(new_grid, l1.left, l1.right, l2.bottom, l1.top)
Expand All @@ -180,15 +207,24 @@ end
Base.hcat(l1::T, l2::T) where {T <: SimpleSDMLayers}
Adds the second layer *to the right of* the first one (according to coordinates),
assuming the strides and left/right coordinates match. This will automatically
assuming the strides and left/right coordinates match. This will automatically
re-order the layers if the second is to the left the first.
"""
function Base.hcat(l1::T, l2::T) where {T <: SimpleSDMLayer}
(l1.top == l2.top) || throw(ArgumentError("The two layers passed to hcat must have the same top coordinate"))
(l1.bottom == l2.bottom) || throw(ArgumentError("The two layers passed to hcat must have the same bottom coordinate"))
all(stride(l1) .≈ stride(l2)) || throw(ArgumentError("The two layers passed to hcat must have the same stride"))
(l1.top == l2.top) || throw(
ArgumentError("The two layers passed to hcat must have the same top coordinate"),
)
(l1.bottom == l2.bottom) || throw(
ArgumentError("The two layers passed to hcat must have the same bottom coordinate"),
)
all(stride(l1) .≈ stride(l2)) ||
throw(ArgumentError("The two layers passed to hcat must have the same stride"))
(l2.right == l1.left) && return hcat(l2, l1)
(l1.right == l2.left) || throw(ArgumentError("The two layers passed to hcat must have contiguous left and right coordinates"))
(l1.right == l2.left) || throw(
ArgumentError(
"The two layers passed to hcat must have contiguous left and right coordinates",
),
)
new_grid = hcat(l1.grid, l2.grid)
RT = T <: SimpleSDMPredictor ? SimpleSDMPredictor : SimpleSDMResponse
return RT(new_grid, l1.left, l2.right, l1.bottom, l1.top)
Expand All @@ -202,7 +238,11 @@ possible for `SimpleSDMResponse` elements (which are mutable) and will throw an
error if called on a `SimpleSDMPredictor` element (which is not mutable).
"""
function Base.replace!(layer::T, old_new::Pair...) where {T <: SimpleSDMLayer}
layer isa SimpleSDMResponse || throw(ArgumentError("`SimpleSDMPredictor` elements are immutable. Convert to a `SimpleSDMResponse` first or call `replace!` directly on the grid element."))
layer isa SimpleSDMResponse || throw(
ArgumentError(
"`SimpleSDMPredictor` elements are immutable. Convert to a `SimpleSDMResponse` first or call `replace!` directly on the grid element.",
),
)
replace!(layer.grid, old_new...)
return layer
end
Expand Down Expand Up @@ -236,25 +276,25 @@ end
Returns the quantiles of `layer` at `p`, using `Statistics.quantile`.
"""
function Statistics.quantile(layer::T, p) where {T <: SimpleSDMLayer}
return quantile(collect(layer), p)
return quantile(values(layer), p)
end

"""
==(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)
Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
verified with `==` (e.g., `layer1.grid == layer2.grid`).
"""
function Base.:(==)(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)
return all(
[
layer1.grid == layer2.grid,
layer1.left == layer2.left,
layer1.right == layer2.right,
layer1.bottom == layer2.bottom,
layer1.top == layer2.top,
]
layer1.grid == layer2.grid,
layer1.left == layer2.left,
layer1.right == layer2.right,
layer1.bottom == layer2.bottom,
layer1.top == layer2.top,
]
)
end

Expand All @@ -265,45 +305,45 @@ end
"""
isequal(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)
Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
verified with `isequal` (e.g., `isequal(layer1.grid, layer2.grid)`).
"""
function Base.isequal(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)
return all(
[
isequal(layer1.grid, layer2.grid),
isequal(layer1.left, layer2.left),
isequal(layer1.right, layer2.right),
isequal(layer1.bottom, layer2.bottom),
isequal(layer1.top, layer2.top),
]
isequal(layer1.grid, layer2.grid),
isequal(layer1.left, layer2.left),
isequal(layer1.right, layer2.right),
isequal(layer1.bottom, layer2.bottom),
isequal(layer1.top, layer2.top),
]
)
end

Base.:*(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n*x, layer)
Base.:*(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x*n, layer)
Base.:/(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n/x, layer)
Base.:/(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x/n, layer)
Base.:-(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n-x, layer)
Base.:-(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x-n, layer)
Base.:+(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n+x, layer)
Base.:+(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x+n, layer)
Base.://(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n//x, layer)
Base.://(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x//n, layer)
Base.:%(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n%x, layer)
Base.:%(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x%n, layer)
Base.:*(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n * x, layer)
Base.:*(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x * n, layer)
Base.:/(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n / x, layer)
Base.:/(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x / n, layer)
Base.:-(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n - x, layer)
Base.:-(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x - n, layer)
Base.:+(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n + x, layer)
Base.:+(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x + n, layer)
Base.://(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n // x, layer)
Base.://(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x // n, layer)
Base.:%(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n % x, layer)
Base.:%(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x % n, layer)

function Base.findmax(layer::T) where {T <: SimpleSDMLayer}
val, pos = findmax(collect(layer))
val, pos = findmax(values(layer))
return (val, keys(layer)[pos])
end

function Base.findmin(layer::T) where {T <: SimpleSDMLayer}
val, pos = findmin(collect(layer))
val, pos = findmin(values(layer))
return (val, keys(layer)[pos])
end

function Base.findall(f::Function, layer::T) where {T <: SimpleSDMLayer}
return keys(layer)[findall(f, collect(layer))]
end
return keys(layer)[findall(f, values(layer))]
end
Loading

0 comments on commit 7b065aa

Please sign in to comment.