In [4]:
using Random
import LinearAlgebra.qr

In [5]:
function limit_cycle(rng::AbstractRNG, dims...;σ=1.5f0, θ₁=π/21,θ₂=π/3.8)
    @assert isequal(dims[1], dims[2]) && iseven(dims[1]) && iseven(dims[2])
    W = zeros(Float32, dims...)
    rand_block = θ ->  [cos(θ) -sin(θ) ; sin(θ) cos(θ)]
    for i in range(1, dims[1], step=2)
        W[i:i+1, i:i+1] = Float32(σ)*rand_block( Float32(θ₁)+rand(rng,Float32)*Float32(θ₂ - θ₁))
    end
    return W
end

limit_cycle (generic function with 1 method)

In [8]:
limit_cycle(dims...;kwargs...) = limit_cycle(Random.GLOBAL_RNG, dims...; kwargs...)
limit_cycle(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> limit_cycle(rng, dims...; kwargs...)

function orthogonal_init(rng::AbstractRNG, dims...; σ=1.f0)
    Q,R = qr(randn(rng, Float32, dims...))
    return Float32(σ).*Q*sign.(Diagonal(R))
end

orthogonal_init(dims...;kwargs...) = orthogonal_init(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal_init(rng::AbstractRNG; kwargs...) = (dims...; kwargs...) -> orthogonal_init(rng, dims...; kwargs...)


orthogonal_init (generic function with 3 methods)

In [15]:
abstract type AbstractRNNDELayer <: Function end
using Flux
import Flux: functor, gate
import Flux: he_normal
"""
Vanilla RNN
"""
mutable struct ∂RNNCell{F,A,V} <:AbstractRNNDELayer
  σ::F
  Wᵢ::A
  Wᵣ::A
  b::V
end

∂RNNCell(in::Integer, out::Integer, σ = tanh;
        init = he_normal, initWᵣ=limit_cycle ) =
  ∂RNNCell(σ, init(out, in), initWᵣ(out, out),
          zeros(out))

function (m::∂RNNCell)(h, x)
  σ, Wᵢ, Wᵣ, b = m.σ, m.Wᵢ, m.Wᵣ, m.b
  ḣ = σ.(Wi*x .+ Wh*h .+ b).-h
  return ḣ
end

@Flux.functor ∂RNNCell

"""
GRU
"""
struct ∂GRUCell{A,V} <:AbstractRNNDELayer
  Wᵢ::A
  Wᵣ::A
  b::V
end

function ∂GRUCell(in, out; initWᵢ = he_normal, initWᵣ = limit_cycle, initWᵣᵤ = zeros, initb = zeros)
  ∂GRUCell(
  initWᵢ(out * 3, in),
  vcat( initWᵣᵤ(2*out,in), initWᵣ(out,in) ),
  initb(out * 3)
  )
end

∂GRUCell(in, out, init) = ∂GRUCell(in, out; initWᵢ = init, initWᵣ = init, initWᵣᵤ = init, initb = init)

function (m::∂GRUCell)(h, x)
  b, o = m.b, size(h, 1)
  gx, gh = m.Wi*x, m.Wh*h
  r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
  z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
  h̃ = tanh.(gate(gx, o, 3) .+ r.* gate(gh, o, 3) .+ gate(b, o, 3))
  ḣ =  (z .- 1).* (h .- h̃)
  return ḣ
end

@Flux.functor ∂GRUCell




In [13]:
∂GRUCell(1,10)

LoadError: UndefVarError: he_normal not defined