From 1c8f88518fc91af0d4fff2f0157ba84674c082cf Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Tue, 18 Feb 2020 19:29:23 -0500 Subject: [PATCH] fix dict `x == x` to return missing if `x` contains it closes #34744 use `isequal` to compare keys in `ImmutableDict` --- base/abstractdict.jl | 10 ++++++---- base/dict.jl | 8 ++++---- test/dict.jl | 4 ++++ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/base/abstractdict.jl b/base/abstractdict.jl index 845dcc54f8d66..f4acbf7b4c301 100644 --- a/base/abstractdict.jl +++ b/base/abstractdict.jl @@ -17,9 +17,9 @@ const secret_table_token = :__c782dbf1cf4d6a2e5e3865d7e95634f2e09b5902__ haskey(d::AbstractDict, k) = in(k, keys(d)) function in(p::Pair, a::AbstractDict, valcmp=(==)) - v = get(a,p[1],secret_table_token) + v = get(a, p.first, secret_table_token) if v !== secret_table_token - return valcmp(v, p[2]) + return valcmp(v, p.second) end return false end @@ -474,14 +474,16 @@ function isequal(l::AbstractDict, r::AbstractDict) end function ==(l::AbstractDict, r::AbstractDict) - l === r && return true + if l === r + return any(ismissing, values(l)) ? missing : true + end if isa(l,IdDict) != isa(r,IdDict) return false end length(l) != length(r) && return false anymissing = false for pair in l - isin = in(pair, r, ==) + isin = in(pair, r) if ismissing(isin) anymissing = true elseif !isin diff --git a/base/dict.jl b/base/dict.jl index 872c9ca0e3188..950cd3eee22b9 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -746,7 +746,7 @@ ImmutableDict(KV::Pair, rest::Pair...) = ImmutableDict(ImmutableDict(rest...), K function in(key_value::Pair, dict::ImmutableDict, valcmp=(==)) key, value = key_value while isdefined(dict, :parent) - if dict.key == key + if isequal(dict.key, key) valcmp(value, dict.value) && return true end dict = dict.parent @@ -756,7 +756,7 @@ end function haskey(dict::ImmutableDict, key) while isdefined(dict, :parent) - dict.key == key && return true + isequal(dict.key, key) && return true dict = dict.parent end return false @@ -764,14 +764,14 @@ end function getindex(dict::ImmutableDict, key) while isdefined(dict, :parent) - dict.key == key && return dict.value + isequal(dict.key, key) && return dict.value dict = dict.parent end throw(KeyError(key)) end function get(dict::ImmutableDict, key, default) while isdefined(dict, :parent) - dict.key == key && return dict.value + isequal(dict.key, key) && return dict.value dict = dict.parent end return default diff --git a/test/dict.jl b/test/dict.jl index eaf258a1f04a2..3c1121c118c38 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -277,6 +277,8 @@ end @test ismissing(Dict(1=>missing) == Dict(1=>missing)) @test isequal(Dict(1=>missing), Dict(1=>missing)) + d = Dict(1=>missing) + @test ismissing(d == d) @test Dict(missing=>1) == Dict(missing=>1) @test isequal(Dict(missing=>1), Dict(missing=>1)) @@ -716,6 +718,8 @@ import Base.ImmutableDict d5 = ImmutableDict(v...) @test d5 == d2 @test collect(d5) == v + + @test !haskey(ImmutableDict(-0.0=>1), 0.0) end @testset "filtering" begin