# Anime Neural Network
* Based off of the candidate generator YouTube recommender system
* See Section 3 of [Deep Neural Networks for YouTube Recommendations](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45530.pdf) by Covington et al.

In [1]:
name = "ANN";
residual_alphas = ["UserItemBiases"];

In [2]:
using Flux # TODO add to readme
import BSON

In [3]:
using NBInclude
@nbinclude("Alpha.ipynb");

In [4]:
BLAS.set_num_threads(Threads.nthreads())

In [5]:
device = gpu;

## train on data

In [6]:
training = get_residuals("training", residual_alphas);
const validation = get_residuals("validation", residual_alphas)
# column accesses are faster than row accesses, so we make this an (item, user) matrix instead of a (user, item) matrix
R = sparse(
    training.item,
    training.user,
    convert.(Float32, training.rating),
    maximum(training.item) + 1, # leave room for unseen users and items
    maximum(training.user) + 1,
);
n_items, n_users = size(R)

(16981, 452578)

In [7]:
counts = zeros(Float32, n_users)
@tprogress Threads.@threads for j = 1:length(training.user)
    counts[training.user[j]] += 1
end

[32mProgress: 100%|███████████████████████████| Time: 0:00:01 ( 0.44 μs/it)[39m


In [8]:
function get_data(R, split, j)
    # inputs are the average user's ratings for all shows (unseen shows get mapped to zero) + avg implicit ratings + heterogenous features
    # outputs are the one-hot encoding of a heldout series that the user has seen
    weight = max(counts[min(split.user[j], n_users)] - 1, 1)

    X1 = collect(R[:, min(split.user[j], n_users)])
    X1[min(split.item[j], n_items)] = 0
    X1 = X1 ./ weight

    X2 = copy(X1)
    X2[X2.!=0] .= 1 / weight

    Y = zeros(Float32, length(X1))
    Y[split.item[j]] = 1

    # add heterogeneous features
    norm_weight = weight / n_items
    X3 = [norm_weight, sqrt(norm_weight), norm_weight^2]
    return (X1, X2, X3, Y)
end

function get_batch(R, split, block_size)
    idxs = rand(1:length(split.rating), block_size)
    data = [[] for j = 1:Threads.nthreads()]
    Threads.@threads for i in idxs
        push!(data[Threads.threadid()], get_data(R, split, i))
    end
    X1 = Flux.batch([data[t][i][1] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    X2 = Flux.batch([data[t][i][2] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    X3 = Flux.batch([data[t][i][3] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    Y = Flux.batch([data[t][i][4] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    [((X1, X2, X3), Y)] |> device
end;

function loss(x, y)
    Flux.logitcrossentropy(m(x), y)
end

function evalcb(R, split)
    losses = []
    @showprogress for epoch = 1:100
        push!(losses, val_loss(get_batch(R, split, 128)[1]...))
    end
    mean(losses)
end;

In [26]:
function val_loss(x, y)
    mask = (x[2] .!= 0)# zero out entries for shows you've already seen
    prelogits = m(x)
    prelogits[mask] .= -1e3
    Flux.logitcrossentropy(prelogits, y)
end

function val_evalcb(R, split)
    losses = []
    @showprogress for epoch = 1:100
        push!(losses, val_loss(get_batch(R, split, 128)[1]...))
    end
    mean(losses)
end;

In [28]:
# see get_data for documentation on inputs, outputs
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths)
rating_embedding = Dense(n_items, 256, bias = false)
implicit_embedding = Dense(n_items, 256, bias = false)
het_embedding = Dense(3, 3, bias = false)
m =
    Chain(
        Join(vcat, rating_embedding, implicit_embedding, het_embedding),
        Dense(256 + 256 + 3, 1024, relu),    
        Dense(1024, 512, relu),
        Dense(512, 256, relu),
        Dense(256, n_items),
    ) |> device
ps = Flux.params(m);

In [29]:
loss(x, y) = Flux.logitcrossentropy(m(x), y)
opt = ADAM();

In [35]:
best_loss = Inf
patience = 5
iters = 0
iters_without_improvement = 0
continue_training = true

function evalcb()
    # print losses and perform early stopping
    @debug "iteration: $iters"
    @debug "training loss: $(evalcb(R, training))"
    loss = val_evalcb(R, validation)
    @debug "validation loss: $(loss)"
    if loss < best_loss
        global best_loss = loss
        global iters_without_improvement = 0
        BSON.@save "../../data/alphas/$name/model.bson" m
    else
        global iters_without_improvement += 1
        if iters_without_improvement >= patience
            global continue_training = false
        end
    end
end

throttled_cb = Flux.throttle(evalcb, 60);

In [None]:
while continue_training
    batch = get_batch(R, training, 128)
    Flux.train!(loss, ps, batch, opt, cb = throttled_cb)
end

[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:11:40 iteration: 0
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:11:46 training rmse: 7.370386
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:11:53 validation rmse: 7.3511214
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:12:57 iteration: 0
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:13:05 training rmse: 7.3693876
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:07[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:13:12 validation rmse: 7.3729053
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220127 23:14:12 iteration: 0
[32mProgress: 100%

In [14]:
# [ Debug: 20220123 14:44:04 validation rmse: 7.43