In [1]:
task = "random"
content = "implicit"

"implicit"

In [2]:
name = "$task/Transformer/test";

In [3]:
import NBInclude: @nbinclude
import SparseArrays: AbstractSparseArray, sparse
@nbinclude("../Alpha.ipynb")
@nbinclude("Transformer.ipynb");



# Data

In [4]:
@with_kw struct Trainer
    # finetuning domain
    task::Any
    content::Any
    # data
    sentences::Any
    masked_items::Any
    labels::Any
    weights::Any
    priors::Any
    # model
    model::Any
    max_seq_len::Any
    mask_tokens::Any
    pad_tokens::Any
    # training
    batch_size::Any
    opt::Any
    rng::Any
end;

In [5]:
function get_labels(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    sparse(df.item, df.user, df.rating, num_items() + 4, num_users())
end;

In [6]:
function get_weights(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    w = vcat(
        powerdecay(get_counts("validation", task, content), weighting_scheme("inverse")),
        powerdecay(get_counts("test", task, content), weighting_scheme("inverse")),
    )
    sparse(df.item, df.user, w, num_items() + 4, num_users())
end;

In [7]:
function get_priors(task, content)
    df = cat(
        read_alpha("$task/ExplicitUserItemBiases", "validation", task, content),
        read_alpha("$task/ExplicitUserItemBiases", "test", task, content),
    )
    sparse(df.item, df.user, df.rating, num_items() + 4, num_users())
end;

In [8]:
function get_masked_items(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    user_to_items = Dict()
    @showprogress for i = 1:length(df.user)
        if df.user[i] ∉ keys(user_to_items)
            user_to_items[df.user[i]] = Int32[]
        end
        push!(user_to_items[df.user[i]], df.item[i])
    end
    user_to_items
end;

In [9]:
function get_sentence(sentences, x)
    try
        return copy(sentences[x])
    catch KeyError
        return eltype(values(trainer.sentences))(undef, 0)
    end
end;

# Batching

In [10]:
function replace_item(word, item)
    (item, word[2:end]...)
end

function replace_timestamp(word)
    ts = encode_raw_timestamp(1)
    (word[1:2]..., ts, word[4:end]...)
end;

In [11]:
function get_batch(
    users,
    task,
    content,
    sentences,
    masked_items,
    labels,
    weights,
    priors,
    max_seq_len,
    mask_tokens,
    pad_tokens,
    rng,
)
    # TODO potentially do maskout during training
    raw_sentences = [get_sentence(sentences, x) for x in users]
    processed_sentences = eltype(values(sentences))[]
    masked_token_positions = Tuple{Int32,Int32}[]
    masked_groups = Vector{Int32}[]
    masked_group_index = 1
    if content == "explicit"
        output_labels = Float32[]
        output_weights = Float32[]
        output_priors = Float32[]
    elseif content == "implicit"
        output_labels = labels[:, users]
        output_weights = weights[:, users]
        output_priors = priors[:, users]
    else
        @assert false
    end

    function sample_item(user, weight)
        item = rand(rng, masked_items[user])
        push!(output_labels, labels[user, item])
        push!(output_priors, priors[user, item])
        push!(output_weights, weight)
        item
    end

    for i::Int32 = 1:length(raw_sentences)
        s = raw_sentences[i]
        group = Int32[]
        function register_masked_item(i, j)
            push!(masked_token_positions, (j, i))
            push!(group, masked_group_index)
            masked_group_index += 1
        end
        if task == "random"
            s = subset_sentence(s, max_seq_len - 1; recent = false, rng = rng)
            insert!(s, 1, mask_tokens)
            #             num_masked = max(Int(round(0.15 * min(max_seq_len, length(s)))), 1)
            #             s = subset_sentence(s, max_seq_len - num_masked; recent = false, rng = rng)

            #             for j = 1:length(num_masked)
            #                 insert!(s, rand(rng, 1:length(s)+1), mask_tokens)
            #             end
            for j::Int32 = 1:length(s)
                if s[j] == mask_tokens
                    register_masked_item(i, j)
                    if content == "explicit"
                        s[j] = replace_item(s[j], sample_item(users[i], 1))
                    end
                end
            end
        elseif task in ["causal", "temporal"]
            s = subset_sentence(s, max_seq_len - 1; recent = true, rng = rng)
            masked_word = mask_tokens
            if task == "temporal"
                masked_word = replace_timestamp(mask_tokens)
            end
            if content == "explicit"
                masked_word = replace_item(word, sample_item(users[i], 1))
            end
            push!(s, masked_word)
            register_masked_item(i, Int32(length(s)))
        else
            @assert false
        end
        push!(processed_sentences, s)
        push!(masked_groups, group)
    end


    inputs = get_inputs(processed_sentences, max_seq_len, pad_tokens, rng)
    if content == "explicit"
        to_explicit_output(x) = convert.(Float32, collect(x'))
        output_labels = to_explicit_output(output_labels)
        output_weights = to_explicit_output(output_weights)
        output_priors = to_explicit_output(output_priors)
    end
    (
        inputs...,
        output_labels,
        output_weights,
        output_priors,
        masked_token_positions,
        masked_groups,
    )
end

get_batch(users, t::Trainer) = get_batch(
    users,
    t.task,
    t.content,
    t.sentences,
    t.masked_items,
    t.labels,
    t.weights,
    t.priors,
    t.max_seq_len,
    t.mask_tokens,
    t.pad_tokens,
    t.rng,
);

In [12]:
function get_inputs(sentences, max_seq_len, pad_tokens, rng)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)) + 1, max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(sentences, seq_len, pad_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, tokens[1] .!= pad_tokens[1]),
        (1, seq_len, length(sentences)),
    )

    tokens, attention_mask
end;

In [13]:
device(x) = gpu(x)
device(x::AbstractSparseArray) = CUDA.functional() ? CUDA.CuArray(gpu(x)) : collect(x)
device(batch::Tuple) = device.(batch)

function device_free!(x)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(x[1])
    for i = 2:5
        CUDA.unsafe_free!(x[i])
    end
end;

# Loss

In [14]:
gatheragg(a, groups, agg) = Flux.flatten(Flux.batch(agg.(gather.((a,), groups), dims = 2)))

function lm_loss(model, batch, content)
    tokens, attention_mask, labels, weights, priors, masked_token_positions, masked_groups =
        batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        completion = tokens[5],
        position = tokens[1],
    )
    X = model.transformers(X, attention_mask)
    X = gather(X, masked_token_positions)

    if content == "explicit"
        # TODO
    elseif content == "implicit"
        X = gatheragg(X, masked_groups, mean)
        item_pred = logsoftmax(
            transpose(model.embed.embeddings.item.embedding) *
            model.classifier.item.transform(X) .+ model.classifier.item.output_bias.b,
        )
        return -sum(labels .* weights .* item_pred)
    else
        @assert false
    end
end;

In [15]:
function split_losses(users, t::Trainer)
    losses = 0.0
    loss_weights = 0.0
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, t) |> device
        loss_weights += sum(batch[4])
        losses += lm_loss(t.model, batch, t.content)
        batch |> device_free!
    end
    losses / loss_weights
end;

In [16]:
function train_epoch!(users, t::Trainer)
    users = Random.shuffle(t.rng, users)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, t) |> device
        grads = Flux.gradient(t.model) do m
            lm_loss(m, batch, t.content)
        end
        batch |> device_free!
        Flux.update!(t.opt, t.model, grads[1])
    end
end;

In [17]:
function checkpoint(users, t::Trainer, epoch, name)
    @info "evaluating metrics"
    metrics = split_losses(users, t)
    write_params(
        Dict("m" => t.model |> cpu, "epoch" => epoch, "metrics" => metrics),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [18]:
function load_pretrained_model(checkpoint, task, content)
    params = read_params(checkpoint)
    config = params["training_config"]
    use_ptw = config["include_ptw_impressions"]
    sentences = get_training_data(task, use_ptw; show_progress_bar = true)
    labels = get_labels(task, content)
    weights = get_weights(task, content)
    priors = get_priors(task, content)
    masked_items = get_masked_items(task, content)
    model = params["m"] |> gpu
    opt = Optimisers.setup(
        OptimiserChain(Adam(1f-4, (0.9f0, 0.999f0)), WeightDecay(1f-6)),
        model,
    )
    trainer = Trainer(
        # finetuning domain
        task = task,
        content = content,
        # data
        sentences = sentences,
        masked_items = masked_items,
        labels = labels,
        weights = weights,
        priors = priors,
        # model
        model = model,
        max_seq_len = config["max_sequence_length"],
        mask_tokens = config["mask_tokens"],
        pad_tokens = config["pad_tokens"],
        # training
        batch_size = config["batch_size"],
        opt = opt,
        rng = Random.Xoshiro(20230102),
    )
    trainer
end;

In [19]:
function get_users(rng, task, content)
    training = collect(Set(get_split("validation", task, content).user))
    test = collect(Set(get_split("test", task, content).user))
    training, test
end;

# Actually Train Model!

In [20]:
pretrain_checkpoint = "all/Transformer/mask/checkpoints/8";

In [21]:
trainer = load_pretrained_model(pretrain_checkpoint, task, content);

[32mProgress: 100%|███████████████████████████| Time: 0:03:55 ( 1.33 μs/it)[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:15[39m


In [22]:
training, validation = get_users(trainer.rng, trainer.task, trainer.content);

In [None]:
for epoch = 1:100
    train_epoch!(training, trainer)        
    checkpoint(validation, trainer, epoch, name)            
end;

[32mProgress: 100%|█████████████████████████████████████████| Time: 1:08:46[39mm51[39mm
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20230107 05:02:56 evaluating metrics
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:18:51[39mm43[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20230107 05:21:51 saving model after 1 epochs with metrics 5.692785302353296
[32mProgress: 100%|█████████████████████████████████████████| Time: 1:07:49[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20230107 06:29:41 evaluating metrics
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:19:00[39m
[38;5;6m[1m[ [22m[39m[38;5;6m[1mInfo: [22m[39m20230107 06:48:42 saving model after 2 epochs with metrics 5.627291729718357
[32mProgress:  79%|████████████████████████████████▎        |  ETA: 0:14:30[39m