# ステップ08

In [1]:
abstract type AbstractFunction end

In [2]:
mutable struct Variable
    data::Array{Float64}
    grad::Union{Array{Float64},Nothing}
    creator::Union{AbstractFunction,Nothing}  # added
    Variable(data) = new(data, nothing, nothing);
end

In [3]:
function set_creator!(this::Variable, func)
    this.creator = func
end

set_creator! (generic function with 1 method)

In [4]:
function call!(this::AbstractFunction, input::Variable)
    x = input.data
    y = forward(this, x)
    output = Variable(y)
    set_creator!(output, this)  # added
    this.input = input
    this.output = output  # added
    return output
end

call! (generic function with 1 method)

In [5]:
mutable struct Square <: AbstractFunction
    input::Union{Variable,Nothing}
    output::Union{Variable,Nothing}
    Square() = new(nothing, nothing)
end

function forward(this::Square, x::Array{Float64})
    y = x .^ 2
    return y
end

function backward(this::Square, gy::Array{Float64})
    x = this.input.data
    gx = 2 .* x .* gy
    return gx
end

backward (generic function with 1 method)

In [6]:
mutable struct Exp <: AbstractFunction
    input::Union{Variable,Nothing}
    output::Union{Variable,Nothing}
    Exp() = new(nothing, nothing)
end

function forward(this::Exp, x::Array{Float64})
    y = exp.(x)
    return y
end

function backward(this::Exp, gy::Array{Float64})
    x = this.input.data
    gx = exp.(x) .* gy
    return gx
end

backward (generic function with 2 methods)

ここだけ変更、ループでbackward

In [7]:
function backward!(this::Variable)
    funcs = Array{AbstractFunction, 1}()
    push!(funcs, this.creator)
    while length(funcs) > 0
        f = pop!(funcs)
        x = f.input
        y = f.output
        x.grad = backward(f, y.grad)
        
        if !isnothing(x.creator)
            push!(funcs, x.creator)
        end
    end
end

backward! (generic function with 1 method)

In [8]:
A = Square();
B = Exp();
C = Square();

In [9]:
x = Variable([.5])
a = call!(A, x)
b = call!(B, a)
y = call!(C, b)

Variable([1.648721270700128], nothing, Square(Variable([1.2840254166877414], nothing, Exp(Variable([0.25], nothing, Square(Variable([0.5], nothing, nothing), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#)))

In [10]:
y.grad = [1.]
backward!(y)

In [11]:
print(x.grad)

[3.297442541400256]