# NeuralUserItemBiases
* See the corresponding file in `../TrainingAlphas` for more details

In [1]:
source = "NeuralUserItemBiases";

In [2]:
import NBInclude: @nbinclude
@nbinclude("Alpha.ipynb");

In [3]:
@nbinclude("../TrainingAlphas/NeuralNetworkBase.ipynb");

## Data Preprocessing

In [4]:
# Override methods in NeuralNetworkBase to use recommendee splits

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_outputs(split, implicit, num_users)
    @assert split == "training"
    sparse(get_recommendee_split(implicit))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_residuals(
    split,
    residual_alphas,
    implicit,
    num_users,
)
    @assert split == "training"
    sparse(read_recommendee_alpha(residual_alphas, implicit))
end

@memoize LRU{Any,Any}(maxsize = 2) function get_epoch_weights(
    split,
    user_weight_decay,
    item_weight_decay,
    implicit,
    num_users,
)
    @assert split == "training"
    df = get_recommendee_split(implicit)
    user_counts = fill(length(df.rating), num_items())
    weights =
        expdecay(user_counts, user_weight_decay) .* expdecay(
            get_counts(split, implicit; by_item = true, per_rating = false),
            item_weight_decay,
        )

    sparse(RatingsDataset(df.user, df.item, weights[df.item]))
end;

## Retrain user embeddings

In [5]:
function retrain_user_embeddings(params)
    hyp = params["retrain_hyp"]
    global G = @set hyp.num_users = 1
    m = build_retrain_model(G, params["m"]) |> device
    ps = Flux.params(m[1])
    opt = get_optimizer(G.optimizer, G.learning_rate, G.regularization_params)

    @showprogress for _ = 1:params["epochs"]
        train_epoch!(m, ps, opt)
        apply_zero_gradient!(m, ps, opt, true)
    end
    global G = nothing
    m |> cpu
end;

## Write alpha

In [6]:
function compute_alpha(source)
    @info "computing alpha $source"
    params = read_params(source)
    m = retrain_user_embeddings(params)
    preds = m(1)
    write_recommendee_alpha(preds, source)
end;

In [7]:
function compute_alpha()
    compute_alpha("NeuralExplicitUserItemBiases")
    compute_alpha("NeuralImplicitUserItemBiases")
    compute_alpha("NeuralExplicitMatrixFactorization")
    compute_alpha("NeuralImplicitMatrixFactorization")
end;

In [8]:
compute_alpha();

[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220618 13:56:12 computing alpha NeuralExplicitUserItemBiases
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:07[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220618 13:57:24 computing alpha NeuralImplicitUserItemBiases


Iter     Function value   Gradient norm 
     0     9.849667e+00     0.000000e+00
 * time: 0.008617162704467773


[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220618 13:57:35 computing alpha NeuralExplicitMatrixFactorization
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20220618 13:57:39 computing alpha NeuralImplicitMatrixFactorization


Iter     Function value   Gradient norm 
     0     7.866368e+00     3.492808e-01
 * time: 2.7179718017578125e-5
     1     7.534700e+00     1.322128e-01
 * time: 0.6021480560302734
     2     7.442744e+00     4.278617e-02
 * time: 1.122122049331665
     3     7.415466e+00     1.126432e-02
 * time: 1.7214510440826416
     4     7.409862e+00     3.949660e-03
 * time: 2.316575050354004
     5     7.407974e+00     1.255889e-03
 * time: 2.835313081741333
     6     7.407395e+00     3.691816e-04
 * time: 3.4461541175842285
     7     7.407246e+00     1.358968e-04
 * time: 4.045407056808472
     8     7.407190e+00     5.112632e-05
 * time: 4.561211109161377
     9     7.407169e+00     2.053520e-05
 * time: 5.167686223983765
    10     7.407159e+00     8.916975e-06
 * time: 5.763728141784668
    11     7.407155e+00     3.442694e-06
 * time: 6.279336214065552
    12     7.407153e+00     1.426109e-06
 * time: 6.885868072509766
    13     7.407153e+00     6.137319e-07
 * time: 7.4780731201171875

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:15[39m
