## Models
* To define a new model:
  * add the architecture to `build_model` 
  * add the regularization to `regularization_loss`
  * fill in any asserts

### Basic Layers

In [None]:
# a layer to join multiple inputs
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths)

In [None]:
# A layer that stores a parameter and leaves the input unchaged
# this is used for residualization
struct StorageLayer
    β::Any
end
StorageLayer(n::Integer; init = ones) = StorageLayer(init(Float32, n))
function (m::StorageLayer)(x)
    x
end
Flux.@functor StorageLayer

In [None]:
# A layer that performs a linear regression on the input
struct ScalarLayer
    γ::Any
end
ScalarLayer(n::Integer; init = ones) = ScalarLayer(init(Float32, n))
function (m::ScalarLayer)(x)
    x .* m.γ
end
Flux.@functor ScalarLayer

In [None]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

### Custom Models

In [None]:
# Implements a baseline predictor given by R[user, item] = a[item]
function item_biases(rng)
    A = BiasLayer(num_items())
    Chain(zero_reshape, A)
end

function zero_reshape(x) 
    reshape(x, (1, length(x))) * zero(Float32)
end;

In [None]:
# R[user, item] = dot(U[:, user], A[:, item])
function matrix_factorization(rng, K)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    U = Flux.Embedding(G.num_users => K, init = init)
    A = Dense(K, num_items(), bias = false, init = init)
    Chain(U, A)
end

# λ_u ||U|| + λ_a ||a||
function matrix_factorization_regularization(m, x)
    sum(m[1](x) .^ 2) * G.regularization_params[1] +
    sum(m[2].weight .^ 2) * G.regularization_params[2]
end;

In [None]:
# R[user, item] = weighted average of other items the user has seen,
# where the weights are given by the neighborhood similarity matrix W

struct ItemCFLayer
    W::Any
    mask::Any
end
ItemCFLayer(W) = ItemCFLayer(W, item_cf_mask())
ItemCFLayer(n::Integer; init = Flux.glorot_uniform) = ItemCFLayer(init(n, n))
function (m::ItemCFLayer)(x)
    W = m.W .* m.mask
    (W * x) ./ (abs.(W) * nonzero.(x) .+ 1.0f0)
end
Flux.@functor ItemCFLayer (W,)

function item_based_collaborative_filtering(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(ItemCFLayer(num_items(); init = init))
end

function item_cf_mask()
    collect(1.0f0 .- LinearAlgebra.I(num_items()))
end;

function nonzero(x)
    t = eltype(x)
    x != zero(t) ? one(t) : zero(t)
end

function device(x::ItemCFLayer)
    ItemCFLayer(x.W |> device, x.mask |> device)
end

function Flux.cpu(x::ItemCFLayer)
    ItemCFLayer(x.W |> cpu, x.mask |> cpu)
end;

In [None]:
# inputs are the concatenation of implicit and explicit ratings
function autoencoder(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(
        Dense(num_items() * 3, 1024, relu, init = init),
        Dense(1024, 512, relu, init = init),
        Dense(512, 256, relu, init = init),
        Dense(256, num_items(), init = init),
    )
end;

In [None]:
# embarrassingly shallow autoencoder
struct EaseLayer
    W::Any
    mask::Any
end
EaseLayer(W) = EaseLayer(W, item_cf_mask())
EaseLayer(n::Integer; init = Flux.glorot_uniform) = EaseLayer(init(n, n))
function (m::EaseLayer)(x)
    W = m.W .* m.mask
    W * x
end
Flux.@functor EaseLayer (W,)

function ease(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(EaseLayer(num_items(); init = init))
end

function device(x::EaseLayer)
    EaseLayer(x.W |> device, x.mask |> device)
end

function Flux.cpu(x::EaseLayer)
    EaseLayer(x.W |> cpu, x.mask |> cpu)
end;

In [None]:
# double embedding
function double_embedding(rng; K = 256)
    init = (x...) -> Flux.glorot_uniform(rng, x...) 
    output_size(x) = Int(16 * ceil(log(x)))    
    E = 256*3
    sizes = (29, 76, 1010, 892, E, E)
    layers = [Dense(M, output_size(M), relu, init=init) for M in sizes]
    S = Dense(num_items(), output_size(num_items()), relu, init = init)
    Chain(
        Join(
            vcat,
            Join(
                vcat,
                S,
                S,
                S,
            ),
            Join(vcat,
                S,
                layers...,
            ),
        ),
        Dense(output_size(num_items())*4 + sum(output_size.(sizes)), K, relu, init = init),
        Dense(K, 1, init = init),
    )    
end;

In [10]:
# metadata embedding
struct MetadataEmbedding
    E1::Any
    E2::Any
end

MetadataEmbedding(
    N::Integer,
    c1::Integer,
    c2::Integer,
    c3::Integer;
    init = Flux.glorot_uniform,
) = MetadataEmbedding(init(N, c1, c2, 1), init(N, c2, c3, 1))

tensor_product(E, X) =
    reshape(sum(X .* E, dims = 2), (size(X)[1], size(E)[3], 1, size(X)[end]))

function (m::MetadataEmbedding)(X)
    N, M = size(m.E1)[[1, 2]]
    B = size(X)[end]
    X1 = reshape(X, (N, M, 1, B))
    X2 = relu.(tensor_product(m.E1, X1))
    X3 = tensor_product(m.E2, X2)
    X4 = reshape(sum(X3, dims = 1), size(X3)[[2, end]])
    NumSeen = reshape(max.(1, sum(X1[:, 2, :, :] .!= 0, dims = 1)), (1, B))
    X4 ./ NumSeen
end

Flux.@functor MetadataEmbedding


function metadata_embedding(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    N = num_items()
    Chain(MetadataEmbedding(N, 4, 32, 64; init = init), Dense(64, 256, relu), Dense(256, N))
end;

### Dispatch

In [None]:
function build_model(; rng = Random.GLOBAL_RNG)
    if G.model == "item_biases"
        m = item_biases(rng)
    elseif G.model == "item_based_collaborative_filtering"
        m = item_based_collaborative_filtering(rng)
    elseif startswith(G.model, "autoencoder")
        m = autoencoder(rng)     
    elseif G.model == "ease"
        m = ease(rng)
    elseif startswith(G.model, "double_embedding")
        m = double_embedding(rng)
    elseif startswith(G.model, "metadata_embedding")
        m = metadata_embedding(rng)        
    else
        @assert false
    end
    Chain(m..., StorageLayer(1))
end;