In [17]:
mutable struct ADV
       a
       b
end

# overload operations
import Base.+
import Base.*
import Base.-
import Base./
import Base.sin
import Base.cos

function +(A::ADV,B::ADV)
    return ADV(A.a+B.a, A.b+B.b);
end
function +(A::Number,B::ADV)
    return ADV(A+B.a, B.b);
end
function -(A::ADV,B::ADV)
    return ADV(A.a-B.a, A.b-B.b);
end
function *(A::ADV,B::ADV)
    return ADV(A.a*B.a, A.a*B.b+A.b*B.a);
end
function *(A::Number, B::ADV)
    return ADV(A*B.a, A*B.b)
end
function /(A::ADV,B::ADV)
    return ADV(A.a/B.a, (A.b*B.a - B.b*A.a)/(B.a)^2);
end
function sin(A::ADV)
    return ADV(sin(A.a), cos(A.a)*A.b);
end
function cos(A::ADV)
    return ADV(cos(A.a), -sin(A.a)*A.b);
end

cos (generic function with 15 methods)

# AD_forward in special case

In [18]:
function df(x, lambda)
    a = 1; b = 1; da = 0; db = 0;
    for i = 1 : length(x)
        tmp_a = da; tmp_b = db;
        da = 0.3*cos(a)*tmp_a + 0.4 * tmp_b;
        db = 0.1 * tmp_a + 0.3*(-sin(b))*tmp_b + lambda[i];
        y = 0.3*sin(a)+0.4*b;
        z = 0.1*a+0.3*cos(b)+x[i];
        a = y;
        b = z;
    end
    return [a;b;da;db]
end

df (generic function with 1 method)

In [19]:
function f(x)
    a = 1; b = 1;
    for i=1:length(x)
        y = 0.3*sin(a)+0.4*b;
        z = 0.1*a+0.3*cos(b)+x[i];
        a = y;
        b = z;
    end
    return [a;b]
end

f (generic function with 1 method)

In [20]:
N = 2020; x = ones(1,N); lambda = ones(1,N);
@time y = df(x, lambda);
println("The result is:")
println(y[3:4])

  0.066320 seconds (109.88 k allocations: 6.218 MiB)
The result is:
[0.42856225295947, 0.816416595203433]


It is worth noting that, for AD forward, we let b in ADV(a,b) be zero for each dimension in order to find the partial derivative of each dimension.

# AD_forward general

In [21]:
function AD_forward_normal(f, x, d)
    # input function, variable, init derivative
    y = f(x); N = length(d); n = length(y);
    A = zeros(n, N); X = [];
    for i = 1 : N
        push!(X, ADV(1, 0))
    end
    for i = 1 : N
        if i == 1
            X[i].b = d[i];
        else 
            X[i].b = d[i]; X[i-1].b = 0;
        end
        y = f(X);
        for j = 1 : n
            A[j, i] = y[j].b;
        end
    end
    y = zeros(n, 1);
    for i = 1 : n
        y[i] = sum(A[i, :]);
    end
    return y
end

AD_forward_normal (generic function with 1 method)

In [22]:
function f(x)
    a = 1; b = 1;
    for i=1:length(x)
        y = 0.3*sin(a)+0.4*b;
        z = 0.1*a+0.3*cos(b)+x[i];
        a = y;
        b = z;
    end
    return [a;b]
end

f (generic function with 1 method)

In [23]:
x = ones(2020,1); d = ones(2020, 1);
@time y = AD_forward_normal(f, x, d);
println("The result is:")
println(y)

 13.422898 seconds (155.15 M allocations: 2.865 GiB, 6.39% gc time)
The result is:
[0.4285622529594699; 0.8164165952034328]


# AD_forward_fast

In [24]:
# fast implementation of AD forward
function AD_forward_fast(f, x, d)
    N = length(x); X = []; 
    for i = 1 : N
        push!(X, ADV(1, d[i]))
    end
    A = []; y = f(X);
    for i = 1 : length(y)
        push!(A, y[i].b)
    end
    return A
end

AD_forward_fast (generic function with 1 method)

In [25]:
function f(x)
    a = 1; b = 1;
    for i=1:length(x)
        y = 0.3*sin(a)+0.4*b;
        z = 0.1*a+0.3*cos(b)+x[i];
        a = y;
        b = z;
    end
    return [a;b]
end

f (generic function with 1 method)

In [26]:
N = 2020; x = ones(N,1); d = ones(N, 1);
@time Answer = AD_forward_fast(f, x, d);
println("The result is:")
print(Answer)

  0.081122 seconds (179.63 k allocations: 7.154 MiB)
The result is:
Any[0.42856225295947, 0.816416595203433]

# AD_forward with list

In [27]:
# tring different data structure
import Base.+
import Base.*
import Base.-
import Base./
import Base.sin
import Base.cos
function +(A::Float64,B::Array)
    return [A[1]+B[1], B[2]];
end
function /(A::Array,B::Array)
    return [A[1]/B[1], (A[2]*B[1] - B[2]*A[1])/(B[1])^2];
end
function sin(A::Array)
    return [sin(A[1]), cos(A[1])*A[2]];
end
function cos(A::Array)
    return [cos(A[1]), -sin(A[1])*A[2]];
end

cos (generic function with 15 methods)

In [28]:
function f(x)
    a = 1; b = 1;
    for i=1:length(x)
        y = 0.3*sin(a)+0.4*b;
        z = 0.1*a+0.3*cos(b)+x[i];
        a = y;
        b = z;
    end
    return [a;b]
end

f (generic function with 1 method)

In [29]:
function AD_forward_list(f, x, d)
    y = f(x); N = length(d); n = length(y);
    X = [[1,0] for _ in 1:N]; X[1][2] = 1;
    Answer = zeros(n, N); y = f(X);
    Answer[1,1] = y[2]; Answer[2,1] = y[4];
    for i = 2 : N
        X[i][2] = 1; X[i-1][2] = 0; y = f(X);
        for j = 1 : n
            Answer[j,i] = y[2*j];
        end 
    end
    y = zeros(n, 1);
    for i = 1 : n
        y[i] = sum(Answer[i, :]);
    end
    return [y, Answer]
end

AD_forward_list (generic function with 1 method)

In [30]:
x = ones(2020, 1); d = ones(2020, 1);
@time Answer = AD_forward_list(f, x, d);
println("The result is:")
println(Answer[1])

  4.910053 seconds (32.94 M allocations: 2.933 GiB, 9.39% gc time)
The result is:
[0.4285622529594699; 0.8164165952034328]


In [33]:
println(Answer[2][1,:])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

-1.1361239823959357e-276, 3.2848554364860546e-276, -9.49744517834817e-276, 2.7459797443086562e-275, -7.939403296945273e-275, 2.2955058150807738e-274, -6.636955889489916e-274, 1.918931469902845e-273, -5.548173059300694e-273, 1.6041335909462385e-272, -4.6380034474382804e-272, 1.3409778399915222e-271, -3.877145818728422e-271, 1.1209924020652188e-270, -3.2411057624345016e-270, 9.37095250952017e-270, -2.7094071397942287e-269, 7.833661563976711e-269, -2.2649328924255655e-268, 6.548560931941835e-268, -1.8933739901419317e-267, 5.474279164236104e-267, -1.5827687780660337e-266, 4.5762317369326576e-266, -1.3231178931706217e-265, 3.825507666274082e-265, -1.1060623531930873e-264, 3.1979387727709223e-264, -9.246144545890975e-264, 2.6733216311528587e-263, -7.729328162802504e-263, 2.234767158283464e-262, -6.461343271433254e-262, 1.8681569002187838e-261, -5.401369432367124e-261, 1.5616885145724756e-260, -4.5152827391011514e-260, 1.30549581582894e-259, -3.7745572616924224e-259, 1.0913311516627507e-258, 

, -2.226627492382067e-102, 6.437809197549416e-102, -1.8613525345084433e-101, 5.381696088617291e-101, -1.55600039502511e-100, 4.498836778398216e-100, -1.3007408239068633e-99, 3.760809236605323e-99, -1.0873562091544434e-98, 3.143854025123625e-98, -9.08977025903555e-98, 2.628109406637312e-97, -7.598606847175804e-97, 2.196971932559209e-96, -6.352066594918763e-96, 1.836561925567864e-95, -5.3100194335284854e-95, 1.535276648946493e-94, -4.4389185728897585e-94, 1.283416777833205e-93, -3.710720524354721e-93, 1.0728741473120102e-92, -3.101982292249512e-92, 8.968707249692094e-92, -2.5931066697834354e-91, 7.497404046171315e-91, -2.1677113430517257e-90, 6.267465964358354e-90, -1.8121014925966488e-89, 5.23929741183176e-89, -1.5148289139050866e-88, 4.3797983947393087e-88, -1.2663234635385876e-87, 3.6612989244278354e-87, -1.0585849659824998e-86, 3.0606682330946534e-86, -8.849256627223473e-86, 2.5585701212982675e-85, -7.397549121695756e-85, 2.138840464612259e-84, -6.183992092471032e-84, 1.7879668393949

In [34]:
println(Answer[2][2,:])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

 -3.4417144303372465e-286, 9.95096884891377e-286, -2.8771062514431012e-285, 8.318527077890066e-285, -2.4051212120123096e-284, 6.953884973033933e-284, -2.0105646225509092e-283, 5.8131103938718465e-283, -1.6807344599780575e-282, 4.859478202814961e-282, -1.4050124493753768e-281, 4.062287966959658e-281, -1.174522228172538e-280, 3.395875614164906e-280, -9.818435879943026e-280, 2.8387872254932142e-279, -8.207735947113222e-279, 2.373088365783743e-278, -6.861268963944743e-278, 1.9837867175267863e-277, -5.735687904549248e-277, 1.6583494308000543e-276, -4.794756758737885e-276, 1.3862996512364392e-275, -4.008183980378083e-275, 1.1588792369839128e-274, -3.350647306827801e-274, 9.687668064509678e-274, -2.8009785553040896e-273, 8.098420399037965e-273, -2.3414821522056976e-272, 6.769886470390092e-272, -1.9573654566957625e-271, 5.659296574356062e-271, -1.6362625388609989e-270, 4.730897313655066e-270, -1.367836081361882e-269, 3.954800583126422e-269, -1.1434445885302813e-268, 3.3060213772032525e-268, -9

, -1.124106222024508e-102, 3.2501086957255223e-102, -9.396982533369643e-102, 2.7169331551502563e-101, -7.855421400783512e-101, 2.271224276056739e-100, -6.566751099595609e-100, 1.8986332815627702e-99, -5.4894852616993295e-99, 1.5871652904822622e-98, -4.5889432965212436e-98, 1.32679316420615e-97, -3.8361339132491215e-97, 1.109134701433264e-96, -3.2068217995926595e-96, 9.271827886297582e-96, -2.6807474105177047e-95, 7.750798189078863e-95, -2.2409747494885287e-94, 6.479291171551868e-94, -1.8733461452358776e-93, 5.416373005932498e-93, -1.566026471592513e-92, 4.527824998544421e-92, -1.30912213739287e-91, 3.785041981967413e-91, -1.0943625805180335e-90, 3.164111424329902e-90, -9.148340123731295e-90, 2.6450436093645823e-89, -7.647568411907136e-89, 2.211128104291126e-88, -6.392996087131424e-88, 1.8483957981941098e-87, -5.344234502879261e-87, 1.5451691922120368e-86, -4.467520710551581e-86, 1.291686463957935e-85, -3.7346305234973086e-85, 1.079787203617597e-84, -3.121969891280433e-84, 9.02649704577