# Anime Neural Network
* Based off of the candidate generator YouTube recommender system
* See Section 3 of [Deep Neural Networks for YouTube Recommendations](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45530.pdf) by Covington et al.

In [1]:
name = "ANN";
residual_alphas = ["UserItemBiases"];

In [2]:
using Flux # TODO add to readme
import BSON

In [3]:
using NBInclude
@nbinclude("Alpha.ipynb");

In [4]:
BLAS.set_num_threads(Threads.nthreads())

In [5]:
device = gpu

gpu (generic function with 1 method)

## train on data

In [6]:
training = get_residuals("training", residual_alphas);
const validation = get_residuals("validation", residual_alphas)
R = sparse(
    training.user,
    training.item,
    training.rating,
    maximum(training.user),
    maximum(training.item),
);

In [7]:
function get_data(R, split, j)
    # inputs are the user's ratings for all shows (unseen shows get mapped to zero)
    # labels are the user's predictions for a show they haven't seen before (all other shows get mapped to zero)
    X = collect(R[split.user[j], :])
    X[split.item[j]] = 0
    Y = zeros(length(X))
    Y[split.item[j]] = split.rating[j]

    # use implicit feedback
    X[X.!=0] .= 1
    Y[split.item[j]] = 1

    # normalize for embeddings
    weight = sum(X .!= 0)
    if weight == 0
        weight = 1
    end
    X = X ./ weight

    norm_weight = weight / size(R)[2]
    push!(X, norm_weight)
    push!(X, sqrt(norm_weight))
    push!(X, norm_weight^2)
    return (X, Y)
end

function get_batch(R, split, block_size)
    items = rand(1:length(split.rating), block_size)
    data = [[] for j = 1:Threads.nthreads()]
    Threads.@threads for i = 1:length(items)
        push!(data[Threads.threadid()], get_data(R, split, items[i]))
    end
    X = Flux.batch([data[t][i][1] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    Y = Flux.batch([data[t][i][2] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    [(X, Y)] |> device
end;

In [8]:
# inputs are the user's ratings for all shows (unseen shows get mapped to zero)
# labels are the user's predictions for all shows
n_items = size(R)[2]
m =
    Chain(
        Dense(n_items + 3, 256, bias = false),
        Dense(256,  512, relu),
        Dense(512, 256, relu),
        Dense(256, n_items),
    ) |> device
ps = Flux.params(m);

In [9]:
loss(x, y) = Flux.logitcrossentropy(m(x), y)
opt = ADAM();

In [10]:
function evalcb(R, split)
    losses = []
    @showprogress for epoch = 1:100
        push!(losses, loss(get_batch(R, split, 128)[1]...))
    end
    mean(losses)
end

best_loss = Inf
patience = 5
iters_without_improvement = 0
continue_training = true

function evalcb()
    # print losses and perform early stopping
    @debug "training rmse: $(evalcb(R, training))"
    loss = evalcb(R, validation)
    @debug "validation rmse: $(loss)"
    if loss < best_loss
        global best_loss = loss
        global iters_without_improvement = 0
        BSON.@save "../../data/alphas/$name/model2.bson" m
    else
        global iters_without_improvement += 1
        if iters_without_improvement >= patience
            global continue_training = false
        end
    end
end

throttled_cb = Flux.throttle(evalcb, 600);

In [None]:
while continue_training
    batch = get_batch(R, training, 128)
    Flux.train!(loss, ps, batch, opt, cb = throttled_cb)
end

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:45[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 23:58:06 training rmse: 9.739583796213987
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:44[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 23:58:51 validation rmse: 9.739576502287571
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:42[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220124 00:09:43 training rmse: 7.998284901785853
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:42[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220124 00:10:26 validation rmse: 8.004316457330004
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:42[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220124 00:21:14 training rmse: 7.9076553461137955
[32mProgress: 100%|███████████████████████████████████

LoadError: InterruptException:

In [None]:
# [ Debug: 20220123 14:44:04 validation rmse: 7.43