In [None]:
source_medium = "anime"
medium = "manga"
task = "temporal"

In [None]:
name = "$medium/$task/$source_medium/Transformer/";

In [None]:
import DataStructures: DefaultDict
import Flux
import Flux: cpu, gpu, LayerNorm, logsoftmax
import NBInclude: @nbinclude
import Random
import SparseArrays: AbstractSparseArray, sparse
import StatsBase: mean, sample
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/CUDA.ipynb")
@nbinclude("Reference/Include.ipynb");

# Data

In [None]:
@kwdef struct Trainer
    task::Any
    # data
    sentences::Any
    labels::Any
    weights::Any
    timestamps::Any
    source_explicit_baseline::Any
    target_explicit_baseline::Any
    # model
    model::Any
    max_seq_len::Any
    vocab_sizes::Any
    mask_tokens::Any
    pad_tokens::Any
    cls_tokens::Any
    # training
    minibatch_size::Any
    batch_size::Any
    opt::Any
    rng::Any
end;

In [None]:
function get_labels(task, content, pretrain, num_items)
    if pretrain
        df = get_split("training", "all", content, medium)
    else
        df = cat(
            get_split("validation", task, content, medium),
            get_split("test", task, content, medium),
        )
    end
    sparse(df.item, df.user, df.rating, num_items, num_users(medium))
end

get_labels(task, pretrain, num_items) = get_labels(task, "implicit", pretrain, num_items),
get_labels(task, "explicit", pretrain, num_items);

In [None]:
function get_weights(task, content, pretrain, num_items)
    if pretrain
        df = get_split("training", "all", content, medium)
        w = powerdecay(
            get_counts("training", "all", content, medium),
            weighting_scheme("inverse"),
        )
    else
        df = cat(
            get_split("validation", task, content, medium),
            get_split("test", task, content, medium),
        )
        w = vcat(
            powerdecay(
                get_counts("validation", task, content, medium),
                weighting_scheme("inverse"),
            ),
            powerdecay(
                get_counts("test", task, content, medium),
                weighting_scheme("inverse"),
            ),
        )
    end
    sparse(df.item, df.user, w, num_items, num_users(medium))
end

get_weights(task, pretrain, num_items) = get_weights(task, "implicit", pretrain, num_items),
get_weights(task, "explicit", pretrain, num_items);

In [None]:
function get_timestamps(task, content, pretrain, num_items)
    if pretrain
        df = get_split("training", "all", content, medium)
    else
        df = cat(
            get_split("validation", task, content, medium),
            get_split("test", task, content, medium),
        )
    end
    sparse(df.item, df.user, df.timestamp, num_items, num_users(medium))
end

get_timestamps(task, pretrain, num_items) =
    get_timestamps(task, "implicit", pretrain, num_items);

In [None]:
function get_users(rng, task, pretrain, content)
    if pretrain
        users = collect(Set(get_split("training", "all", content, medium).user))
        training_frac = 0.99
        training = [x for x in users if x < training_frac * num_users(medium)]
        test = [x for x in users if x >= training_frac * num_users(medium)]
    else
        training = collect(Set(get_split("validation", task, content, medium).user))
        test = collect(Set(get_split("test", task, content, medium).user))
    end
    training, test
end

get_users(rng, task, pretrain) = get_users(rng, task, pretrain, "implicit");

In [None]:
function get_sentence(sentences, x, cls_tokens, explicit_baseline)
    try
        return copy(sentences[x])
    catch KeyError
        tokens = replace(cls_tokens, :user, x)
        tokens = replace(tokens, :status, explicit_baseline["task"])
        tokens = replace(tokens, :rating, explicit_baseline["user_biases"][x])
        return [tokens]
    end
end;

# Batching

In [None]:
function get_batch(;
    users,
    training,
    task,
    sentences,
    labels,
    weights,
    timestamps,
    max_seq_len,
    vocab_sizes,
    mask_tokens,
    pad_tokens,
    cls_tokens,
    explicit_baseline,
    rng,
)
    sentences = [get_sentence(sentences, x, cls_tokens, explicit_baseline) for x in users]
    processed_sentences = eltype(values(sentences))[]
    batch_positions = Tuple{Int32,Int32}[]
    output_labels = map(x -> x[:, users], labels)
    output_weights = map(x -> x[:, users], weights)

    for i::Int32 = 1:length(sentences)
        s = sentences[i]
        if task == "random"
            s = subset_sentence(
                s,
                max_seq_len - 1;
                recent = false,
                keep_first = true,
                rng = rng,
            )
            masked_word = mask_tokens
        elseif task in ["temporal"]
            s = subset_sentence(
                s,
                max_seq_len - 1;
                recent = true,
                keep_first = true,
                rng = rng,
            )
            masked_word = replace(mask_tokens, :timestamp, 1)
        elseif task in ["temporal_causal"]
            s = subset_sentence(
                s,
                max_seq_len - 1;
                recent = true,
                keep_first = true,
                rng = rng,
            )
            masked_word = replace(mask_tokens, :timestamp, 1)
            masked_word = replace(masked_word, :position, length(s))
        else
            @assert false
        end
        push!(s, masked_word)
        push!(batch_positions, (Int32(length(s)), i))
        push!(processed_sentences, s)
    end

    inputs =
        get_inputs(processed_sentences, max_seq_len, vocab_sizes, pad_tokens, cls_tokens)
    (inputs..., output_labels, output_weights, batch_positions)
end

get_batch(users, training::Bool, t::Trainer) = get_batch(
    users = users,
    training = training,
    task = t.task,
    sentences = t.sentences,
    labels = t.labels,
    weights = t.weights,
    timestamps = t.timestamps,
    max_seq_len = t.max_seq_len,
    vocab_sizes = t.vocab_sizes,
    mask_tokens = t.mask_tokens,
    pad_tokens = t.pad_tokens,
    cls_tokens = t.cls_tokens,
    explicit_baseline = t.source_explicit_baseline,
    rng = t.rng,
);

In [None]:
function get_inputs(sentences, max_seq_len, vocab_sizes, pad_tokens, cls_tokens)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)), max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(
        sentences,
        seq_len,
        extract(vocab_sizes, :position),
        pad_tokens,
        cls_tokens,
    )

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, extract(tokens, :item) .!= extract(pad_tokens, :item)),
        (1, seq_len, length(sentences)),
    )
    attention_mask = attention_mask .* permutedims(attention_mask, (2, 1, 3))

    tokens, attention_mask
end;

In [None]:
device(x) = gpu(x)
device(x::AbstractSparseArray) = CUDA.functional() ? CUDA.CuArray(gpu(x)) : collect(x)
function device(x::Tuple)
    (
        device(x[1][1]),
        device(x[1][2]),
        device(x[1][3]),
        device(x[1][4]),
        device(x[1][5]),
        nothing,
        device(x[1][7]),
    ),
    device(x[2]),
    device.(x[3]),
    device.(x[4]),
    device(x[5])
end

CUDA.unsafe_free!(::Nothing) = nothing

function device_free!(x)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(x[1])
    CUDA.unsafe_free!(x[2])
    CUDA.unsafe_free!.(x[3])
    CUDA.unsafe_free!.(x[4])
    CUDA.unsafe_free!(x[5])
end;

# Loss

In [None]:
function output_embedding(model, batch)
    tokens, attention_mask, _, _, batch_positions = batch
    X = model.embed(
        item = extract(tokens, :item),
        rating = extract(tokens, :rating),
        timestamp = extract(tokens, :timestamp),
        status = extract(tokens, :status),
        completion = extract(tokens, :completion),
        position = extract(tokens, :position),
    )
    X = model.transformers(X, attention_mask)
    X = gather(X, batch_positions)
end;

In [None]:
function lm_preds(model, batch)
    X = output_embedding(model, batch)
    item_preds =
        transpose(model.classifier.item.decoder) * model.classifier.item.transform(X) .+
        model.classifier.item.output_bias.b
    rating_preds = model.classifier.rating.transform(X)
    item_preds, rating_preds
end;

In [None]:
function lm_losses(model, batch)
    item_preds, rating_preds = lm_preds(model, batch)
    labels = batch[3]
    weights = batch[4]

    if sum(weights[1]) > 0
        item_loss =
            -sum(labels[1] .* weights[1] .* logsoftmax(item_preds)) / sum(weights[1])
    else
        item_loss = 0.0f0
    end
    if sum(weights[2]) > 0
        rating_loss = sum((rating_preds - labels[2]) .^ 2 .* weights[2]) / sum(weights[2])
    else
        rating_loss = 0.0f0
    end

    item_loss, rating_loss
end;

In [None]:
# function evaluate_losses(users, t::Trainer)
#     losses = zeros(2)
#     loss_weights = zeros(2)
#     user_batches = collect(Iterators.partition(users, t.batch_size))
#     @showprogress for user_batch in user_batches
#         batch = get_batch(user_batch, false, t) |> device
#         weights = sum.(batch[4])
#         loss_weights .+= weights
#         losses .+= lm_losses(t.model, batch) .* weights
#         batch |> device_free!
#     end
#     losses = losses ./ loss_weights
#     Dict("Item Crossentropy Loss" => losses[1], "Rating MSE Loss" => losses[2])
# end;

In [None]:
function lm_correlation_losses(model, batch)
    item_preds, rating_preds = lm_preds(model, batch)
    labels = batch[3]
    weights = batch[4]
    if sum(weights[1]) > 0
        item_loss = -sum(labels[1] .* weights[1] .* logsoftmax(item_preds))
    else
        item_loss = 0.0f0
    end
    if sum(weights[2]) > 0
        r1_loss = sum((rating_preds - labels[2]) .^ 2 .* weights[2])
        r0_loss = sum((0 .* rating_preds - labels[2]) .^ 2 .* weights[2])
        r_n1_loss = sum((-1 .* rating_preds - labels[2]) .^ 2 .* weights[2])
    else
        r1_loss = 0.0f0
        r0_loss = 0.0f0
        r_n1_loss = 0.0f0
    end
    item_loss, r1_loss, r0_loss, r_n1_loss
end;

function evaluate_losses(users, t::Trainer)
    losses = zeros(4)
    loss_weights = zeros(2)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, false, t) |> device
        weights = sum.(batch[4])
        loss_weights .+= weights
        losses .+= lm_correlation_losses(t.model, batch)
        batch |> device_free!
    end
    i_loss = losses[1] / loss_weights[1]
    # get the correlation loss by finding the minimum of the rating loss quadratic    
    r1_loss = losses[2] / loss_weights[2]
    r0_loss = losses[3] / loss_weights[2]
    r_n1_loss = losses[4] / loss_weights[2]
    a = (r1_loss + r_n1_loss) / 2 - r0_loss
    b = (r1_loss - r_n1_loss) / 2
    c = r0_loss
    r_loss = c - b^2 / (4 * a)
    Dict(
        "Item Crossentropy Loss" => i_loss,
        "Correlation MSE Loss" => r_loss,
        "Rating MSE Loss" => r1_loss,
    )
end;

# Training

In [None]:
function train_epoch!(users, t::Trainer)
    users = Random.shuffle(t.rng, users)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    losses = []
    @showprogress for user_batch in user_batches
        minibatches = collect(Iterators.partition(user_batch, t.minibatch_size))
        total_grads = nothing
        for minibatch in minibatches
            batch = get_batch(minibatch, true, t) |> device
            loss, grads = Flux.withgradient(t.model) do m
                sum(lm_losses(m, batch))
            end
            total_grads = tuplesum(total_grads, grads[1])
            push!(losses, loss)
            batch |> device_free!
        end
        total_grads = tupledivide(total_grads, length(minibatches))
        Flux.update!(t.opt, t.model, total_grads)
    end
    mean(losses)
end;

In [None]:
function checkpoint(
    users,
    t::Trainer,
    training_loss,
    epoch,
    source_checkpoint,
    target_checkpoint,
    name,
)
    @info "evaluating metrics"
    metrics = evaluate_losses(users, t)
    metrics["training_loss"] = training_loss
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "epoch" => epoch,
            "metrics" => metrics,
            "source_checkpoint" => source_checkpoint,
            "target_checkpoint" => target_checkpoint,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
    metrics
end;

# Configuration

In [None]:
function create_cross_media_model(source_checkpoint, target_checkpoint)
    m = read_params(source_checkpoint)["m"]
    n = read_params(target_checkpoint)["m"]
    item_cls = (
        transform = n.classifier.item.transform,
        output_bias = n.classifier.item.output_bias,
        decoder = n.embed.embeddings.item.embedding,
    )
    rating_cls = n.classifier.rating
    clf = (item = item_cls, rating = rating_cls)
    TransformerModel(m.embed, m.transformers, clf)
end;

In [None]:
function get_explicit_baseline(medium)
    explicit_baseline = read_params("$medium/$task/ExplicitUserItemBiases")
    explicit_baseline["user_biases"] = DefaultDict(
        mean(explicit_baseline["u"]),
        Dict(keys(explicit_baseline["u"]) .=> explicit_baseline["u"]),
    )
    explicit_baseline["item_biases"] = DefaultDict(
        mean(explicit_baseline["a"]),
        Dict(keys(explicit_baseline["a"]) .=> explicit_baseline["a"]),
    )
    explicit_baseline["task"] = findfirst(x -> x == task, ALL_TASKS)
    explicit_baseline
end;

In [None]:
function load_input_data(t::Trainer, source_checkpoint)
    config = read_params(source_checkpoint)["training_config"]
    use_ptw = config["include_ptw_impressions"]
    sentences = reduce(
        merge,
        [
            get_training_data(task, source_medium, use_ptw, config["cls_tokens"]) for
            task in ALL_TASKS
        ],
    )
    t = @set t.sentences = sentences
    GC.gc()

    # demean inputs
    @tprogress Threads.@threads for u in collect(keys(t.sentences))
        for i = 1:length(t.sentences[u])
            tokens = t.sentences[u][i]
            has_explicit_rating =
                (extract(tokens, :rating) .< extract(t.vocab_sizes, :rating))
            if has_explicit_rating
                blp =
                    t.source_explicit_baseline["user_biases"][u] +
                    t.source_explicit_baseline["item_biases"][extract(tokens, :item)]
                demeaned_rating = extract(tokens, :rating) - blp
                t.sentences[u][i] = replace(tokens, :rating, demeaned_rating)
            end
            if extract(tokens, :item) == extract(t.cls_tokens, :item)
                tokens = replace(tokens, :status, t.source_explicit_baseline["task"])
                tokens =
                    replace(tokens, :rating, t.source_explicit_baseline["user_biases"][u])
                t.sentences[u][i] = tokens
            end
        end
    end
    t
end

function load_output_data(t::Trainer, source_checkpoint, pretrain)
    N = size(t.model.classifier.item.decoder)[2]
    t = @set t.labels = get_labels(t.task, pretrain, N)
    t = @set t.weights = get_weights(t.task, pretrain, N)
    t = @set t.timestamps = get_timestamps(t.task, pretrain, N)

    # demean outputs
    @tprogress Threads.@threads for (a, u, _) in
                                    collect(zip(SparseArrays.findnz(t.labels[2])...))
        t.labels[2][a, u] -=
            t.target_explicit_baseline["user_biases"][u] +
            t.target_explicit_baseline["item_biases"][a]
    end
    training, validation = get_users(t.rng, t.task, pretrain)
    t, training, validation
end;

In [None]:
function load_pretrained_model(source_checkpoint, target_checkpoint, task)
    model = create_cross_media_model(source_checkpoint, target_checkpoint) |> gpu
    source_params = read_params(source_checkpoint)
    config = source_params["training_config"]
    t = Trainer(
        # finetuning domain
        task = task,
        # data
        sentences = nothing, #sentences,
        labels = nothing, #labels,
        weights = nothing, #weights,
        timestamps = nothing, #timestamps,
        source_explicit_baseline = get_explicit_baseline(source_medium),
        target_explicit_baseline = get_explicit_baseline(medium),
        # model
        model = model,
        max_seq_len = config["max_sequence_length"],
        vocab_sizes = config["base_vocab_sizes"],
        mask_tokens = config["mask_tokens"],
        pad_tokens = config["pad_tokens"],
        cls_tokens = config["cls_tokens"],
        # training
        minibatch_size = config["minibatch_size"],
        batch_size = config["batch_size"],
        opt = nothing,
        rng = Random.Xoshiro(20230102),
    )
    load_input_data(t, source_checkpoint)
end;

# Load pretrained model

In [None]:
function get_pretrain_checkpoint(medium, tag)
    pretrain_dir = "$medium/all/Transformer/$tag/checkpoints/"
    pretrain_epoch = sort(parse.(Int, readdir(get_data_path("alphas/$pretrain_dir"))))[end]
    pretrain_checkpoint = joinpath(pretrain_dir, string(pretrain_epoch))
    @info "using pretrained model from $pretrain_checkpoint"
    pretrain_checkpoint
end;

In [None]:
source_checkpoint = get_pretrain_checkpoint(source_medium, "v9")
target_checkpoint = get_pretrain_checkpoint(medium, "v9")
trainer = load_pretrained_model(source_checkpoint, target_checkpoint, task);

# Optionally pretrain model

In [None]:
function get_embeddings(users, t::Trainer)
    embedding = zeros(Float32, size(t.model.classifier.item.decoder)[1], length(users))
    @showprogress for minibatch in
                      collect(Iterators.partition(1:length(users), t.minibatch_size))
        batch = get_batch(users[minibatch], false, t) |> device
        y = output_embedding(t.model, batch) |> cpu
        batch |> device_free!
        embedding[:, minibatch] = y
    end
    embedding
end;

In [None]:
function pretrain_epoch!(users, embeddings, t::Trainer)
    losses = []
    @showprogress for minibatch in collect(
        Iterators.partition(Random.shuffle(1:length(users)), t.minibatch_size),
    )
        X = embeddings[:, minibatch] |> device
        y = device.(z[:, users[minibatch]] for z in t.labels)
        w = device.(z[:, users[minibatch]] for z in t.weights)
        tloss, grads = Flux.withgradient(t.model.classifier) do m
            rating_preds = m.rating.transform(X)
            item_preds =
                transpose(m.item.decoder) * m.item.transform(X) .+ m.item.output_bias.b
            iloss = -sum(y[1] .* w[1] .* logsoftmax(item_preds)) / sum(w[1])
            rloss = sum((rating_preds - y[2]) .^ 2 .* w[2]) / sum(w[2])
            iloss + rloss
        end
        Flux.update!(t.opt, t.model.classifier, grads[1])
        push!(losses, tloss)
        CUDA.unsafe_free!(X)
        CUDA.unsafe_free!.(y)
        CUDA.unsafe_free!.(w)
    end
    mean(losses)
end;

In [None]:
function pretrain_loss!(users, embeddings, t::Trainer)
    i_losses = []
    i_weight = 0.0f0
    r1_losses = []
    r0_losses = []
    r_n1_losses = []
    r_weight = 0.0f0
    @showprogress for minibatch in
                      collect(Iterators.partition(1:length(users), t.minibatch_size))
        X = embeddings[:, minibatch] |> device
        y = device.(z[:, users[minibatch]] for z in t.labels)
        w = device.(z[:, users[minibatch]] for z in t.weights)

        m = trainer.model.classifier
        rating_preds = m.rating.transform(X)
        item_preds = transpose(m.item.decoder) * m.item.transform(X) .+ m.item.output_bias.b
        iloss = -sum(y[1] .* w[1] .* logsoftmax(item_preds))
        i_weight += sum(w[1])
        r1_loss = sum((1 .* rating_preds - y[2]) .^ 2 .* w[2])
        r0_loss = sum((0 .* rating_preds - y[2]) .^ 2 .* w[2])
        r_n1_loss = sum((-1 .* rating_preds - y[2]) .^ 2 .* w[2])
        r_weight += sum(w[2])

        push!(i_losses, iloss)
        push!(r1_losses, r1_loss)
        push!(r0_losses, r0_loss)
        push!(r_n1_losses, r_n1_loss)

        CUDA.unsafe_free!(X)
        CUDA.unsafe_free!.(y)
        CUDA.unsafe_free!.(w)
    end

    i_loss = sum(i_losses) / i_weight
    # get the correlation loss by finding the minimum of the rating loss quadratic    
    r1 = sum(r1_losses) / r_weight
    r0 = sum(r0_losses) / r_weight
    r_n1 = sum(r_n1_losses) / r_weight
    a = (r1 + r_n1) / 2 - r0
    b = (r1 - r_n1) / 2
    c = r0
    r_loss = c - b^2 / (4 * a)
    i_loss, r_loss, r1
end;

In [None]:
function pretrain_model!(t::Trainer)
    t, training, validation = load_output_data(t, source_checkpoint, true)
    training_embeddings = get_embeddings(training, t)
    validation_embeddings = get_embeddings(validation, t)

    lr = 1e-4
    t = @set t.opt = Optimisers.setup(
        OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * 1f-2)),
        t.model.classifier,
    )

    stopper = early_stopper(max_iters = 10, patience = 0)
    test_loss = Inf
    best_model = nothing
    while (!stop!(stopper, test_loss))
        best_model = t.model |> cpu
        training_loss = pretrain_epoch!(training, training_embeddings, t)
        item_loss, corr_loss, rating_loss =
            pretrain_loss!(validation, validation_embeddings, t)
        test_loss = item_loss + corr_loss
        @info "$training_loss $item_loss $corr_loss $rating_loss"
    end
    t = @set t.model = best_model |> gpu
    write_params(
        Dict("m" => t.model |> cpu, "epoch" => stopper.iters, "loss" => stopper.loss),
        "$name/checkpoints/pretrain",
    )
    t
end;

In [None]:
if source_medium != medium
    trainer = pretrain_model!(trainer)
end

# Finetune model

In [None]:
trainer, training, validation = load_output_data(trainer, source_checkpoint, false);

In [None]:
lr = 1e-5
opt = Optimisers.setup(
    OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * 1f-2)),
    trainer.model,
)
trainer = @set trainer.opt = opt;

In [None]:
stopper = early_stopper(max_iters = 100, patience = 0)
test_loss = Inf
best_model = nothing
while (!stop!(stopper, test_loss))
    best_model = trainer.model |> cpu
    training_loss = train_epoch!(training, trainer)
    metrics = checkpoint(
        validation,
        trainer,
        training_loss,
        stopper.iters,
        source_checkpoint,
        target_checkpoint,
        name,
    )
    test_loss = metrics["Item Crossentropy Loss"] + metrics["Rating MSE Loss"]
end;
trainer = @set trainer.model = best_model |> gpu;

# Save predictions

In [None]:
# returns a vector that maps a user to the list of items to predict
function user_to_items(users::Vector, items::Vector)
    user_to_count = zeros(Int32, num_users(medium), Threads.nthreads())
    @tprogress Threads.@threads for u in users
        user_to_count[u, Threads.threadid()] += 1
    end
    user_to_count = convert.(Int32, vec(sum(user_to_count, dims = 2)))

    utoa = Vector{Vector{Int32}}()
    @showprogress for u = 1:num_users(medium)
        push!(utoa, Vector{Int32}(undef, user_to_count[u]))
    end

    @showprogress for i = 1:length(users)
        u = users[i]
        a = items[i]
        utoa[u][user_to_count[u]] = a
        user_to_count[u] -= 1
    end
    utoa
end

function evaluate_model(users, items, t::Trainer)
    CUDA.math_mode!(CUDA.FAST_MATH; precision = :TensorFloat32)
    utoa = user_to_items(users, items)
    out_users = Vector{Int32}(undef, length(users))
    out_items = Vector{Int32}(undef, length(users))
    out_implicit_ratings = fill(NaN32, length(out_users))
    out_explicit_ratings = fill(NaN32, length(out_users))
    out_idx = 1

    # compute predictions    
    user_batches = collect(Iterators.partition(Set(users), t.minibatch_size))
    @showprogress for sampled_users in Set(user_batches)
        batch = get_batch(sampled_users, false, t) |> device
        item_preds, rating_preds = lm_preds(t.model, batch)
        item_preds = softmax(item_preds) |> cpu
        rating_preds = rating_preds |> cpu
        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if length(utoa[u]) > 0
                item_mask = utoa[u]
                next_idx = out_idx + length(item_mask)
                out_users[out_idx:next_idx-1] .= u
                out_items[out_idx:next_idx-1] = item_mask
                out_implicit_ratings[out_idx:next_idx-1] = item_preds[item_mask, j]
                out_explicit_ratings[out_idx:next_idx-1] = rating_preds[item_mask, j]
                out_idx = next_idx
            end
        end
    end
    CUDA.math_mode!(CUDA.FAST_MATH; precision = :BFloat16)
    RatingsDataset(
        user = out_users,
        item = out_items,
        rating = out_implicit_ratings,
        medium = medium,
    ),
    RatingsDataset(
        user = out_users,
        item = out_items,
        rating = out_explicit_ratings,
        medium = medium,
    )
end;

In [None]:
function write_alpha(t::Trainer, outdir::String, task::String)
    master_dfs = []
    @showprogress for split in ALL_SPLITS
        for content in ALL_CONTENTS
            push!(
                master_dfs,
                get_raw_split(split, task, content, medium; fields = [:user, :item]),
            )
        end
    end
    master_df = reduce(cat, master_dfs)
    imp_p, exp_p = sparse.(evaluate_model(master_df.user, master_df.item, trainer))
    function model(content, p, users, items)
        r = zeros(length(users))
        @tprogress Threads.@threads for j = 1:length(r)
            r[j] = p[items[j], users[j]]
            if content == "explicit"
                blp =
                    t.target_explicit_baseline["user_biases"][users[j]] +
                    t.target_explicit_baseline["item_biases"][items[j]]
                r[j] += blp
            end
        end
        r
    end
    for (content, p) in [("implicit", imp_p), ("explicit", exp_p)]
        write_alpha(
            (users, items) -> model(content, p, users, items),
            medium,
            "$outdir/$content";
            task = task,
            log = true,
            log_task = task,
            log_content = content,
            log_alphas = String[],
        )
    end
end;

In [None]:
write_alpha(trainer, name, task)