# ステップ07

In [1]:
abstract type AbstractFunction end

Variableにcreatorを持たせるようにする

In [2]:
mutable struct Variable
    data::Array{Float64}
    grad::Union{Array{Float64},Nothing}
    creator::Union{AbstractFunction,Nothing}  # added
    Variable(data) = new(data, nothing, nothing);
end

Variableのインスタンスにcreatorをセットする関数set_creator!を定義（!）を名前に入れることを忘れない

In [3]:
function set_creator!(this::Variable, func)
    this.creator = func
end

set_creator! (generic function with 1 method)

In [4]:
function call!(this::AbstractFunction, input::Variable)
    x = input.data
    y = forward(this, x)
    output = Variable(y)
    set_creator!(output, this)  # added
    this.input = input
    this.output = output  # added
    return output
end

call! (generic function with 1 method)

In [5]:
mutable struct Square <: AbstractFunction
    input::Union{Variable,Nothing}
    output::Union{Variable,Nothing}
    Square() = new(nothing, nothing)
end

function forward(this::Square, x::Array{Float64})
    y = x .^ 2
    return y
end

function backward(this::Square, gy::Array{Float64})
    x = this.input.data
    gx = 2 .* x .* gy
    return gx
end

backward (generic function with 1 method)

In [6]:
mutable struct Exp <: AbstractFunction
    input::Union{Variable,Nothing}
    output::Union{Variable,Nothing}
    Exp() = new(nothing, nothing)
end

function forward(this::Exp, x::Array{Float64})
    y = exp.(x)
    return y
end

function backward(this::Exp, gy::Array{Float64})
    x = this.input.data
    gx = exp.(x) .* gy
    return gx
end

backward (generic function with 2 methods)

In [7]:
A = Square();
B = Exp();
C = Square();

In [8]:
x = Variable([0.5])

Variable([0.5], nothing, nothing)

In [9]:
a = call!(A, x)

Variable([0.25], nothing, Square(Variable([0.5], nothing, nothing), Variable(#= circular reference @-2 =#)))

In [10]:
b = call!(B, a)

Variable([1.2840254166877414], nothing, Exp(Variable([0.25], nothing, Square(Variable([0.5], nothing, nothing), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#)))

In [11]:
y = call!(C, b)

Variable([1.648721270700128], nothing, Square(Variable([1.2840254166877414], nothing, Exp(Variable([0.25], nothing, Square(Variable([0.5], nothing, nothing), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#)))

あらかじめ繋がり方を知らなくても変数を辿っていけば逆向きに進めるようになっていることを確認

In [12]:
@assert y.creator == C
@assert y.creator.input == b
@assert y.creator.input.creator == B
@assert y.creator.input.creator.input == a
@assert y.creator.input.creator.input.creator == A
@assert y.creator.input.creator.input.creator.input == x

In [13]:
y.grad = [1.]

1-element Array{Float64,1}:
 1.0

In [14]:
C = y.creator
b = C.input
b.grad = backward(C, y.grad)
B = b.creator
a = B.input
a.grad = backward(B, b.grad)
A = a.creator
x = A.input
x.grad = backward(A, a.grad)
x.grad

1-element Array{Float64,1}:
 3.297442541400256

再帰的的に遡るbackward!を実装。副作用があることに注意。

In [15]:
function backward!(this::Variable)
    f = this.creator
    if !isnothing(f)
        x = f.input
        print(this.grad)
        print("\n")
        x.grad = backward(f, this.grad)
        backward!(x)
    end
end

backward! (generic function with 1 method)

In [16]:
A = Square();
B = Exp();
C = Square();

In [17]:
x = Variable([.5])
a = call!(A, x)
b = call!(B, a)
y = call!(C, b)

Variable([1.648721270700128], nothing, Square(Variable([1.2840254166877414], nothing, Exp(Variable([0.25], nothing, Square(Variable([0.5], nothing, nothing), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#))), Variable(#= circular reference @-2 =#)))

In [18]:
y.grad = [1.]
backward!(y)

[1.0]
[2.568050833375483]
[3.297442541400256]


In [19]:
print(x.grad)

[3.297442541400256]