## Models

In [None]:
# A layer that stores a parameter and leaves the input unchaged
# this is used for residualization
struct StorageLayer
    β::Any
end
StorageLayer(n::Integer; init = ones) = StorageLayer(init(Float32, n))
(m::StorageLayer)(x) = x
@functor StorageLayer

In [None]:
# A layer that multiplies the input by a scalar
struct ScalarLayer
    γ::Any
end
ScalarLayer(n::Integer; init = ones) = ScalarLayer(init(Float32, n))
function (m::ScalarLayer)(x)
    x .* m.γ
end
@functor ScalarLayer

In [None]:
function autoencoder(rng, num_inputs, num_outputs)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(
        Dense(num_inputs, 1024, relu, init = init),
        Dense(1024, 512, relu, init = init),
        Dense(512, 256, relu, init = init),
        Dense(256, num_outputs, init = init),
    )
end;

In [None]:
function build_model(; rng = Random.GLOBAL_RNG)
    if startswith(G.model, "autoencoder")
        m = autoencoder(rng, num_items(G.medium) * 3, num_items(G.medium))
    elseif startswith(G.model, "universal")
        m = autoencoder(rng, sum(num_items(x) for x in ALL_MEDIUMS) * 2, num_items(G.medium))        
    else
        @assert false
    end
    Chain(m..., StorageLayer(1))
end;