-
Notifications
You must be signed in to change notification settings - Fork 60
/
differential_arithmetic.jl
80 lines (63 loc) · 2.67 KB
/
differential_arithmetic.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#==
All differentials need to define + and *.
That happens here.
We just use @eval to define all the combinations for AbstractDifferential
subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.
Notice:
The precedence goes:
`Zero, DoesNotExist, One, AbstractThunk, Composite, Any`
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
==#
Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DoesNotExist, :One, :AbstractThunk, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a
@eval Base.:*(::Zero, ::$T) = Zero()
@eval Base.:*(::$T, ::Zero) = Zero()
end
Base.:+(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
Base.:*(::DoesNotExist, ::DoesNotExist) = DoesNotExist()
for T in (:One, :AbstractThunk, :Any)
@eval Base.:+(::DoesNotExist, b::$T) = b
@eval Base.:+(a::$T, ::DoesNotExist) = a
@eval Base.:*(::DoesNotExist, ::$T) = DoesNotExist()
@eval Base.:*(::$T, ::DoesNotExist) = DoesNotExist()
end
Base.:+(a::One, b::One) = extern(a) + extern(b)
Base.:*(::One, ::One) = One()
for T in (:AbstractThunk, :Any)
@eval Base.:+(a::One, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::One) = a + extern(b)
@eval Base.:*(::One, b::$T) = b
@eval Base.:*(a::$T, ::One) = a
end
Base.:+(a::AbstractThunk, b::AbstractThunk) = unthunk(a) + unthunk(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = unthunk(a) * unthunk(b)
for T in (:Any,)
@eval Base.:+(a::AbstractThunk, b::$T) = unthunk(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + unthunk(b)
@eval Base.:*(a::AbstractThunk, b::$T) = unthunk(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b)
end
################## Composite ##############################################################
# We intentionally do not define, `Base.*(::Composite, ::Composite)` as that is not meaningful
# In general one doesn't have to represent multiplications of 2 differentials
# Only of a differential and a scaling factor (generally `Real`)
Base.:*(s::Any, comp::Composite) = map(x->s*x, comp)
Base.:*(comp::Composite, s::Any) = map(x->x*s, comp)
function Base.:+(a::Composite{P}, b::Composite{P}) where P
data = elementwise_add(backing(a), backing(b))
return Composite{P, typeof(data)}(data)
end
function Base.:+(a::P, d::Composite{P}) where P
try
return construct(P, elementwise_add(backing(a), backing(d)))
catch err
throw(PrimalAdditionFailedException(a, d, err))
end
end
Base.:+(a::Composite{P}, b::P) where P = b + a