Skip to content

Commit

Permalink
Add NotImplemented and NotImplementedRule
Browse files Browse the repository at this point in the history
Sometimes you just don't implement something. Maybe it's hard, maybe
you're lazy like me, whatever. For such cases, there is the differential
`NotImplemented` and its associated rule `NotImplementedRule`.

This was born of not being able to figure out the rule for the scalar
multiple parameter in `BLAS.gemm`, but having implemented rules for the
matrix parameters.
  • Loading branch information
ararslan committed May 1, 2019
1 parent c5fd246 commit 6abd7c9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
30 changes: 29 additions & 1 deletion src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ This way, we don't need to implement promotion/conversion rules between subtypes
of `AbstractDifferential` to resolve potential ambiguities.
=#

const PRECEDENCE_LIST = [:wirtinger, :casted, :zero, :dne, :one, :thunk, :fallback]
const PRECEDENCE_LIST = [:wirtinger, :casted, :zero, :dne, :notimplemented, :one,
:thunk, :fallback]

global defs = Expr(:block)

Expand Down Expand Up @@ -227,6 +228,33 @@ mul_dne(::DNE, ::DNE) = DNE()
mul_dne(::DNE, ::Any) = DNE()
mul_dne(::Any, ::DNE) = DNE()

#####
##### `NotImplemented`
#####

"""
NotImplemented <: AbstractDifferential
A differential type which behaves similar to [`DNE`](@ref) but instead signifies that
the actual differential is not implemented in ChainRules, not that it does not exist.
"""
struct NotImplemented <: AbstractDifferential end

extern(::NotImplemented) = error("`NotImplemented` cannot be converted to an external type.")

Base.Broadcast.broadcastable(::NotImplemented) = Ref(NotImplemented())

Base.iterate(x::NotImplemented) = (x, nothing)
Base.iterate(x::NotImplemented, ::Any) = nothing

add_notimplemented(::NotImplemented, ::NotImplemented) = NotImplemented()
add_notimplemented(::NotImplemented, b) = b
add_notimplemented(a, ::NotImplemented) = a

mul_notimplemented(::NotImplemented, ::NotImplemented) = NotImplemented()
mul_notimplemented(::NotImplemented, ::Any) = NotImplemented()
mul_notimplemented(::Any, ::NotImplemented) = NotImplemented()

#####
##### `One`
#####
Expand Down
14 changes: 14 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ struct DNERule <: AbstractRule end

DNERule(args...) = DNE()

#####
##### `NotImplementedRule`
#####

"""
NotImplementedRule <: AbstractRule
Rule indicating that a particular derivative is not implemented by ChainRules.
Note that this does not imply nondifferentiability; for that, use [`DNERule`](@ref).
"""
struct NotImplementedRule <: AbstractRule end

NotImplementedRule(args...) = NotImplemented()

#####
##### `WirtingerRule`
#####
Expand Down

0 comments on commit 6abd7c9

Please sign in to comment.