In [1]:
task = "causal"
content = "implicit";

In [2]:
name = "$task/Transformer/$content";

In [3]:
import NBInclude: @nbinclude
import SparseArrays: AbstractSparseArray, sparse
@nbinclude("../Alpha.ipynb")
@nbinclude("Transformer.ipynb");



# Data

In [4]:
@with_kw struct Trainer
    # finetuning domain
    task::Any
    content::Any
    # data
    sentences::Any
    masked_items::Any
    labels::Any
    weights::Any
    priors::Any
    # model
    model::Any
    max_seq_len::Any
    mask_tokens::Any
    pad_tokens::Any
    # training
    batch_size::Any
    opt::Any
    rng::Any
end;

In [5]:
function get_labels(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    sparse(df.item, df.user, df.rating, num_items() + 4, num_users())
end;

In [6]:
function get_weights(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    w = vcat(
        powerdecay(get_counts("validation", task, content), weighting_scheme("inverse")),
        powerdecay(get_counts("test", task, content), weighting_scheme("inverse")),
    )
    sparse(df.item, df.user, w, num_items() + 4, num_users())
end;

In [7]:
function get_priors(task, content)
    df = cat(
        read_alpha("$task/ExplicitUserItemBiases", "validation", task, content),
        read_alpha("$task/ExplicitUserItemBiases", "test", task, content),
    )
    sparse(df.item, df.user, df.rating, num_items() + 4, num_users())
end;

In [8]:
function get_masked_items(task, content)
    df = cat(get_split("validation", task, content), get_split("test", task, content))
    user_to_items = Dict()
    @showprogress for i = 1:length(df.user)
        if df.user[i] ∉ keys(user_to_items)
            user_to_items[df.user[i]] = Tuple{Int32,Float32}[]
        end
        push!(user_to_items[df.user[i]], (df.item[i], df.timestamp[i]))
    end
    user_to_items
end;

In [9]:
function get_sentence(sentences, x)
    try
        return copy(sentences[x])
    catch KeyError
        return eltype(values(trainer.sentences))(undef, 0)
    end
end;

# Batching

In [10]:
function replace_item(word, item_timestamp)
    # TODO also replace the ts
    item, timestamp = item_timestamp
    ts = encode_raw_timestamp(timestamp)
    (item, word[2:end]...)
end

function replace_timestamp(word, timestamp)
    ts = encode_raw_timestamp(timestamp)
    (word[1:2]..., ts, word[4:end]...)
end;

In [11]:
function get_timestamp(item_timestamps)
    minimum(map(x -> x[2], item_timestamps))
end;

get_timestamp (generic function with 1 method)

In [12]:
# function random_dropout(rng, sentence, p)
#     new_sentence = eltype(sentence)[]
#     for i = 1:length(sentence)
#         if rand(rng) >= p
#             push!(new_sentence, sentence[i])
#         end
#     end
#     new_sentence
# end;

In [13]:
function get_batch(
    users,
    training,
    task,
    content,
    sentences,
    masked_items,
    labels,
    weights,
    priors,
    max_seq_len,
    mask_tokens,
    pad_tokens,
    rng,
)
    raw_sentences = [get_sentence(sentences, x) for x in users]
    # TODO test to completion
    # if training
    #     raw_sentences = random_dropout.(rng, raw_sentences, 0.1)
    # end
    processed_sentences = eltype(values(sentences))[]
    masked_token_positions = Tuple{Int32,Int32}[]
    if content == "explicit"
        output_labels = Float32[]
        output_weights = Float32[]
        output_priors = Float32[]
    elseif content == "implicit"
        output_labels = labels[:, users]
        output_weights = weights[:, users]
        output_priors = priors[:, users]
    else
        @assert false
    end

    function sample_item(user, weight)
        item, timestamp = rand(rng, masked_items[user])
        push!(output_labels, labels[item, user])
        push!(output_priors, priors[item, user])
        push!(output_weights, weight)
        item, timestamp
    end

    for i::Int32 = 1:length(raw_sentences)
        s = raw_sentences[i]
        if task == "random"
            s = subset_sentence(s, max_seq_len - 1; recent = false, rng = rng)
            insert!(s, 1, mask_tokens)
            push!(masked_token_positions, (1, i))
            if content == "explicit"
                s[1] = replace_item(s[1], sample_item(users[i], 1))
            end
        elseif task in ["causal", "temporal"]
            s = subset_sentence(s, max_seq_len - 1; recent = true, rng = rng)
            masked_word = mask_tokens
            if task == "temporal"
                masked_word = replace_timestamp(mask_tokens, 1)
            else
                masked_word =
                    replace_timestamp(mask_tokens, get_timestamp(masked_items[users[i]]))
            end
            if content == "explicit"
                masked_word = replace_item(word, sample_item(users[i], 1))
            end
            push!(s, masked_word)
            push!(masked_token_positions, (Int32(length(s)), i))
        else
            @assert false
        end
        push!(processed_sentences, s)
    end


    inputs = get_inputs(processed_sentences, max_seq_len, pad_tokens, rng)
    if content == "explicit"
        to_explicit_output(x) = convert.(Float32, collect(x'))
        output_labels = to_explicit_output(output_labels)
        output_weights = to_explicit_output(output_weights)
        output_priors = to_explicit_output(output_priors)
    end
    (inputs..., output_labels, output_weights, output_priors, masked_token_positions)
end

get_batch(users, t::Trainer, training::Bool) = get_batch(
    users,
    training,
    t.task,
    t.content,
    t.sentences,
    t.masked_items,
    t.labels,
    t.weights,
    t.priors,
    t.max_seq_len,
    t.mask_tokens,
    t.pad_tokens,
    t.rng,
);

In [14]:
function get_inputs(sentences, max_seq_len, pad_tokens, rng)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)) + 1, max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(sentences, seq_len, pad_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, tokens[1] .!= pad_tokens[1]),
        (1, seq_len, length(sentences)),
    )

    tokens, attention_mask
end;

In [15]:
device(x) = gpu(x)
device(x::AbstractSparseArray) = CUDA.functional() ? CUDA.CuArray(gpu(x)) : collect(x)
device(batch::Tuple) = device.(batch)

function device_free!(x)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(x[1])
    for i = 2:5
        CUDA.unsafe_free!(x[i])
    end
end;

# Loss

In [16]:
function lm_loss(model, batch, content)
    tokens, attention_mask, labels, weights, priors, masked_token_positions = batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        completion = tokens[5],
        position = tokens[1],
    )
    X = model.transformers(X, attention_mask)
    X = gather(X, masked_token_positions)

    # TODO avg losses instead of sum loss

    if content == "explicit"
        rating_pred = model.classifier.rating(X)
        return sum((rating_pred - labels) .^ 2 .* weights)
    elseif content == "implicit"
        item_pred = logsoftmax(
            transpose(model.embed.embeddings.item.embedding) *
            model.classifier.item.transform(X) .+ model.classifier.item.output_bias.b,
        )
        return -sum(labels .* weights .* item_pred)
    else
        @assert false
    end
end;

In [17]:
function split_losses(users, t::Trainer)
    losses = 0.0
    loss_weights = 0.0
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, t, false) |> device
        loss_weights += sum(batch[4])
        losses += lm_loss(t.model, batch, t.content)
        batch |> device_free!
    end
    losses / loss_weights
end;

In [18]:
function train_epoch!(users, t::Trainer)
    users = Random.shuffle(t.rng, users)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, t, true) |> device
        grads = Flux.gradient(t.model) do m
            lm_loss(m, batch, t.content)
        end
        batch |> device_free!
        Flux.update!(t.opt, t.model, grads[1])
    end
end;

In [19]:
function checkpoint(users, t::Trainer, epoch, name)
    @info "evaluating metrics"
    metrics = split_losses(users, t)
    write_params(
        Dict("m" => t.model |> cpu, "epoch" => epoch, "metrics" => metrics),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [20]:
function get_model(checkpoint)
    params = read_params(checkpoint)
    m = params["m"]
    # config = params["model_config"]
    # item_cls = (
    #     transform = Chain(
    #         Dense(config["hidden_size"], config["hidden_size"], tanh),
    #         Dropout(0.1),
    #         Dense(config["hidden_size"], config["hidden_size"]),
    #         LayerNorm(config["hidden_size"]),
    #     ),
    #     output_bias = BiasLayer(config["vocab_sizes"][1]),
    # )
    # rating_cls = Dense(config["hidden_size"], 1)
    # clf = (item = item_cls, rating = rating_cls)
    # TransformerModel(m.embed, m.transformers, clf)
    m
end;

In [21]:
function load_pretrained_model(checkpoint, task, content)
    params = read_params(checkpoint)
    config = params["training_config"]
    use_ptw = config["include_ptw_impressions"]
    sentences = get_training_data(task, use_ptw; show_progress_bar = true)
    labels = get_labels(task, content)
    weights = get_weights(task, content)
    priors = get_priors(task, content)
    masked_items = get_masked_items(task, content)
    model = get_model(checkpoint) |> gpu
    lr = 3e-5
    opt = Optimisers.setup(
        OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * 1f-2)),
        model,
    )
    trainer = Trainer(
        # finetuning domain
        task = task,
        content = content,
        # data
        sentences = sentences,
        masked_items = masked_items,
        labels = labels,
        weights = weights,
        priors = priors,
        # model
        model = model,
        max_seq_len = config["max_sequence_length"],
        mask_tokens = config["mask_tokens"],
        pad_tokens = config["pad_tokens"],
        # training
        batch_size = config["batch_size"],
        opt = opt,
        rng = Random.Xoshiro(20230102),
    )
    trainer
end;

In [22]:
function get_users(rng, task, content)
    training = collect(Set(get_split("validation", task, content).user))
    test = collect(Set(get_split("test", task, content).user))
    training, test
end;

# Actually Train Model!

In [23]:
pretrain_checkpoint = "all/Transformer/mask/checkpoints/8"
trainer = load_pretrained_model(pretrain_checkpoint, task, content);

[32mProgress: 100%|███████████████████████████| Time: 0:04:18 ( 1.33 μs/it)[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:06[39m


In [24]:
training, validation = get_users(trainer.rng, trainer.task, trainer.content);

In [None]:
for epoch = 1:100
    train_epoch!(training, trainer)        
    checkpoint(validation, trainer, epoch, name)            
end;

[32mProgress:  74%|██████████████████████████████▍          |  ETA: 0:18:39[39mm55[39mm