## Models
* To define a new model:
  * add the architecture to `build_model` 
  * add the regularization to `regularization_loss`
  * fill in any asserts

In [None]:
# a layer to join multiple inputs
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths)

In [None]:
# A layer that stores a parameter and leaves the input unchaged
# this is used for residualization
struct StorageLayer
    β::Any
end
StorageLayer(n::Integer; init = ones) = StorageLayer(init(Float32, n))
function (m::StorageLayer)(x)
    x
end
Flux.@functor StorageLayer

In [None]:
# A layer that performs a linear regression on the input
struct ScalarLayer
    γ::Any
    c::Any
end
ScalarLayer(n::Integer; init = ones) = ScalarLayer(init(Float32, n), zeros(Float32, n))
function (m::ScalarLayer)(x)
    x .* m.γ .+ m.c
end
Flux.@functor ScalarLayer

In [None]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [None]:
# Implements a baseline predictor given by R[user, item] = u[user] + a[item]
function user_item_biases(rng)
    U = Flux.Embedding(G.num_users => 1, init = (x...) -> zeros(Float32, x...))
    A = BiasLayer(num_items())
    B = BiasLayer(1) # unregularized constant so that U, A can centered at 0
    Chain(U, A, B)
end

# regularization is λ_u ||u|| + λ_a ||a||
function user_item_biases_regularization(m, x)
    sum(m[1](x) .^ 2) * G.regularization_params[1] +
    sum(m[2].b .^ 2) * G.regularization_params[2]
end;

In [None]:
# R[user, item] = U[:, user] * A[:, item]
function matrix_factorization(rng, K)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    U = Flux.Embedding(G.num_users => K, init = init)
    A = Dense(K, num_items(), bias = false, init = init)
    Chain(U, A)
end

# λ_u ||U|| + λ_a ||a||
function matrix_factorization_regularization(m, x)
    sum(m[1](x) .^ 2) * G.regularization_params[1] +
    sum(m[2].weight .^ 2) * G.regularization_params[2]
end;

In [None]:
# R[user, item] = weighted average of other items the user has seen,
# where the weights are given by the neighborhood similarity matrix W

struct ItemCFLayer
    W::Any
    mask::Any
end
ItemCFLayer(W) = ItemCFLayer(W, item_cf_mask())
ItemCFLayer(n::Integer; init = Flux.glorot_uniform) = ItemCFLayer(init(n, n))
function (m::ItemCFLayer)(x)
    W = m.W .* m.mask
    (W * x) ./ (abs.(W) * nonzero.(x) .+ 1.0f0)
end
Flux.@functor ItemCFLayer (W,)

function item_based_collaborative_filtering(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(ItemCFLayer(num_items(); init = init))
end

function item_cf_mask()
    collect(1.0f0 .- LinearAlgebra.I(num_items()))
end;

function nonzero(x)
    t = eltype(x)
    x != zero(t) ? one(t) : zero(t)
end

function device(x::ItemCFLayer)
    ItemCFLayer(x.W |> device, x.mask |> device)
end

function Flux.cpu(x::ItemCFLayer)
    ItemCFLayer(x.W |> cpu, x.mask |> cpu)
end;

In [None]:
# inputs are the concatenation of implicit and explicit ratings
function autoencoder(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(
        Dense(num_items() * 3, 1024, relu, init = init),
        Dense(1024, 512, relu, init = init),
        Dense(512, 256, relu, init = init),
        Dense(256, num_items(), init = init),
    )
end;

In [None]:
# embarrassingly shallow autoencoder
struct EaseLayer
    W::Any
    mask::Any
end
EaseLayer(W) = EaseLayer(W, item_cf_mask())
EaseLayer(n::Integer; init = Flux.glorot_uniform) = EaseLayer(init(n, n))
function (m::EaseLayer)(x)
    W = m.W .* m.mask
    W * x
end
Flux.@functor EaseLayer (W,)

function ease(rng)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(EaseLayer(num_items(); init = init))
end

function device(x::EaseLayer)
    EaseLayer(x.W |> device, x.mask |> device)
end

function Flux.cpu(x::EaseLayer)
    EaseLayer(x.W |> cpu, x.mask |> cpu)
end;

In [None]:
# double embedding
function double_embedding(rng; K = 256, M = 1115)
    init = (x...) -> Flux.glorot_uniform(rng, x...)
    Chain(
        Join(
            vcat,
            Dense(num_items() * 3, K, relu, init = init),
            Dense(num_items(), K, relu, init = init),
            Dense(M, K, relu, init=init),
        ),
        Dense(K * 3, K, relu, init = init),
        Dense(K, 1, init = init),
    )    
end

In [None]:
function build_model(; rng = Random.GLOBAL_RNG)
    if G.model == "user_item_biases"
        m = user_item_biases(rng)
    elseif startswith(G.model, "matrix_factorization")
        K = parse(Int, split(G.model, "_")[end])
        m = matrix_factorization(rng, K)
    elseif G.model == "item_based_collaborative_filtering"
        m = item_based_collaborative_filtering(rng)
    elseif G.model == "autoencoder"
        m = autoencoder(rng)
    elseif G.model == "ease"
        m = ease(rng)
    elseif G.model == "double_embedding"
        m = double_embedding(rng)
    else
        @assert false
    end
    Chain(m..., StorageLayer(1))
end

function regularization_loss(m, x)
    if G.model == "user_item_biases"
        return user_item_biases_regularization(m, x)
    elseif startswith(G.model, "matrix_factorization")
        return matrix_factorization_regularization(m, x)
    elseif G.model == "item_based_collaborative_filtering" ||
           G.model == "autoencoder" ||
           G.model == "ease" ||
           G.model == "double_embedding"
        return 0
    else
        @assert false
    end
end;