In [1]:
using Flux
using Statistics

In [2]:
###Complex Weight Initialization
function c_uniform(in::Int64, out::Int64)
    return Flux.glorot_uniform(in,out)/sqrt(2), Flux.glorot_uniform(in,out)/sqrt(2)
end

struct CDense{F, M<:AbstractMatrix, M<:AbstractMatrix, B}
  weight_real::M
  weight_imag::M
  bias::B
  σ::F
  function CDense(weight_real::M,weight_imag::M,bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
    W = [weight_real -weight_imag; weight_imag weight_real]
    b = Flux.create_bias(W, bias, size(W,1))
    new{F,M,M,typeof(b)}(weight_real,weight_imag, b, σ)
  end
end

function CDense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
               init = c_uniform, bias = true)
  CDense(init(out, in)..., bias, σ)
end

function (a::CDense)(x::AbstractVecOrMat)
  Flux._size_check(a, x, 1 => size(a.weight_real, 2)+size(a.weight_imag, 2))
  σ = NNlib.fast_act(a.σ, x)  # replaces tanh => tanh_fast, etc
  xT = Flux._match_eltype(a, x)  # fixes Float64 input, etc.
  return σ.([a.weight_real -a.weight_imag; a.weight_imag a.weight_real] * xT .+ a.bias)
end

function (a::CDense)(x::AbstractArray)
  Flux._size_check(a, x, 1 => size(a.weight_real, 2)+size(a.weight_imag, 2))
  reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)
end

In [3]:
m=CDense(1=>1)

CDense{typeof(identity), Matrix{Float64}, Matrix{Float64}, Vector{Float64}}([-0.26438817115977825;;], [0.7523289166458874;;], [0.0, 0.0], identity)

In [5]:
m([1,1])

LoadError: UndefVarError: `_match_eltype` not defined