In [3]:
import Base.+
import Base.-
import Base.*
import Base./
import Base.cos
import Base.sin

In [4]:
mutable struct ADV
    value # Float64
    children # []
    grad_value # nothing
end

# Operations overload

In [5]:
function +(a::ADV, b::ADV)
    z = ADV(a.value + b.value, [], nothing);
    push!(a.children, (1.0, z));
    push!(b.children, (1.0, z));
    return z
end
function +(a::Number, b::ADV)
    z = ADV(a + b.value, [], nothing);
    push!(b.children, (1.0, z));
    return z
end
function +(a::ADV, b::Number)
    z = ADV(a.value + b, [], nothing);
    push!(a.children, (1.0, z));
    return z
end

+ (generic function with 187 methods)

In [6]:
function -(a::ADV, b::ADV)
    z = ADV(a.value - b.value, [], nothing);
    push!(a.children, (1.0, z));
    push!(b.children, (-1.0, z));
    return z
end
function -(a::Number, b::ADV)
    z = ADV(a - b.value, [], nothing);
    push!(b.children, (-1.0, z));
    return z
end
function -(a::ADV, b::Number)
    z = ADV(a.value - b, [], nothing);
    push!(a.children, (1.0, z));
    return z
end

- (generic function with 195 methods)

In [7]:
function *(a::ADV, b::ADV)
    z = ADV(a.value * b.value, [], nothing);
    push!(a.children, (b.value, z));
    push!(b.children, (a.value, z));
    return z
end
function *(a::Number, b::ADV)
    z = ADV(a * b.value, [], nothing);
    push!(b.children, (a, z));
    return z
end
function *(a::ADV, b::Number)
    z = ADV(a.value * b, [], nothing);
    push!(a.children, (b, z));
    return z
end

* (generic function with 367 methods)

In [8]:
function /(a::ADV, b::ADV)
    z = ADV(a.value / b.value, [], nothing);
    push!(a.children, (1 / b.value, z));
    push!(b.children, (-a.value/(b.value)^2, z));
    return z
end
function /(a::Number, b::ADV)
    z = ADV(a / b.value, [], nothing);
    push!(b.children, (-a/(b.value)^2, z));
    return z
end
function /(a::ADV, b::Number)
    z = ADV(a.value / b, [], nothing);
    push!(a.children, (1 / b, z));
    return z
end

/ (generic function with 121 methods)

In [9]:
function sin(a::ADV)
    z = ADV(sin(a.value), [], nothing);
    push!(a.children, (cos(a.value), z));
    return z
end

sin (generic function with 14 methods)

In [10]:
function cos(a::ADV)
    z = ADV(cos(a.value), [], nothing);
    push!(a.children, (-sin(a.value), z));
    return z
end

cos (generic function with 14 methods)

# Gradient computing

In [11]:
function gradient(Z)
    # Z: ADV
    if isnothing(Z.grad_value)
        s = 0;
        for i = 1 : length(Z.children)
            tmp = Z.children[i];
            s += tmp[1] * gradient(tmp[2]);
        end
        Z.grad_value = s;
    end
    return Z.grad_value;
end

gradient (generic function with 1 method)

# Test case

In [12]:
x = ADV(0.5, [], nothing); y = ADV(4.2, [], nothing);
z = x*y + sin(x);

z.grad_value = 1.0;
println(gradient(x))
println(gradient(y))

5.077582561890373
0.5


In [13]:
y.value + cos(x.value)

5.077582561890373

# Project case

In [14]:
function f(x)
    # a = ADV(1, [], nothing); b = ADV(1, [], nothing);
    # a03 = ADV(0.3, [], nothing); a04 = ADV(0.4, [], nothing); a01 = ADV(0.1, [], nothing);
    a = 1; b = 1;
    for i=1:length(x)
        y = 0.3*sin(a)+0.4*b;
        z = 0.1*a+0.3*cos(b)+x[i];
        a = y;
        b = z;
    end
    return [a;b]
end

f (generic function with 1 method)

In [19]:
function AD(f, x, d)
    y = f(x); N = length(x); n = length(y); sol = []; sol_2 = []; 
    for i = 1 : n
        X = []; [push!(X, ADV(1, [], nothing)) for i = 1 : N]; y = f(X);
        for j = 1 : n
            y[j].grad_value = 0.0;
        end
        y[i].grad_value = d[i];
        tmp = [gradient(X[i]) for i = 1 : N];
        push!(sol, sum(tmp));
        push!(sol_2, tmp);
    end
    return [sol,  sol_2];
end

AD (generic function with 1 method)

In [20]:
x = ones(2020,1); d = [1, 1, 1];
@time sol = AD(f, x, d);
println("The result is:")
println(sol[1])

  0.983487 seconds (1.11 M allocations: 48.297 MiB, 3.17% gc time)
The result is:
Any[0.4285622529594699, 0.8164165952034328]


In [23]:
println(sol[2][1])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

 -6.817341476622069e-286, 1.971086039803325e-285, -5.698966656768713e-285, 1.647732280535163e-284, -4.76405957752526e-284, 1.3774242288218502e-283, -3.9825226264929285e-283, 1.1514598145332506e-282, -3.3291956602203174e-282, 9.625645293164278e-282, -2.783046019700869e-281, 8.046572371904524e-281, -2.3264914226339837e-280, 6.726544035679144e-280, -1.944833934212578e-279, 5.6230644021687226e-279, -1.6257867941685673e-278, 4.700608975905515e-278, -1.3590788671440352e-277, 3.929481002536897e-277, -1.1361239823959357e-276, 3.2848554364860554e-276, -9.497445178348165e-276, 2.745979744308657e-275, -7.939403296945269e-275, 2.295505815080775e-274, -6.636955889489912e-274, 1.9189314699028463e-273, -5.54817305930069e-273, 1.6041335909462395e-272, -4.638003447438276e-272, 1.340977839991523e-271, -3.8771458187284196e-271, 1.1209924020652192e-270, -3.241105762434498e-270, 9.37095250952017e-270, -2.7094071397942256e-269, 7.833661563976712e-269, -2.2649328924255633e-268, 6.548560931941835e-268, -1.893

, 8.125484511000639e-127, -2.3493071516504325e-126, 6.792510754680279e-126, -1.963906776516352e-125, 5.678209378158723e-125, -1.641730764807523e-124, 4.746707499873117e-124, -1.3724072529025418e-123, 3.968017131601077e-123, -1.1472658661180188e-122, 3.3170697703830567e-122, -9.590585919564466e-122, 2.7729093642167492e-121, -8.01726443686299e-121, 2.318017670539686e-120, -6.702044024178546e-120, 1.9377502886582753e-119, -5.602583581439489e-119, 1.6198652102255193e-118, -4.683488003626256e-118, 1.3541287103179017e-117, -3.915168700522117e-117, 1.1319858915001606e-116, -3.272891046519101e-116, 9.462852746505852e-116, -2.735978094879116e-115, 7.910485702554458e-115, -2.287144921496741e-114, 6.612782183831923e-114, -1.9119421685864844e-113, 5.527965014421847e-113, -1.5982908742078872e-112, 4.62111050253352e-112, -1.3360936123213243e-111, 3.863024136533777e-111, -1.1169094247444392e-110, 3.2293007213874864e-110, -9.336820800419452e-110, 2.699538698329534e-109, -7.805129111439816e-109, 2.2566

In [24]:
println(sol[2][2])

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

-1.3678360813618804e-269, 3.954800583126418e-269, -1.1434445885302801e-268, 3.306021377203249e-268, -9.558641893240666e-268, 2.7636734436517327e-267, -7.990560780969224e-267, 2.310296889128754e-266, -6.679721063670558e-266, 1.9312960900566534e-265, -5.583922669696788e-265, 1.6144698133903983e-264, -4.667888386947098e-264, 1.3496184203802573e-263, -3.9021281779639605e-263, 1.1282155079781893e-262, -3.261989802463737e-262, 9.431334169874851e-262, -2.7268651838416586e-261, 7.884137701958314e-261, -2.2795269700817725e-260, 6.59075653389401e-260, -1.9055739308716473e-259, 5.50955263988821e-259, -1.5929673365028715e-258, 4.605718650900374e-258, -1.3316433931295067e-257, 3.8501572954718096e-257, -1.1131892574510862e-256, 3.218544666634578e-256, -9.30572200709282e-256, 2.690547158503159e-255, -7.779132029316828e-255, 2.2491668632639542e-254, -6.502976887061371e-254, 1.8801943548237103e-253, -5.436173114723216e-253, 1.571751242493776e-252, -4.544376928670547e-252, 1.3139077680681352e-251, -3.79

-5.48948526169932e-99, 1.587165290482259e-98, -4.588943296521235e-98, 1.3267931642061474e-97, -3.836133913249114e-97, 1.1091347014332619e-96, -3.206821799592654e-96, 9.271827886297565e-96, -2.6807474105177e-95, 7.750798189078848e-95, -2.2409747494885247e-94, 6.479291171551856e-94, -1.8733461452358746e-93, 5.416373005932489e-93, -1.5660264715925103e-92, 4.5278249985444126e-92, -1.3091221373928677e-91, 3.7850419819674064e-91, -1.0943625805180317e-90, 3.1641114243298967e-90, -9.148340123731279e-90, 2.6450436093645777e-89, -7.647568411907122e-89, 2.2111281042911224e-88, -6.392996087131414e-88, 1.8483957981941069e-87, -5.344234502879252e-87, 1.5451691922120343e-86, -4.4675207105515734e-86, 1.291686463957933e-85, -3.734630523497303e-85, 1.0797872036175953e-84, -3.1219698912804284e-84, 9.026497045776428e-84, -2.609815332659167e-83, 7.545713511388027e-83, -2.1816789744152233e-82, 6.307850333010923e-82, -1.82377775462712e-81, 5.273056784328701e-81, -1.524589702364093e-80, 4.4080195919335576e-80