In [None]:
task = "temporal"

In [None]:
name = "$task/Transformer/";

In [None]:
import NBInclude: @nbinclude
import Flux
import Flux: cpu, gpu, LayerNorm, logsoftmax
import Random
import SparseArrays: AbstractSparseArray, sparse
import StatsBase: mean, sample
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/CUDA.ipynb")
@nbinclude("Reference/Include.ipynb");

# Data

In [None]:
@with_kw struct Trainer
    task::Any
    # data
    sentences::Any
    labels::Any
    weights::Any
    timestamps::Any
    # model
    model::Any
    max_seq_len::Any
    vocab_sizes::Any
    mask_tokens::Any
    pad_tokens::Any
    cls_tokens::Any
    # training
    minibatch_size::Any
    batch_size::Any
    opt::Any
    rng::Any
end;

In [None]:
function get_labels(task, content, num_items)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    sparse(df.item, df.user, df.rating, num_items, num_users())
end

get_labels(task, num_items) =
    get_labels(task, "implicit", num_items), get_labels(task, "explicit", num_items);

In [None]:
function get_weights(task, content, num_items)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    w = vcat(
        powerdecay(get_counts("validation", task, content), weighting_scheme("inverse")),
        powerdecay(get_counts("test", task, content), weighting_scheme("inverse")),
    )
    sparse(df.item, df.user, w, num_items, num_users())
end

get_weights(task, num_items) =
    get_weights(task, "implicit", num_items), get_weights(task, "explicit", num_items);

In [None]:
function get_timestamps(task, content, num_items)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    sparse(df.item, df.user, df.timestamp, num_items, num_users())
end

get_timestamps(task, num_items) = get_timestamps(task, "implicit", num_items);

In [None]:
function get_users(rng, task, content)
    training = collect(Set(get_split("validation", task, content).user))
    test = collect(Set(get_split("test", task, content).user))
    training, test
end

get_users(rng, task) = get_users(rng, task, "implicit");

In [None]:
function get_sentence(sentences, x, cls_tokens)
    try
        return copy(sentences[x])
    catch KeyError
        return [cls_tokens]
    end
end;

# Batching

In [None]:
function get_batch(
    users,
    training;
    task,
    sentences,
    labels,
    weights,
    timestamps,
    max_seq_len,
    vocab_sizes,
    mask_tokens,
    pad_tokens,
    cls_tokens,
    rng,
)
    sentences = [get_sentence(sentences, x, cls_tokens) for x in users]
    processed_sentences = eltype(values(sentences))[]
    batch_positions = Tuple{Int32,Int32}[]
    output_labels = map(x -> x[:, users], labels)
    output_weights = map(x -> x[:, users], weights)

    for i::Int32 = 1:length(sentences)
        s = sentences[i]
        if task == "random"
            s = subset_sentence(s, max_seq_len - 1; recent = false, rng = rng)
            masked_word = mask_tokens
        elseif task in ["temporal"]
            s = subset_sentence(s, max_seq_len - 1; recent = true, rng = rng)
            # TODO set pos
            masked_word = replace(mask_tokens, :timestamp, 1)
        else
            @assert false
        end
        push!(s, masked_word)
        push!(batch_positions, (Int32(length(s)), i))
        push!(processed_sentences, s)
    end

    inputs = get_inputs(
        processed_sentences,
        max_seq_len,
        vocab_sizes,
        pad_tokens,
        cls_tokens,
        rng,
    )
    (inputs..., output_labels, output_weights, batch_positions)
end

get_batch(users, training::Bool, t::Trainer) = get_batch(
    users,
    training,
    task = t.task,
    sentences = t.sentences,
    labels = t.labels,
    weights = t.weights,
    timestamps = t.timestamps,
    max_seq_len = t.max_seq_len,
    vocab_sizes = t.vocab_sizes,
    mask_tokens = t.mask_tokens,
    pad_tokens = t.pad_tokens,
    cls_tokens = t.cls_tokens,
    rng = t.rng,
);

In [None]:
function get_inputs(sentences, max_seq_len, vocab_sizes, pad_tokens, cls_tokens, rng)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)), max_seq_len)

    # get tokenized sentences
    tokens =
        get_token_ids(sentences, seq_len, vocab_sizes[7], pad_tokens, cls_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, extract(tokens, :item) .!= extract(pad_tokens, :item)),
        (1, seq_len, length(sentences)),
    )
    attention_mask = attention_mask .* permutedims(attention_mask, (2, 1, 3))

    tokens, attention_mask
end;

In [None]:
device(x) = gpu(x)
device(x::AbstractSparseArray) = CUDA.functional() ? CUDA.CuArray(gpu(x)) : collect(x)
function device(x::Tuple)
    (
        device(x[1][1]),
        device(x[1][2]),
        device(x[1][3]),
        device(x[1][4]),
        device(x[1][5]),
        nothing,
        device(x[1][7]),
    ),
    device(x[2]),
    device.(x[3]),
    device.(x[4]),
    device(x[5])
end

CUDA.unsafe_free!(::Nothing) = nothing

function device_free!(x)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(x[1])
    CUDA.unsafe_free!(x[2])
    CUDA.unsafe_free!.(x[3])
    CUDA.unsafe_free!.(x[4])
    CUDA.unsafe_free!(x[5])
end;

# Loss

In [None]:
function lm_preds(model, batch)
    tokens, attention_mask, _, _, batch_positions = batch
    X = model.embed(
        item = extract(tokens, :item),
        rating = extract(tokens, :rating),
        timestamp = extract(tokens, :timestamp),
        status = extract(tokens, :status),
        completion = extract(tokens, :completion),
        position = extract(tokens, :position),
    )
    X = model.transformers(X, attention_mask)
    X = gather(X, batch_positions)

    item_preds =
        transpose(model.embed.embeddings.item.embedding) *
        model.classifier.item.transform(X) .+ model.classifier.item.output_bias.b
    rating_preds = model.classifier.rating.transform(X)
    item_preds, rating_preds
end;

In [None]:
function lm_losses(model, batch)
    item_preds, rating_preds = lm_preds(model, batch)
    labels = batch[3]
    weights = batch[4]

    if sum(weights[1]) > 0
        item_loss =
            -sum(labels[1] .* weights[1] .* logsoftmax(item_preds)) / sum(weights[1])
    else
        item_loss = 0.0f0
    end
    if sum(weights[2]) > 0
        rating_loss = sum((rating_preds - labels[2]) .^ 2 .* weights[2]) / sum(weights[2])
    else
        rating_loss = 0.0f0
    end

    item_loss, rating_loss
end;

In [None]:
function evaluate_losses(users, t::Trainer)
    losses = zeros(2)
    loss_weights = zeros(2)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, false, t) |> device
        weights = sum.(batch[4])
        loss_weights .+= weights
        losses .+= lm_losses(t.model, batch) .* weights
        batch |> device_free!
    end
    losses = losses ./ loss_weights
    Dict("Item Crossentropy Loss" => losses[1], "Rating MSE Loss" => losses[2])
end;

# Training

In [None]:
function train_epoch!(users, t::Trainer)
    users = Random.shuffle(t.rng, users)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    losses = []
    @showprogress for user_batch in user_batches
        minibatches = collect(Iterators.partition(user_batch, t.minibatch_size))
        total_grads = nothing
        for minibatch in minibatches
            batch = get_batch(minibatch, true, t) |> device
            loss, grads = Flux.withgradient(t.model) do m
                sum(lm_losses(m, batch))
            end
            total_grads = tuplesum(total_grads, grads[1])
            push!(losses, loss)
            batch |> device_free!
        end
        total_grads = tupledivide(total_grads, length(minibatches))
        Flux.update!(t.opt, t.model, total_grads)
    end
    mean(losses)
end;

In [None]:
function checkpoint(users, t::Trainer, training_loss, epoch, pretrain_checkpoint, name)
    @info "evaluating metrics"
    metrics = evaluate_losses(users, t)
    metrics["training_loss"] = training_loss
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "epoch" => epoch,
            "metrics" => metrics,
            "pretrain_checkpoint" => pretrain_checkpoint,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
    metrics
end;

# Configuration

In [None]:
function load_pretrained_model(checkpoint, task)
    params = read_params(checkpoint)
    config = params["training_config"]
    use_ptw = config["include_ptw_impressions"]
    model = params["m"] |> gpu
    N = size(model.embed.embeddings.item.embedding)[2]
    sentences =
        get_training_data(task, use_ptw, config["cls_tokens"]; show_progress_bar = true)
    labels = get_labels(task, N)
    weights = get_weights(task, N)
    timestamps = get_timestamps(task, N)
    lr = 1e-5
    opt = Optimisers.setup(
        OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * 1f-2)),
        model,
    )
    trainer = Trainer(
        # finetuning domain
        task = task,
        # data
        sentences = sentences,
        labels = labels,
        weights = weights,
        timestamps = timestamps,
        # model
        model = model,
        max_seq_len = config["max_sequence_length"],
        vocab_sizes = config["base_vocab_sizes"],
        mask_tokens = config["mask_tokens"],
        pad_tokens = config["pad_tokens"],
        cls_tokens = config["cls_tokens"],
        # training
        minibatch_size = config["minibatch_size"],
        batch_size = config["batch_size"],
        opt = opt,
        rng = Random.Xoshiro(20230102),
    )
    trainer
end;

# Actually Train Model!

In [None]:
pretrain_model_tag = "v8"
pretrain_dir = "all/Transformer/$pretrain_model_tag/checkpoints/"
pretrain_epoch = sort(parse.(Int, readdir(get_data_path("alphas/$pretrain_dir"))))[end]
pretrain_checkpoint = joinpath(pretrain_dir, string(pretrain_epoch))
@info "loading pretrained model from $pretrain_checkpoint"
trainer = load_pretrained_model(pretrain_checkpoint, task);

In [None]:
training, validation = get_users(trainer.rng, trainer.task);

In [None]:
stopper = early_stopper(max_iters = 20, patience=0)
test_loss = Inf
best_model = nothing
while (!stop!(stopper, test_loss))
    best_model = trainer.model |> cpu
    training_loss = train_epoch!(training, trainer)        
    metrics = checkpoint(validation, trainer, training_loss, stopper.iters, pretrain_checkpoint, name)    
    test_loss = metrics["Item Crossentropy Loss"] + metrics["Rating MSE Loss"]
end;
trainer = @set trainer.model = best_model |> gpu;

# Save predictions

In [None]:
# returns a vector that maps a user to the list of items to predict
function user_to_items(users::Vector, items::Vector)
    user_to_count = zeros(Int32, num_users(), Threads.nthreads())
    @tprogress Threads.@threads for u in users
        user_to_count[u, Threads.threadid()] += 1
    end
    user_to_count = convert.(Int32, vec(sum(user_to_count, dims = 2)))

    utoa = Vector{Vector{Int32}}()
    @showprogress for u = 1:num_users()
        push!(utoa, Vector{Int32}(undef, user_to_count[u]))
    end

    @showprogress for i = 1:length(users)
        u = users[i]
        a = items[i]
        utoa[u][user_to_count[u]] = a
        user_to_count[u] -= 1
    end
    utoa
end

function evaluate_model(users, items, t::Trainer)
    CUDA.math_mode!(CUDA.FAST_MATH; precision = :TensorFloat32)
    utoa = user_to_items(users, items)
    out_users = Vector{Int32}(undef, length(users))
    out_items = Vector{Int32}(undef, length(users))
    out_implicit_ratings = fill(NaN32, length(out_users))
    out_explicit_ratings = fill(NaN32, length(out_users))
    out_idx = 1

    # compute predictions    
    user_batches = collect(Iterators.partition(Set(users), t.minibatch_size))
    @showprogress for sampled_users in Set(user_batches)
        batch = get_batch(sampled_users, false, t) |> device
        item_preds, rating_preds = lm_preds(t.model, batch)
        item_preds = softmax(item_preds) |> cpu
        rating_preds = rating_preds |> cpu
        for j = 1:length(sampled_users)
            u = sampled_users[j]
            if length(utoa[u]) > 0
                item_mask = utoa[u]
                next_idx = out_idx + length(item_mask)
                out_users[out_idx:next_idx-1] .= u
                out_items[out_idx:next_idx-1] = item_mask
                out_implicit_ratings[out_idx:next_idx-1] = item_preds[item_mask, j]
                out_explicit_ratings[out_idx:next_idx-1] = rating_preds[item_mask, j]
                out_idx = next_idx
            end
        end
    end
    CUDA.math_mode!(CUDA.FAST_MATH; precision = :BFloat16)
    RatingsDataset(user = out_users, item = out_items, rating = out_implicit_ratings),
    RatingsDataset(user = out_users, item = out_items, rating = out_explicit_ratings)
end;

In [None]:
function write_alpha(t::Trainer, outdir::String, task::String)
    master_dfs = []
    @showprogress for split in ALL_SPLITS
        for content in ALL_CONTENTS
            push!(master_dfs, get_raw_split(split, task, content; fields = [:user, :item]))
        end
    end
    master_df = reduce(cat, master_dfs)
    imp_p, exp_p = sparse.(evaluate_model(master_df.user, master_df.item, trainer))
    function model(p, users, items)
        r = zeros(length(users))
        @tprogress Threads.@threads for j = 1:length(r)
            r[j] = p[items[j], users[j]]
        end
        r
    end
    for (content, p) in [("implicit", imp_p), ("explicit", exp_p)]
        write_alpha(
            (users, items) -> model(p, users, items),
            "$outdir/$content";
            task = task,
            log = true,
            log_task = task,
            log_content = content,
            log_alphas = String[],
        )
    end
end;

In [None]:
write_alpha(trainer, name, task)