## Models
* To define a new model:
  * add the architecture to `build_model` 
  * fill in any asserts

### Basic Layers

In [None]:
# A layer that stores a parameter and leaves the input unchaged
# this is used for residualization
struct StorageLayer
    β::Any
end
StorageLayer(n::Integer; init = ones) = StorageLayer(init(Float32, n))
function (m::StorageLayer)(x)
    x
end
@functor StorageLayer

In [None]:
# A layer that multiplies the input by a scalar
struct ScalarLayer
    γ::Any
end
ScalarLayer(n::Integer; init = ones) = ScalarLayer(init(Float32, n))
function (m::ScalarLayer)(x)
    x .* m.γ
end
@functor ScalarLayer

In [None]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
@functor BiasLayer

### Custom Models

In [None]:
# Implements a baseline predictor given by R[user, item] = a[item]
function item_biases(rng, num_items)
    A = BiasLayer(num_items)
    Chain(zero_reshape, A)
end

function zero_reshape(x)
    reshape(x, (1, length(x))) * zero(Float32)
end;

In [None]:
# R[user, item] = weighted average of other items the user has seen,
# where the weights are given by the neighborhood similarity matrix W

struct ItemCFLayer
    W::Any
end

function ItemCFLayer(n::Integer; init = Flux.glorot_uniform)
    W = init(n, n)
    for i = 1:n
        W[i, i] = 0
    end
    ItemCFLayer(W)
end

function (m::ItemCFLayer)(x)
    (m.W * x) ./ (abs.(m.W) * nonzero.(x) .+ eps(Float32))
end

@functor ItemCFLayer

function item_based_collaborative_filtering(rng, num_items)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(ItemCFLayer(num_items; init = init), BiasLayer(num_items))
end

function nonzero(x)
    t = eltype(x)
    x != zero(t) ? one(t) : zero(t)
end;

In [None]:
# embarrassingly shallow autoencoder
struct EaseLayer
    W::Any
end

function EaseLayer(n::Integer; init = Flux.glorot_uniform)
    W = init(n, n)
    for i = 1:n
        W[i, i] = 0
    end
    EaseLayer(W)
end

function (m::EaseLayer)(x)
    m.W * x
end
@functor EaseLayer

function ease(rng, num_items)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(EaseLayer(num_items; init = init), BiasLayer(num_items))
end;

In [None]:
# inputs are the concatenation of implicit and explicit ratings
function autoencoder(rng, num_inputs, num_outputs)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(
        Dense(num_inputs, 1024, relu, init = init),
        Dense(1024, 512, relu, init = init),
        Dense(512, 256, relu, init = init),
        Dense(256, num_outputs, init = init),
    )
end;

### Dispatch

In [None]:
function build_model(; rng = Random.GLOBAL_RNG)
    if G.model == "item_biases"
        m = item_biases(rng, num_items(G.medium))
    elseif G.model == "item_based_collaborative_filtering"
        m = item_based_collaborative_filtering(rng, num_items(G.medium))
    elseif startswith(G.model, "autoencoder")
        m = autoencoder(rng, num_items(G.medium) * 3, num_items(G.medium))
    elseif G.model == "ease"
        m = ease(rng, num_items(G.medium))
    elseif startswith(G.model, "cross_media")
        params = split(G.model, ".")
        if "anime" in params
            medium = "anime"
        elseif "manga" in params
            medium = "manga"
        else
            @assert false
        end
        @assert medium != G.medium        
        m = autoencoder(rng, num_items(medium) * 3, num_items(G.medium))
    else
        @assert false
    end
    Chain(m..., StorageLayer(1))
end;