In [49]:
begin
    using StaticArrays
	import Base: +, *, >, <, ==, ^
	using Distributions , Random, Plots, StatsPlots
	plotly()
end

Plots.PlotlyBackend()

In [2]:
struct MultiDual{N,T}
    val::T
    derivs::SVector{N,T}
end

In [3]:
begin
	function Base.:+(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    	return MultiDual{N,T}(f.val + g.val, f.derivs + g.derivs)
	end

	function Base.:*(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    	return MultiDual{N,T}(f.val * g.val, f.val .* g.derivs + g.val .* f.derivs)
	end
	
	Base.:^(f::MultiDual{N, T}, n::Integer) where {N,T} = Base.power_by_squaring(f, n)
	

	function Base.:*(f::MultiDual{N,T}, α::Number) where {N,T}
    	return MultiDual{N,T}(f.val * α, f.derivs .* α)
	end

	Base.:*(α::Number, f::MultiDual{N, T}) where {N, T} = f*α	

end

In [4]:
x = MultiDual(3,SVector(1,0))
y = MultiDual(4,SVector(0,1))
x*x + y*y

MultiDual{2, Int64}(25, [6, 8])

In [5]:
a = SVector(1,2)
length(a)

2

# Show

In [16]:
function show(f::MultiDual{N,T}) where {N,T}
    print(f.val)
    print("  ")
    print(f.derivs)
end

show (generic function with 1 method)

In [17]:
show(MultiDual(3, SVector(1,0)))

3  [1, 0]

# Jacobian

# Log

In [6]:
function MultiDual_Log(f::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(log(f.val), (f*(1/f.val)).derivs)
end

MultiDual_Log (generic function with 1 method)

In [12]:
x = MultiDual(2.0, SVector(1.0,0.0))
MultiDual_Log(x)

MultiDual{2, Float64}(0.6931471805599453, [0.5, 0.0])

# Exp

In [21]:
function Exp(f::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(exp(f.val), (f*exp(f.val)).derivs)
end

Exp (generic function with 1 method)

In [22]:
x = MultiDual(2.0, SVector(1.0,0.0))
show(Exp(x))

7.38905609893065  [7.38905609893065, 0.0]

In [18]:
exp(2)

7.38905609893065

# Sin

In [24]:
function Sin(f::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(sin(f.val), (f*cos(f.val)).derivs)
end

Sin (generic function with 1 method)

In [25]:
x = MultiDual(2.0 , SVector(1.0,0.0,0.0))
Sin(x)

MultiDual{3, Float64}(0.9092974268256817, [-0.4161468365471424, -0.0, -0.0])

In [23]:
sin(2)

0.9092974268256817

# Cos

In [26]:
function Cos(f::MultiDual{N,T}) where {N,T}
    return MultiDual{N,T}(cos(f.val), (f*(-1)*sin(f.val)).derivs)
end

Cos (generic function with 1 method)

In [27]:
x = MultiDual(2.0, SVector(1.0,0.0))
show(Cos(x))

-0.4161468365471424  [-0.9092974268256817, -0.0]

# Comparison between MultiDuals

In [40]:
function Base.:>(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    if f.val>g.val
        return true
    end
    if g.val>=f.val
        return false
    end
end

In [43]:
function Base.:<(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    if f.val<g.val
        return true
    end
    if g.val<=f.val
        return false
    end
end

In [42]:
function Base.:(==)(f::MultiDual{N,T}, g::MultiDual{N,T}) where {N,T}
    if f.val==g.val && f.derivs==g.derivs
        return true
    else
        return false
    end
end

In [46]:
MultiDual(2.0, SVector(1.0,0.0)) == MultiDual(2.0, SVector(1.0,0.0))

true

In [34]:
>(MultiDual(2.0, SVector(1.0,0.0)), MultiDual(3.0, SVector(0.0,1.0)))

false

In [38]:
SVector(1,0)==SVector(1,0)

true

# Powers

In [50]:
function Base.:^(f::MultiDual{N,T}, a) where {N,T}
    return MultiDual(f.val^a, (f*a*(f.val^(a-1))).derivs)
end

In [51]:
MultiDual(3.0, SVector(1.0,0.0))^2

MultiDual{2, Float64}(9.0, [6.0, 0.0])