# Anime Neural Network
* Based off of the YouTube recommender system

In [1]:
name = "ANN";
residual_alphas = ["UserItemBiases"];

In [2]:
using Flux

In [3]:
using NBInclude
@nbinclude("Alpha.ipynb");

In [4]:
BLAS.set_num_threads(Threads.nthreads())

## train on data

In [5]:
training = get_residuals("training", residual_alphas);
const validation = get_residuals("validation", residual_alphas)
R = sparse(
    training.user,
    training.item,
    training.rating,
    maximum(training.user),
    maximum(training.item),
);

# elems = 1:1024
# training =
#     RatingsDataset(training.user[elems], training.item[elems], training.rating[elems]);

In [6]:
function get_data(R, split, j)
    # inputs are the user's ratings for all shows (unseen shows get mapped to zero)
    # labels are the user's predictions for a show they haven't seen before (all other shows get mapped to zero)
    X = collect(R[split.user[j], :])
    X[split.item[j]] = 0
    Y = zeros(length(X))
    Y[split.item[j]] = split.rating[j]
    return (X, Y)
end

function get_batch(R, split, block_size)
    items = rand(1:length(split.rating), block_size)
    data = [[] for j = 1:Threads.nthreads()]
    Threads.@threads for i = 1:length(items)
        push!(data[Threads.threadid()], get_data(R, split, items[i]))
    end
    X = Flux.batch([data[t][i][1] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    Y = Flux.batch([data[t][i][2] for t = 1:Threads.nthreads() for i = 1:length(data[t])])
    [(X, Y)]
end;

In [14]:
# inputs are the user's ratings for all shows (unseen shows get mapped to zero)
# labels are the user's predictions for all shows
n_items = size(R)[2]
l1 = Dense(n_items, 256, relu)
l2 = Dense(256, n_items)
function m(x)
    counts = max.(sum(x .!= 0, dims = 1), 1)
    l2(l1(x) ./ counts)
end;

In [15]:
function loss(x, y)
    # computes the mean squared error
    # TODO optimize
    sum((y - m(x) .* (y .!= 0)) .^ 2) / size(y)[2]
end
ps = Flux.params([l1, l2])
opt = ADAM();

In [16]:
function evalcb(R, split)
    losses = []
    @showprogress for epoch = 1:100
        push!(losses, loss(get_batch(R, split, 128)[1]...))
    end
    sqrt(mean(losses))
end

function evalcb()
    @debug "training rmse: $(evalcb(R, training))"
    #    @debug "validation rmse: $(evalcb(R, validation))"
end

throttled_cb = Flux.throttle(evalcb, 60);

In [17]:
throttled_cb()

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220122 23:53:04 training rmse: 1.2964616874516517


In [None]:
for epoch = 1:999999999
    batch = get_batch(R, training, 128)
    Flux.train!(loss, ps, batch, opt, cb = throttled_cb)
    print(".")
end

...................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220122 23:54:44 training rmse: 1.3067223125783232


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220122 23:56:24 training rmse: 1.2903109990221644


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220122 23:58:06 training rmse: 1.2694310079384346


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220122 23:59:46 training rmse: 1.3041273868673422


....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:01:26 training rmse: 1.3043592752456548


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:03:08 training rmse: 1.2940464165582861


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:04:49 training rmse: 1.2838777524295735


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:06:31 training rmse: 1.2955429918713288


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:08:12 training rmse: 1.3015181267340343


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:09:53 training rmse: 1.302239594416411


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:11:34 training rmse: 1.283032434554239


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:13:16 training rmse: 1.2930097932262532


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:14:57 training rmse: 1.2837447755600546


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:16:38 training rmse: 1.28075422816069


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:18:19 training rmse: 1.2900725997222164


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:20:01 training rmse: 1.2926867227582959


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:21:41 training rmse: 1.2926529729574772


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:23:23 training rmse: 1.2686332120805135


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:25:04 training rmse: 1.2838505707043268


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:26:45 training rmse: 1.2846220529410612


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:28:25 training rmse: 1.2733112464116074


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:38[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:30:07 training rmse: 1.304303332841987


.....................................

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:37[39m
[38;5;4m[1m[ [22m[39m[38;5;4m[1mDebug: [22m[39m20220123 00:31:47 training rmse: 1.2743327403921723


.......................

In [12]:
# TODO evaluation metrics to check loss