In [1]:
task = "random"

"random"

In [2]:
name = "$task/Transformer/";

In [3]:
import NBInclude: @nbinclude
import SparseArrays: AbstractSparseArray, sparse
@nbinclude("../Alpha.ipynb")
@nbinclude("Transformer.ipynb");



# Data

In [4]:
@with_kw struct Trainer
    task::Any
    # data
    sentences::Any
    masked_items::Any
    labels::Any
    weights::Any
    explicit_prior::Any
    # model
    model::Any
    max_seq_len::Any
    vocab_sizes::Any
    mask_tokens::Any
    pad_tokens::Any
    # training
    minibatch_size::Any
    batch_size::Any
    opt::Any
    rng::Any
end;

In [5]:
function get_labels(task, content, prior = nothing)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    if prior != nothing
        @tprogress Threads.@threads for i = 1:length(df.rating)
            df.rating[i] -= prior.user[df.user[i]] + prior.item[df.item[i]]
        end
    end
    sparse(df.item, df.user, df.rating, num_items() + 4, num_users())
end

get_labels(task, explicit_prior) =
    get_labels(task, "implicit", nothing), get_labels(task, "explicit", explicit_prior);

In [6]:
function get_weights(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    w = vcat(
        powerdecay(get_counts("validation", task, content), weighting_scheme("inverse")),
        powerdecay(get_counts("test", task, content), weighting_scheme("inverse")),
    )
    sparse(df.item, df.user, w, num_items() + 4, num_users())
end

get_weights(task) = get_weights(task, "implicit"), get_weights(task, "explicit");

In [7]:
function get_masked_items(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    user_to_items = Dict()
    @showprogress for i = 1:length(df.user)
        if df.user[i] ∉ keys(user_to_items)
            user_to_items[df.user[i]] = Tuple{Int32,Float32}[]
        end
        push!(user_to_items[df.user[i]], (df.item[i], df.timestamp[i]))
    end
    user_to_items
end

get_masked_items(task) =
    get_masked_items(task, "implicit"), get_masked_items(task, "explicit");

In [8]:
function get_sentence(sentences, x)
    try
        return copy(sentences[x])
    catch KeyError
        return eltype(values(trainer.sentences))(undef, 0)
    end
end;

In [9]:
function get_users(rng, task, content)
    training = collect(Set(get_split("validation", task, content).user))
    test = collect(Set(get_split("test", task, content).user))
    training, test
end

get_users(rng, task) = get_users(rng, task, "implicit");

# Batching

In [10]:
get_random_timestamp(rng, item_timestamps) = rand(rng, map(x -> x[2], item_timestamps));
replace_timestamp(word, timestamp) = (word[1:2]..., timestamp, word[4:end]...);

In [11]:
# function random_dropout(rng, sentence, p)
#     new_sentence = eltype(sentence)[]
#     for i = 1:length(sentence)
#         if rand(rng) >= p
#             push!(new_sentence, sentence[i])
#         end
#     end
#     new_sentence
# end;

In [12]:
function get_batch(
    users,
    training;
    task,
    sentences,
    masked_items,
    labels,
    weights,
    explicit_prior,
    max_seq_len,
    vocab_sizes,
    mask_tokens,
    pad_tokens,
    rng,
)
    raw_sentences = [get_sentence(sentences, x) for x in users]
    # TODO test to completion
    # if training
    #     raw_sentences = random_dropout.(rng, raw_sentences, 0.1)
    # end
    processed_sentences = eltype(values(sentences))[]
    batch_positions = Tuple{Int32,Int32}[]
    output_labels = map(x -> x[:, users], labels)
    output_weights = map(x -> x[:, users], weights)

    for i::Int32 = 1:length(raw_sentences)
        s = raw_sentences[i]
        if task == "random"
            s = subset_sentence(s, max_seq_len - 1; recent = false, rng = rng)
            insert!(s, 1, mask_tokens)
            push!(batch_positions, (1, i))
        elseif task in ["causal", "temporal"]
            s = subset_sentence(s, max_seq_len - 1; recent = true, rng = rng)
            masked_word = mask_tokens
            if task == "temporal"
                masked_word = replace_timestamp(mask_tokens, 1)
            else
                if training && rang(rng) < 0.5
                    masked_word = replace_timestamp(
                        mask_tokens,
                        get_random_timestamp(rng, masked_items[users[i]]),
                    )
                end
            end
            push!(s, masked_word)
            push!(batch_positions, (Int32(length(s)), i))
        else
            @assert false
        end
        push!(processed_sentences, s)
    end

    inputs = get_inputs(
        processed_sentences,
        max_seq_len,
        vocab_sizes,
        pad_tokens,
        mask_tokens,
        explicit_prior,
        rng,
    )
    (inputs..., output_labels, output_weights, batch_positions)
end

get_batch(users, training::Bool, t::Trainer) = get_batch(
    users,
    training,
    task = t.task,
    sentences = t.sentences,
    masked_items = t.masked_items,
    labels = t.labels,
    weights = t.weights,
    explicit_prior = t.explicit_prior,
    max_seq_len = t.max_seq_len,
    vocab_sizes = t.vocab_sizes,
    mask_tokens = t.mask_tokens,
    pad_tokens = t.pad_tokens,
    rng = t.rng,
);

In [13]:
function get_inputs(
    sentences,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    mask_tokens,
    explicit_prior,
    rng,
)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)) + 1, max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(sentences, seq_len, pad_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, tokens[1] .!= pad_tokens[1]),
        (1, seq_len, length(sentences)),
    )

    # demean explicit ratings
    for b::Int32 = 1:length(sentences)
        seq_len = Int(sum(attention_mask[:, :, b]))
        for i::Int32 = 1:seq_len
            if tokens[2][i, b] .< vocab_sizes[2]
                prior =
                    explicit_prior.user[tokens[6][i, b]] +
                    explicit_prior.item[tokens[1][i, b]]
                tokens[2][i, b] -= prior
            end
        end
    end

    tokens, attention_mask
end;

In [14]:
device(x) = gpu(x)
device(x::AbstractSparseArray) = CUDA.functional() ? CUDA.CuArray(gpu(x)) : collect(x)
device(batch::Tuple) = device.(batch)

function device_free!(x)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(x[1])
    CUDA.unsafe_free!(x[2])
    CUDA.unsafe_free!.(x[3])
    CUDA.unsafe_free!.(x[4])
    CUDA.unsafe_free!(x[5])
end;

# Loss

In [15]:
function lm_preds(model, batch)
    tokens, attention_mask, labels, weights, batch_positions = batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        completion = tokens[5],
        position = tokens[1],
    )
    X = model.transformers(X, attention_mask)
    X = gather(X, batch_positions)

    item_preds =
        transpose(model.embed.embeddings.item.embedding) *
        model.classifier.item.transform(X) .+ model.classifier.item.output_bias.b
    rating_preds = model.classifier.rating.transform(X)
    item_preds, rating_preds
end;

In [16]:
function lm_losses(model, batch)
    tokens, attention_mask, labels, weights, batch_positions = batch
    item_preds, rating_preds = lm_preds(model, batch)

    if sum(weights[1]) > 0
        item_loss =
            -sum(labels[1] .* weights[1] .* logsoftmax(item_preds)) / sum(weights[1])
    else
        item_loss = 0.0f0
    end
    if sum(weights[2]) > 0
        rating_loss = sum((rating_preds - labels[2]) .^ 2 .* weights[2]) / sum(weights[2])
    else
        rating_loss = 0.0f0
    end

    item_loss, rating_loss
end;

In [17]:
function split_losses(users, t::Trainer)
    losses = zeros(2)
    loss_weights = zeros(2)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, false, t) |> device
        weights = sum.(batch[4])
        loss_weights .+= weights
        losses .+= lm_losses(t.model, batch) .* weights
        batch |> device_free!
    end
    losses = losses ./ loss_weights
    Dict("Item Crossentropy Loss" => losses[1], "Rating MSE Loss" => losses[2])
end;

# Training

In [18]:
function train_epoch!(users, t::Trainer)
    users = Random.shuffle(t.rng, users)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    losses = []
    @showprogress for user_batch in user_batches
        minibatches = collect(Iterators.partition(user_batch, t.minibatch_size))
        total_grads = nothing
        for minibatch in minibatches
            batch = get_batch(minibatch, true, t) |> device
            loss, grads = Flux.withgradient(t.model) do m
                sum(lm_losses(m, batch))
            end
            total_grads = tuplesum(total_grads, grads[1])
            push!(losses, loss)
            batch |> device_free!
        end
        total_grads = tupledivide(total_grads, length(minibatches))
        Flux.update!(t.opt, t.model, total_grads)
    end
    mean(losses)
end;

In [19]:
function checkpoint(users, t::Trainer, training_loss, epoch, name)
    @info "evaluating metrics"
    metrics = split_losses(users, t)
    metrics["training_loss"] = training_loss
    write_params(
        Dict("m" => t.model |> cpu, "epoch" => epoch, "metrics" => metrics),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
    metrics
end;

# Configuration

In [20]:
function load_pretrained_model(checkpoint, task)
    params = read_params(checkpoint)
    config = params["training_config"]
    use_ptw = config["include_ptw_impressions"]
    sentences = get_training_data(task, use_ptw; show_progress_bar = true)
    explicit_prior = config["explicit_prior"] # TODO task specific priors    
    labels = get_labels(task, explicit_prior)
    weights = get_weights(task)
    masked_items = get_masked_items(task)
    model = params["m"] |> gpu
    lr = 3e-5
    opt = Optimisers.setup(
        OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * 1f-2)),
        model,
    )
    trainer = Trainer(
        # finetuning domain
        task = task,
        # data
        sentences = sentences,
        masked_items = masked_items,
        labels = labels,
        weights = weights,
        explicit_prior = explicit_prior,
        # model
        model = model,
        max_seq_len = config["max_sequence_length"],
        vocab_sizes = config["base_vocab_sizes"],
        mask_tokens = config["mask_tokens"],
        pad_tokens = config["pad_tokens"],
        # training
        minibatch_size = config["minibatch_size"],
        batch_size = config["batch_size"],
        opt = opt,
        rng = Random.Xoshiro(20230102),
    )
    trainer
end;

# Actually Train Model!

In [21]:
pretrain_checkpoint = "all/Transformer/norm/checkpoints/8"
trainer = load_pretrained_model(pretrain_checkpoint, task);

[32mProgress: 100%|███████████████████████████| Time: 0:03:53 ( 1.32 μs/it)[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:22[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:13[39m


In [22]:
training, validation = get_users(trainer.rng, trainer.task);

In [None]:
stopper = early_stopper(max_iters = 25)
test_loss = Inf
while (!stop!(stopper, test_loss))
    training_loss = train_epoch!(training, trainer)        
    metrics = checkpoint(validation, trainer, training_loss, epoch, name)    
    test_loss = metrics["Item Crossentropy Loss"] + metrics["Rating MSE Loss"]
end;

[32mProgress:  42%|█████████████████▏                       |  ETA: 0:46:59[39mm57[39mm

# Save predictions

In [None]:
# # returns a vector that maps a user to the list of items to predict
# function user_to_items(users::Vector, items::Vector)
#     user_to_count = zeros(Int32, num_users(), Threads.nthreads())
#     @tprogress Threads.@threads for u in users
#         user_to_count[u, Threads.threadid()] += 1
#     end
#     user_to_count = convert.(Int32, vec(sum(user_to_count, dims = 2)))

#     utoa = Vector{Vector{Int32}}()
#     @showprogress for u = 1:num_users()
#         push!(utoa, Vector{Int32}(undef, user_to_count[u]))
#     end

#     @showprogress for i = 1:length(users)
#         u = users[i]
#         a = items[i]
#         utoa[u][user_to_count[u]] = a
#         user_to_count[u] -= 1
#     end
#     utoa
# end;

In [None]:
# function evaluate_model(users, items, t::Trainer, content::String)
#     utoa = user_to_items(users, items)

#     activation = content == "implicit" ? softmax : identity
#     out_users = Vector{Int32}(undef, length(users))
#     out_items = Vector{Int32}(undef, length(users))
#     out_ratings = fill(NaN32, length(out_users))
#     out_idx = 1

#     # compute predictions    
#     user_batches = collect(Iterators.partition(Set(users), t.minibatch_size))
#     @showprogress for sampled_users in Set(user_batches)
#         batch = get_batch(sampled_users, false, t) |> device
#         item_preds, rating_preds = lm_preds(model, batch)
#         alpha = (content == "implicit" ? item_preds : rating_preds) |> cpu
#         for j = 1:length(sampled_users)
#             u = sampled_users[j]
#             if length(utoa[u]) > 0
#                 item_mask = utoa[u]
#                 next_idx = out_idx + length(item_mask)
#                 out_users[out_idx:next_idx-1] .= u
#                 out_items[out_idx:next_idx-1] = item_mask
#                 out_ratings[out_idx:next_idx-1] = alpha[item_mask, j]
#                 out_idx = next_idx
#             end
#         end
#     end
#     RatingsDataset(user = out_users, item = out_items, rating = out_ratings)
# end

# function write_alpha(t::Trainer, outdir::String, content::String, task::String)
#     function model(users, items)
#         p = sparse(evaluate_model(users, items, t, content))
#         r = zeros(length(users))
#         @tprogress Threads.@threads for j = 1:length(r)
#             r[j] = p[items[j], users[j]]
#         end
#         r
#     end
#     if content == "explicit"
#         residual_alphas = ["random/ExplicitUserItemBiases"] # TODO task specific
#     else
#         residual_alphas = String[]
#     end
#     write_alpha(
#         model,
#         outdir;
#         task = task,
#         log = true,
#         log_task = task,
#         log_content = content,
#         log_alphas = residual_alphas,
#     )
# end;

In [None]:
# for content in ["implicit", "explicit"]
#     write_alpha(t, name, content, task)
# end;