In [1]:
using Flux

│ For performance reasons, it is recommended to upgrade to a driver that supports CUDA 11.2 or higher.
└ @ CUDA C:\Users\anhin\.julia\packages\CUDA\Ozu5O\src\initialization.jl:42


In [3]:
# let's try to incorporate https://github.com/FluxML/Zygote.jl/pull/412

In [2]:
# Just for reuse between get and get!
function ∇getdictkey(d::AbstractDict, k, ctx, Δ)
    grad = grad_mut(ctx, d)
    grad[k] = accum(get(grad, k, nothing), Δ)
    return (nothing, grad, nothing)
end

∇getdictkey (generic function with 1 method)

In [6]:
using Zygote

In [8]:
import Zygote: @adjoint!

In [9]:
@adjoint! function get!(f::Function, d::AbstractDict, k)
    # Will be replaced if ∇f is called
    back = Δ -> ∇getdictkey(d, k, __context__, Δ)

    function ∇f()
        res,fback = pullback(__context__,f)
        back = function(Δ)
                Δd = get(grad_mut(__context__, d), k, nothing)
                delete!(grad_mut(__context__, d), k)
                fback(Δ) # Always return empty tuple due to no arg?
                return (nothing, Δd, nothing)
            end
        return res
    end
    return get!(∇f, d, k), back
end

In [10]:
@adjoint! function get(f::Function, d::AbstractDict, k)
    # Will be replaced if ∇f is called
    back = Δ -> ∇getdictkey(d, k, __context__, Δ)

    function ∇f()
        res,fback = pullback(__context__,f)
        back = function(Δ)
                fback(Δ) # Always return empty tuple due to no arg?
                return (nothing, nothing, nothing)
            end
        return res
    end
    return get(∇f, d, k), back
end

In [11]:
# define ReLU

f(x)=max(0,x) # I don't want to use 'relu', because this would mask the built-in Flux 'relu', 
              #     

f (generic function with 1 method)

In [12]:
test_dict = Dict(:x=>0f0, "y"=>4f0, 8=>-3f0)

Dict{Any, Float32} with 3 entries:
  "y" => 4.0
  8   => -3.0
  :x  => 0.0

In [13]:
# we did modify it, and it is rather horrible
# let's perhaps try to encapsulate this awful behavior

function my_map(my_f, my_dict::Dict{Any, Float32})
    new_dict = Dict{Any, Float32}()
    for k in keys(my_dict)
        new_dict[k] = f(my_dict[k])
    end
    new_dict
end

my_map (generic function with 1 method)

In [14]:
my_map(f, test_dict)

Dict{Any, Float32} with 3 entries:
  :x  => 0.0
  8   => 0.0
  "y" => 4.0

In [15]:
sum(values(my_map(f, test_dict)))

4.0f0

In [16]:
p = params(test_dict)

Params([])

In [17]:
grads = gradient(()->sum(values(my_map(f, test_dict))), p)

LoadError: MethodError: no method matching getindex(::Dict{Any, Any})
[0mClosest candidates are:
[0m  getindex(::Dict{K, V}, [91m::Any[39m) where {K, V} at dict.jl:480
[0m  getindex(::AbstractDict, [91m::Any[39m) at abstractdict.jl:494
[0m  getindex(::AbstractDict, [91m::Any[39m, [91m::Any[39m, [91m::Any...[39m) at abstractdict.jl:504

In [18]:
function my_map2(my_f, my_dict::Dict{Any, Float32})
    new_dict::Dict{Any, Float32} = deepcopy(my_dict)
    map!(my_f, values(new_dict::Dict{Any, Float32}))
    new_dict
end

my_map2 (generic function with 1 method)

In [19]:
my_map2(f, test_dict)

Dict{Any, Float32} with 3 entries:
  :x  => 0.0
  8   => 0.0
  "y" => 4.0

In [20]:
sum(values(my_map2(f, test_dict)))

4.0f0

In [21]:
grads = gradient(()->sum(values(my_map2(f, test_dict))), p)

LoadError: MethodError: no method matching getindex(::Dict{Any, Any})
[0mClosest candidates are:
[0m  getindex(::Dict{K, V}, [91m::Any[39m) where {K, V} at dict.jl:480
[0m  getindex(::AbstractDict, [91m::Any[39m) at abstractdict.jl:494
[0m  getindex(::AbstractDict, [91m::Any[39m, [91m::Any[39m, [91m::Any...[39m) at abstractdict.jl:504