Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SimpleSDMLayers tests #19

Merged
merged 2 commits into from
Nov 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions SimpleSDMLayers/src/SimpleSDMLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ include("lib/overloads.jl")

# Raster clipping
include("lib/clip.jl")
export clip

include("lib/generated.jl")

Expand Down
162 changes: 101 additions & 61 deletions SimpleSDMLayers/src/lib/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: similar
import Base: copy
import Base: eltype
import Base: convert
import Base: collect
import Base: values
import Base: hcat
import Base: vcat
import Base: show
Expand All @@ -24,15 +24,21 @@ Shows a textual representation of the layer.
function Base.show(io::IO, ::MIME"text/plain", layer::T) where {T <: SimpleSDMLayer}
itype = eltype(layer)
otype = T <: SimpleSDMPredictor ? "predictor" : "response"
print(io, """SDM $(otype) → $(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells
\x20\x20Latitudes\t$(layer.bottom) ⇢ $(layer.top)
\x20\x20Longitudes\t$(layer.left) ⇢ $(layer.right)""")
return print(
io,
"""SDM $(otype) → $(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells
\x20\x20Latitudes\t$(layer.bottom) ⇢ $(layer.top)
\x20\x20Longitudes\t$(layer.left) ⇢ $(layer.right)""",
)
end

function Base.show(io::IO, layer::T) where {T <: SimpleSDMLayer}
itype = eltype(layer)
otype = T <: SimpleSDMPredictor ? "predictor" : "response"
print(io, "SDM $(otype) → $(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells")
return print(
io,
"SDM $(otype) → $(size(layer,1))×$(size(layer,2)) grid with $(length(layer)) $(itype)-valued cells",
)
end

"""
Expand All @@ -41,7 +47,9 @@ end
Returns a response with the same grid and bounding box as the predictor.
"""
function Base.convert(::Type{SimpleSDMResponse}, layer::T) where {T <: SimpleSDMPredictor}
return copy(SimpleSDMResponse(layer.grid, layer.left, layer.right, layer.bottom, layer.top))
return copy(
SimpleSDMResponse(layer.grid, layer.left, layer.right, layer.bottom, layer.top),
)
end

"""
Expand All @@ -50,7 +58,9 @@ end
Returns a predictor with the same grid and bounding box as the response.
"""
function Base.convert(::Type{SimpleSDMPredictor}, layer::T) where {T <: SimpleSDMResponse}
return copy(SimpleSDMPredictor(layer.grid, layer.left, layer.right, layer.bottom, layer.top))
return copy(
SimpleSDMPredictor(layer.grid, layer.left, layer.right, layer.bottom, layer.top),
)
end

"""
Expand Down Expand Up @@ -92,14 +102,17 @@ Returns the stride, *i.e.* half the length, of cell dimensions, possibly
alongside a side of the grid. The first position is the length of the
*longitude* cells, the second the *latitude*.
"""
function Base.stride(layer::T; dims::Union{Nothing,Integer}=nothing) where {T <: SimpleSDMLayer}
lon_stride = (layer.right-layer.left)/2.0size(layer, 2)
lat_stride = (layer.top-layer.bottom)/2.0size(layer, 1)
function Base.stride(
layer::T;
dims::Union{Nothing, Integer} = nothing,
) where {T <: SimpleSDMLayer}
lon_stride = (layer.right - layer.left) / 2.0size(layer, 2)
lat_stride = (layer.top - layer.bottom) / 2.0size(layer, 1)
isnothing(dims) && return (lon_stride, lat_stride)
dims == 1 && return lon_stride
dims == 2 && return lat_stride
end
Base.stride(layer::T, i::Int) where {T<:SimpleSDMLayer} = stride(layer; dims=i)
Base.stride(layer::T, i::Int) where {T <: SimpleSDMLayer} = stride(layer; dims = i)

"""
Base.eachindex(layer::T) where {T <: SimpleSDMLayer}
Expand All @@ -118,12 +131,11 @@ the type. If not, the same result can always be achieved through the use of
`copy`, manual update, and `convert`.
"""
function Base.similar(layer::T, ::Type{TC}) where {TC <: Any, T <: SimpleSDMLayer}
emptygrid = convert(Matrix{Union{Nothing,TC}}, zeros(TC, size(layer)))
emptygrid = convert(Matrix{Union{Nothing, TC}}, zeros(TC, size(layer)))
emptygrid[findall(isnothing, layer.grid)] .= nothing
return SimpleSDMResponse(emptygrid, layer.left, layer.right, layer.bottom, layer.top)
end


"""
Base.similar(layer::T) where {T <: SimpleSDMLayer}

Expand All @@ -145,32 +157,47 @@ Returns a new copy of the layer, which has the same type.
function Base.copy(layer::T) where {T <: SimpleSDMLayer}
copygrid = copy(layer.grid)
RT = T <: SimpleSDMResponse ? SimpleSDMResponse : SimpleSDMPredictor
return RT(copygrid, copy(layer.left), copy(layer.right), copy(layer.bottom), copy(layer.top))
return RT(
copygrid,
copy(layer.left),
copy(layer.right),
copy(layer.bottom),
copy(layer.top),
)
end

"""
Base.collect(l::T) where {T <: SimpleSDMLayer}
Base.values(l::T) where {T <: SimpleSDMLayer}

Returns the non-`nothing` values of a layer.
"""
function Base.collect(l::T) where {T <: SimpleSDMLayer}
function Base.values(l::T) where {T <: SimpleSDMLayer}
v = filter(!isnothing, l.grid)
return convert(Vector{typeof(v[1])}, v)
return convert(Vector{typeof(v[1])}, v)
end

"""
Base.vcat(l1::T, l2::T) where {T <: SimpleSDMLayers}

Adds the second layer *under* the first one (according to coordinates),
Adds the second layer *under* the first one (according to coordinates),
assuming the strides and left/right coordinates match. This will automatically
re-order the layers if the second is above the first.
"""
function Base.vcat(l1::T, l2::T) where {T <: SimpleSDMLayer}
(l1.left == l2.left) || throw(ArgumentError("The two layers passed to vcat must have the same left coordinate"))
(l1.right == l2.right) || throw(ArgumentError("The two layers passed to vcat must have the same right coordinate"))
all(stride(l1) .≈ stride(l2)) || throw(ArgumentError("The two layers passed to vcat must have the same stride"))
(l1.left == l2.left) || throw(
ArgumentError("The two layers passed to vcat must have the same left coordinate"),
)
(l1.right == l2.right) || throw(
ArgumentError("The two layers passed to vcat must have the same right coordinate"),
)
all(stride(l1) .≈ stride(l2)) ||
throw(ArgumentError("The two layers passed to vcat must have the same stride"))
(l1.top == l2.bottom) && return vcat(l2, l1)
(l2.top == l1.bottom) || throw(ArgumentError("The two layers passed to vcat must have contiguous bottom and top coordinates"))
(l2.top == l1.bottom) || throw(
ArgumentError(
"The two layers passed to vcat must have contiguous bottom and top coordinates",
),
)
new_grid = vcat(l2.grid, l1.grid)
RT = T <: SimpleSDMPredictor ? SimpleSDMPredictor : SimpleSDMResponse
return RT(new_grid, l1.left, l1.right, l2.bottom, l1.top)
Expand All @@ -180,15 +207,24 @@ end
Base.hcat(l1::T, l2::T) where {T <: SimpleSDMLayers}

Adds the second layer *to the right of* the first one (according to coordinates),
assuming the strides and left/right coordinates match. This will automatically
assuming the strides and left/right coordinates match. This will automatically
re-order the layers if the second is to the left the first.
"""
function Base.hcat(l1::T, l2::T) where {T <: SimpleSDMLayer}
(l1.top == l2.top) || throw(ArgumentError("The two layers passed to hcat must have the same top coordinate"))
(l1.bottom == l2.bottom) || throw(ArgumentError("The two layers passed to hcat must have the same bottom coordinate"))
all(stride(l1) .≈ stride(l2)) || throw(ArgumentError("The two layers passed to hcat must have the same stride"))
(l1.top == l2.top) || throw(
ArgumentError("The two layers passed to hcat must have the same top coordinate"),
)
(l1.bottom == l2.bottom) || throw(
ArgumentError("The two layers passed to hcat must have the same bottom coordinate"),
)
all(stride(l1) .≈ stride(l2)) ||
throw(ArgumentError("The two layers passed to hcat must have the same stride"))
(l2.right == l1.left) && return hcat(l2, l1)
(l1.right == l2.left) || throw(ArgumentError("The two layers passed to hcat must have contiguous left and right coordinates"))
(l1.right == l2.left) || throw(
ArgumentError(
"The two layers passed to hcat must have contiguous left and right coordinates",
),
)
new_grid = hcat(l1.grid, l2.grid)
RT = T <: SimpleSDMPredictor ? SimpleSDMPredictor : SimpleSDMResponse
return RT(new_grid, l1.left, l2.right, l1.bottom, l1.top)
Expand All @@ -202,7 +238,11 @@ possible for `SimpleSDMResponse` elements (which are mutable) and will throw an
error if called on a `SimpleSDMPredictor` element (which is not mutable).
"""
function Base.replace!(layer::T, old_new::Pair...) where {T <: SimpleSDMLayer}
layer isa SimpleSDMResponse || throw(ArgumentError("`SimpleSDMPredictor` elements are immutable. Convert to a `SimpleSDMResponse` first or call `replace!` directly on the grid element."))
layer isa SimpleSDMResponse || throw(
ArgumentError(
"`SimpleSDMPredictor` elements are immutable. Convert to a `SimpleSDMResponse` first or call `replace!` directly on the grid element.",
),
)
replace!(layer.grid, old_new...)
return layer
end
Expand Down Expand Up @@ -236,25 +276,25 @@ end
Returns the quantiles of `layer` at `p`, using `Statistics.quantile`.
"""
function Statistics.quantile(layer::T, p) where {T <: SimpleSDMLayer}
return quantile(collect(layer), p)
return quantile(values(layer), p)
end

"""
==(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)

Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
verified with `==` (e.g., `layer1.grid == layer2.grid`).
"""
function Base.:(==)(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)
return all(
[
layer1.grid == layer2.grid,
layer1.left == layer2.left,
layer1.right == layer2.right,
layer1.bottom == layer2.bottom,
layer1.top == layer2.top,
]
layer1.grid == layer2.grid,
layer1.left == layer2.left,
layer1.right == layer2.right,
layer1.bottom == layer2.bottom,
layer1.top == layer2.top,
]
)
end

Expand All @@ -265,45 +305,45 @@ end
"""
isequal(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)

Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
Tests whether two `SimpleSDMLayer` elements are equal. The layers are equal if
all their fields (`grid`, `left`, `right`, `bottom`, `top`) are equal, as
verified with `isequal` (e.g., `isequal(layer1.grid, layer2.grid)`).
"""
function Base.isequal(layer1::SimpleSDMLayer, layer2::SimpleSDMLayer)
return all(
[
isequal(layer1.grid, layer2.grid),
isequal(layer1.left, layer2.left),
isequal(layer1.right, layer2.right),
isequal(layer1.bottom, layer2.bottom),
isequal(layer1.top, layer2.top),
]
isequal(layer1.grid, layer2.grid),
isequal(layer1.left, layer2.left),
isequal(layer1.right, layer2.right),
isequal(layer1.bottom, layer2.bottom),
isequal(layer1.top, layer2.top),
]
)
end

Base.:*(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n*x, layer)
Base.:*(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x*n, layer)
Base.:/(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n/x, layer)
Base.:/(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x/n, layer)
Base.:-(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n-x, layer)
Base.:-(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x-n, layer)
Base.:+(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n+x, layer)
Base.:+(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x+n, layer)
Base.://(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n//x, layer)
Base.://(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x//n, layer)
Base.:%(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n%x, layer)
Base.:%(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x%n, layer)
Base.:*(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n * x, layer)
Base.:*(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x * n, layer)
Base.:/(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n / x, layer)
Base.:/(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x / n, layer)
Base.:-(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n - x, layer)
Base.:-(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x - n, layer)
Base.:+(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n + x, layer)
Base.:+(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x + n, layer)
Base.://(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n // x, layer)
Base.://(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x // n, layer)
Base.:%(n::Number, layer::T) where {T <: SimpleSDMLayer} = broadcast(x -> n % x, layer)
Base.:%(layer::T, n::Number) where {T <: SimpleSDMLayer} = broadcast(x -> x % n, layer)

function Base.findmax(layer::T) where {T <: SimpleSDMLayer}
val, pos = findmax(collect(layer))
val, pos = findmax(values(layer))
return (val, keys(layer)[pos])
end

function Base.findmin(layer::T) where {T <: SimpleSDMLayer}
val, pos = findmin(collect(layer))
val, pos = findmin(values(layer))
return (val, keys(layer)[pos])
end

function Base.findall(f::Function, layer::T) where {T <: SimpleSDMLayer}
return keys(layer)[findall(f, collect(layer))]
end
return keys(layer)[findall(f, values(layer))]
end
Loading