From 0d5df1bbb59c11a5028fd1c457f6dec2a2dafdc6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 14 Oct 2019 13:12:41 +0100 Subject: [PATCH] namedtuple getfield, for iteration Co-Authored-By: dhairyagandhi96 --- src/lib/lib.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 7588a0718..fcecda2f1 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -105,6 +105,9 @@ end @adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) +@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Integer) where {K,N} = + (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing)) + @adjoint function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) first(xs), Δ -> ((Δ, drest...),)