## Understanding Autodiff

(Alan Edelman, 2017)


  Suppose you knew that sqrt(x)  was the limit of the iteration of $$ t \leftarrow (t+x/t)/2. $$
  How would you obtain the derivative of sqrt(x)?
  
  The entire Julia code needed to do this is in the three cells below.  Notice there is no cheating by calling any packages.
  
   We then use code_native to show the assembler for the derivative which is remarkably short.
   
   We contrast with what a symbolic calculator must do.  Think of all the memory being used because of the explosion of terms.

In [1]:
struct Dual <: Number
    f::Float64 
    f′::Float64 
end

In [2]:
import Base: -,+, *, /
# Teach the sum and quotient rule (our algorithm only uses "+" and "/")
+(x::Dual, y::Dual) = Dual(x.f + y.f, x.f′ + y.f′)
-(x::Dual, y::Dual) = Dual(x.f  -y.f, x.f′ -  y.f′)
/(x::Dual, y::Dual) = Dual(x.f / y.f, (y.f*x.f′ - x.f*y.f′)/(y.f)^2)
*(x::Dual, y::Dual) = Dual(x.f * y.f ,  x.f * y.f′ + x.f′ * y.f)
import Base: convert  
import Base: promote_rule
convert(::Type{Dual}, x::Real) = Dual(x,0.0)
promote_rule(::Type{Dual}, ::Type{<:Number}) = Dual

promote_rule (generic function with 122 methods)

In [3]:
function root(x,N=10) # Babylonian method
    t = 1.0
    for i = 1:N
        t =  (t+x/t) / 2  # one add, and two divides
        println(t)
    end
    t   
end  

root (generic function with 2 methods)

In [4]:
x = 100.0
root(Dual(x,1.0)),  (sqrt(x), x^(-.5)/2)

Dual(50.5, 0.5)
Dual(26.24009900990099, 0.2500980296049407)
Dual(15.025530119986813, 0.12594242045311826)
Dual(10.840434673026925, 0.06835572803193413)
Dual(10.032578510960604, 0.05121765066161276)
Dual(10.000052895642693, 0.050003684003552644)
Dual(10.000000000139897, 0.05000000001878714)
Dual(10.0, 0.05)
Dual(10.0, 0.05)
Dual(10.0, 0.05)


(Dual(10.0, 0.05), (10.0, 0.05))

## The assembler

In [5]:
@code_native(root(Dual(x,1.0)))

	.section	__TEXT,__text,regular,pure_instructions
Filename: In[3]
	pushq	%rbp
	movq	%rsp, %rbp
	pushq	%r15
	pushq	%r14
	pushq	%rbx
	subq	$40, %rsp
	movq	%rsi, %r15
	movq	%rdi, %rbx
	movabsq	$jl_get_ptls_states_fast, %rax
	callq	*%rax
	movq	%rax, %r14
	movq	$0, -32(%rbp)
	movq	$2, -48(%rbp)
	movq	(%r14), %rax
	movq	%rax, -40(%rbp)
	leaq	-48(%rbp), %rax
	movq	%rax, (%r14)
Source line: 2
	movabsq	$root, %rax
	leaq	-64(%rbp), %rdi
	movl	$10, %edx
	movq	%r15, %rsi
	callq	*%rax
	testb	%dl, %dl
	movl	$0, %ecx
	cmovsq	%rax, %rcx
	movq	%rcx, -32(%rbp)
	movb	%dl, %cl
	jns	L109
	xorl	%ecx, %ecx
L109:
	andb	$127, %cl
	cmpb	$1, %cl
	je	L138
	cmpb	$2, %cl
	jne	L144
	movq	(%rax), %rcx
	movq	8(%rax), %rsi
	movq	%rsi, 8(%rbx)
	movq	%rcx, (%rbx)
	jmp	L144
L138:
	movq	(%rax), %rcx
	movq	%rcx, (%rbx)
L144:
	testb	%dl, %dl
	cmovsq	%rax, %rbx
	movq	-40(%rbp), %rax
	movq	%rax, (%r14)
	movq	%rbx, %rax
	addq	$40, %rsp
	popq	%rbx
	popq	%r14
	popq	%r15
	popq	%rbp
	retq
	nopl	(%rax,%rax)


## Symbolically

In [6]:
#Pkg.add("SymPy")
using SymPy                    



In [7]:
function symroot(x,N=5) # Babylonian method
    xx = symbols("x")
    t = 1
    for i = 1:N
        t =  (t+xx/t) / 2
        display(i)
      
        display(simplify(diff(t,xx)))
        display(subs(diff(t,xx),xx,x))
    end  
end  

symroot (generic function with 2 methods)

In [17]:
symroot(100.0);

1

1/2

1/2

2

1      1    
- + --------
4          2
    (x + 1) 

0.250098029604941

3

 6       5        4        3        2             
x  + 14*x  + 147*x  + 340*x  + 375*x  + 126*x + 21
--------------------------------------------------
  / 6       5       4        3       2           \
8*\x  + 14*x  + 63*x  + 100*x  + 63*x  + 14*x + 1/

0.125942420453118

4

 14       13         12          11           10            9            8    
x   + 70*x   + 3199*x   + 52364*x   + 438945*x   + 2014506*x  + 5430215*x  + 8
------------------------------------------------------------------------------
      / 14       13         12          11           10           9           
   16*\x   + 70*x   + 1771*x   + 20540*x   + 126009*x   + 440986*x  + 920795*x

        7            6            5            4           3          2       
836200*x  + 8842635*x  + 5425210*x  + 2017509*x  + 437580*x  + 52819*x  + 3094
------------------------------------------------------------------------------
8            7           6           5           4          3         2       
  + 1173960*x  + 920795*x  + 440986*x  + 126009*x  + 20540*x  + 1771*x  + 70*x

       
*x + 85
-------
    \  
 + 1/  

0.0683557280319341

5

 30        29          28            27              26               25      
x   + 310*x   + 59799*x   + 4851004*x   + 215176549*x   + 5809257090*x   + 102
------------------------------------------------------------------------------
                     / 30        29          28            27             26  
                  32*\x   + 310*x   + 36611*x   + 2161196*x   + 73961629*x   +

           24                  23                   22                   21   
632077611*x   + 1246240871640*x   + 10776333438765*x   + 68124037776390*x   + 
------------------------------------------------------------------------------
             25                24                 23                  22      
 1603620018*x   + 23367042639*x   + 238538538360*x   + 1758637118685*x   + 957

                 20                     19                     18             
321156247784955*x   + 1146261110726340*x   + 3133113888931089*x   + 6614351291
--------------------------------------------------

0.0512176506616128

## Conclusion.  If you can first see that the function droot below is a "derivative" of the function root, and then realize you don't have to type it by hand, then you are starting to get it!

In [9]:
function root(x,N=10) # Babylonian method
    t = 1.0
    for i = 1:N
        t =  (t+x/t) / 2  # one add, and two divides
        println(t)
    end
    t   
end 

root (generic function with 2 methods)

In [10]:
function droot(x,N=10) # Babylonian method
    dt = 0.0
    t = 1.0 
    for i = 1:N
         dt = (dt +(t*1 - x*dt)/t^2) / 2  # Note how x/t expands with the quotient rule
         t =  (t+x/t) / 2        
        println(t," ",dt)
    end
      
end 

droot (generic function with 2 methods)

In [11]:
root(Dual(100,1)) # Done with Julia's type system

Dual(50.5, 0.5)
Dual(26.24009900990099, 0.2500980296049407)
Dual(15.025530119986813, 0.12594242045311826)
Dual(10.840434673026925, 0.06835572803193413)
Dual(10.032578510960604, 0.05121765066161276)
Dual(10.000052895642693, 0.050003684003552644)
Dual(10.000000000139897, 0.05000000001878714)
Dual(10.0, 0.05)
Dual(10.0, 0.05)
Dual(10.0, 0.05)


Dual(10.0, 0.05)

In [12]:
droot(100) # Done the old fashioned way

50.5 0.5
26.24009900990099 0.2500980296049407
15.025530119986813 0.12594242045311826
10.840434673026925 0.06835572803193413
10.032578510960604 0.05121765066161276
10.000052895642693 0.050003684003552644
10.000000000139897 0.05000000001878714
10.0 0.05
10.0 0.05
10.0 0.05


In [13]:
2 + 3*x

302.0

In [14]:
ϵ  = Dual(0,1)

Dual(0.0, 1.0)

In [15]:
ϵ * ϵ 

Dual(0.0, 0.0)

In [18]:
Base.show(io::IO,x::Dual) = print(io,x.f," + ",x.f′," ϵ")

In [19]:
ϵ * ϵ 

0.0 + 0.0 ϵ