Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ jobs:
fail-fast: false
matrix:
version:
- "lts" # Long-term support version
- "1" # Latest Release
- "min" # Oldest supported Julia release
- "pre" # Pre-release/nightly
os:
- ubuntu-latest
- macOS-13 # Intel
Expand Down Expand Up @@ -70,7 +71,7 @@ jobs:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1.6'
version: '1.10'
- uses: julia-actions/cache@v2
- run: |
julia --project=docs -e '
Expand Down
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
BenchmarkTools = "0.5"
Compat = "3.40, 4"
FiniteDifferences = "0.10"
BenchmarkTools = "1"
Compat = "4"
FiniteDifferences = "0.12"
OffsetArrays = "1"
StaticArrays = "0.11, 0.12, 1"
julia = "1.6"
StaticArrays = "1"
julia = "1.10"

[extensions]
ChainRulesCoreSparseArraysExt = "SparseArrays"
Expand Down
46 changes: 21 additions & 25 deletions test/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,9 @@ struct NoSuperType end

@testset "Base: Tuple" begin
pt1 = ProjectTo((1.0,))
if VERSION >= v"1.6"
@test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0)
@test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent
@test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any}
end
@test @inferred(pt1((1 + im,))) == Tangent{Tuple{Float64}}(1.0)
@test @inferred(pt1(pt1((1,)))) == pt1(pt1((1,))) # accepts correct Tangent
@test @inferred(pt1(Tangent{Any}(1))) == pt1((1,)) # accepts Tangent{Any}
@test pt1([1,]) == Tangent{Tuple{Float64}}(1.0,) # accepts Vector
@test @inferred(pt1(NoTangent())) === NoTangent()
@test @inferred(pt1(ZeroTangent())) === ZeroTangent()
Expand Down Expand Up @@ -240,25 +238,23 @@ struct NoSuperType end
@test padj_complex(adjoint([4, 5, 6 + 7im])) == [4 5 6 - 7im]

# evil test case
if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called
xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]])
pvecvec3 = ProjectTo(xs)
@test pvecvec3(xs)[1] == [1 2 3]
@test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im])
@test pvecvec3(xs)[2] isa LinearAlgebra.AdjOrTransAbsMat{ComplexF64,<:Vector}
@test pvecvec3(collect(xs))[1] == [1 2 3]
ys = permutedims([[1 2 3 + im], Any[4 5 6 7 + 8im]])
@test pvecvec3(ys)[1] == [1 2 3]
@test pvecvec3(ys)[2] == [4 5 6 7 + 8im]
@test pvecvec3(xs)[2] isa LinearAlgebra.AdjOrTransAbsMat{ComplexF64,<:Vector}
@test pvecvec3(ys) isa LinearAlgebra.AdjOrTransAbsVec

zs = adj([[1 2; 3 4], [5 6; 7 8+im]'])
pvecmat = ProjectTo(zs)
@test pvecmat(zs) == zs
@test pvecmat(collect.(zs)) == zs
@test pvecmat(collect.(zs)) isa LinearAlgebra.AdjOrTransAbsVec
end
xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]])
pvecvec3 = ProjectTo(xs)
@test pvecvec3(xs)[1] == [1 2 3]
@test pvecvec3(xs)[2] == adj.([4 + im 5 - im 6 + im 7 - im])
@test pvecvec3(xs)[2] isa LinearAlgebra.AdjOrTransAbsMat{ComplexF64,<:Vector}
@test pvecvec3(collect(xs))[1] == [1 2 3]
ys = permutedims([[1 2 3 + im], Any[4 5 6 7 + 8im]])
@test pvecvec3(ys)[1] == [1 2 3]
@test pvecvec3(ys)[2] == [4 5 6 7 + 8im]
@test pvecvec3(xs)[2] isa LinearAlgebra.AdjOrTransAbsMat{ComplexF64,<:Vector}
@test pvecvec3(ys) isa LinearAlgebra.AdjOrTransAbsVec

zs = adj([[1 2; 3 4], [5 6; 7 8+im]'])
pvecmat = ProjectTo(zs)
@test pvecmat(zs) == zs
@test pvecmat(collect.(zs)) == zs
@test pvecmat(collect.(zs)) isa LinearAlgebra.AdjOrTransAbsVec

# issue #410
@test padj([NoTangent() NoTangent() NoTangent()]) === NoTangent()
Expand Down Expand Up @@ -440,7 +436,7 @@ struct NoSuperType end
@test eval(Meta.parse(str))(ones(1, 3)) isa Adjoint{Float64,Vector{Float64}}
end

VERSION > v"1.1" && @testset "allocation tests" begin
@testset "allocation tests" begin
# For sure these fail on Julia 1.0, not sure about 1.3 etc.
# We only really care about current stable anyway
# Each "@test 33 > ..." is zero on nightly, 32 on 1.5.
Expand Down
11 changes: 2 additions & 9 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,9 @@ macro test_macro_throws(err_expr, expr)
err = nothing
try
@macroexpand($(esc(expr)))
catch _err
catch err
# https://github.com/JuliaLang/julia/pull/38379
if VERSION >= v"1.7.0-DEV.937"
err = _err
else
# until Julia v1.7
# all errors thrown at macro expansion time are LoadErrors, we need to unwrap
@assert _err isa LoadError
err = _err.error
end
# Since Julia 1.7, errors are not wrapped in LoadError
end
# Reuse `@test_throws` logic
if err !== nothing
Expand Down
22 changes: 5 additions & 17 deletions test/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,27 +142,16 @@ end

# Test getproperty is inferrable
_unpacknamedtuple = tangent -> (tangent.x, tangent.y)
if VERSION ≥ v"1.2"
@inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0))
@inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0))
end
@inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0))
@inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0))
end

@testset "reverse" begin
c = Tangent{Tuple{Int,Int,String}}(1, 2, "something")
cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1)
@test reverse(c) === cr

if VERSION < v"1.9-"
# can't reverse a named tuple or a dict
@test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0))

d = Dict(:x => 1, :y => 2.0)
cdict = Tangent{typeof(d),typeof(d)}(d)
@test_throws MethodError reverse(Tangent{Foo}())
else
# These now work but do we care?
end
# On Julia 1.9+ these work but we don't test them
end

@testset "unset properties" begin
Expand Down Expand Up @@ -440,11 +429,10 @@ end

@testset "Internals don't allocate a ton" begin
bk = (; x=1.0, y=2.0)
VERSION >= v"1.5" &&
@test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32
@test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 80

# weaker version of the above (which should pass on all versions)
@test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48
@test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 80
@test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48
end
end
Expand Down
30 changes: 12 additions & 18 deletions test/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,8 @@
m = rand(3, 3)
tm = @thunk(m)

if VERSION >= v"1.2"
@test 3 == mapreduce(_ -> 1, +, t)
@test 3 == mapreduce((_, _) -> 1, +, v, t)
end
@test 3 == mapreduce(_ -> 1, +, t)
@test 3 == mapreduce((_, _) -> 1, +, v, t)
@test 10 == sum(@thunk([1 2; 3 4]))
@test [4 6] == sum!([1 1], @thunk([1 2; 3 4]))

Expand Down Expand Up @@ -156,14 +154,12 @@
@test Symmetric(a) == Symmetric(t)
@test Hermitian(a) == Hermitian(t)

if VERSION >= v"1.2"
@test diagm(0 => v) == diagm(0 => tv)
@test diagm(3, 4, 0 => v) == diagm(3, 4, 0 => tv)
# Check against accidential type piracy
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472
@test Base.which(diagm, Tuple{}()).module != ChainRulesCore
@test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore
end
@test diagm(0 => v) == diagm(0 => tv)
@test diagm(3, 4, 0 => v) == diagm(3, 4, 0 => tv)
# Check against accidential type piracy
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/472
@test Base.which(diagm, Tuple{}()).module != ChainRulesCore
@test Base.which(diagm, Tuple{Int,Int}).module != ChainRulesCore
@test tril(a) == tril(t)
@test tril(a, 1) == tril(t, 1)
@test triu(a) == triu(t)
Expand All @@ -176,12 +172,10 @@
@test dot(v, v) == dot(tv, v)
@test dot(v, v) == dot(tv, tv)

if VERSION >= v"1.2"
@test_throws MutateThunkException ldiv!(2.0, deepcopy(t)) ==
ldiv!(2.0, deepcopy(a))
@test_throws MutateThunkException rdiv!(deepcopy(t), 2.0) ==
rdiv!(deepcopy(a), 2.0)
end
@test_throws MutateThunkException ldiv!(2.0, deepcopy(t)) ==
ldiv!(2.0, deepcopy(a))
@test_throws MutateThunkException rdiv!(deepcopy(t), 2.0) ==
rdiv!(deepcopy(a), 2.0)

@test mul!(deepcopy(a), a, a) == mul!(deepcopy(a), t, a)

Expand Down
Loading