In [None]:
task = "random";

In [None]:
name = "$task/Transformer/test";

In [None]:
import NBInclude: @nbinclude
import SparseArrays: AbstractSparseArray, sparse
@nbinclude("../Alpha.ipynb")
@nbinclude("Transformer.ipynb");

# Data

In [None]:
@with_kw struct Trainer
    model::Any
    sentences::Any
    labels::Any
    weights::Any
    priors::Any
    batch_size::Any
    max_seq_len::Any
    cls_tokens::Any
    pad_tokens::Any
    opt::Any
    rng::Any
end;

In [None]:
function get_labels(content)
    df = get_split("validation", task, content)
    sparse(df.item, df.user, df.rating)
end;

In [None]:
function get_weights(content)
    df = get_split("validation", task, content)
    w = powerdecay(get_counts("validation", task, content), weighting_scheme("inverse"))
    sparse(df.item, df.user, w)
end;

In [None]:
function get_priors(content)
    df = read_alpha("$task/ExplicitUserItemBiases", "validation", task, content)
    sparse(df.item, df.user, df.rating)
end;

# Batching

In [None]:
function get_batch(
    users,
    sentences,
    labels,
    weights,
    priors,
    max_seq_len,
    cls_tokens,
    pad_tokens,
    rng,
)
    inputs =
        get_inputs([sentences[x] for x in users], max_seq_len, cls_tokens, pad_tokens, rng)
    output_labels = map(x -> x[:, users], labels)
    output_weights = map(x -> x[:, users], weights)
    output_priors = map(x -> x[:, users], priors)
    (inputs..., output_labels, output_weights, output_priors)
end

get_batch(users, t::Trainer) = get_batch(
    users,
    t.sentences,
    t.labels,
    t.weights,
    t.priors,
    t.max_seq_len,
    t.cls_tokens,
    t.pad_tokens,
    t.rng,
);

In [None]:
function get_inputs(sentences, max_seq_len, cls_tokens, pad_tokens, rng)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)) + 1, max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(sentences, seq_len, cls_tokens, pad_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, tokens[1] .!= pad_tokens[1]),
        (1, seq_len, length(sentences)),
    )

    tokens, attention_mask
end;

In [None]:
function device(x::AbstractSparseArray)
    CUDA.functional() ? CUDA.CuArray(gpu(x)) : collect(x)
end

function device(batch)
    gpu.(batch[1]), gpu(batch[2]), device.(batch[3]), device.(batch[4]), device.(batch[5])
end

function device_free!(batch)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    CUDA.unsafe_free!.(batch[3])
    CUDA.unsafe_free!.(batch[4])
    CUDA.unsafe_free!.(batch[5])
end;

# Model

In [None]:
function create_model(checkpoint)
    params = read_params(checkpoint)
    config = params["training_config"]
    hidden_size = params["model_config"]["hidden_size"]
    model = params["m"]
    classifier =
        (item = Dense(hidden_size, num_items()), rating = Dense(hidden_size, num_items()))
    TransformerModel(model.embed, model.transformers, classifier)
end;

# Loss

In [None]:
function lm_losses(model, batch)
    tokens, attention_mask, labels, weights = batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        completion = tokens[5],
        position = tokens[1],
    )
    X = model.transformers(X, attention_mask)
    X = X[:, 1, :]

    pred_item = logsoftmax(model.classifier.item(X))
    item_loss = sum(weights[1] .* -labels[1] .* pred_item)
    pred_rating = model.classifier.rating(X)
    rating_loss = sum(weights[2] .* (pred_rating - labels[2]) .^ 2)
    item_loss, rating_loss
end;

In [None]:
function split_losses(users, t::Trainer)
    losses = zeros(2)
    loss_weights = zeros(2)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, t) |> device
        loss_weights .+= sum.(batch[4])
        losses .+= lm_losses(t.model, batch)
        batch |> device_free!
    end
    losses ./ loss_weights
end;

In [None]:
function train_epoch!(users, t::Trainer)
    users = Random.shuffle(t.rng, users)
    user_batches = collect(Iterators.partition(users, t.batch_size))
    @showprogress for user_batch in user_batches
        batch = get_batch(user_batch, t) |> device
        grads = Flux.gradient(t.model) do m
            sum(lm_losses(m, batch))
        end
        batch |> device_free!
        Flux.update!(t.opt, t.model, grads[1])
    end
end;

In [None]:
function checkpoint(users, t::Trainer, epoch, name)
    @info "evaluating metrics"
    metrics = split_losses(users, t)
    write_params(
        Dict("m" => t.model |> cpu, "epoch" => epoch, "metrics" => metrics),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [None]:
function load_pretrained_model(checkpoint)
    params = read_params(checkpoint)
    config = params["training_config"]
    use_ptw = config["include_ptw_impressions"]
    sentences = get_training_data(task, use_ptw; show_progress_bar = true)
    labels = (get_labels("implicit"), get_labels("explicit"))
    weights = (get_weights("implicit"), get_weights("explicit"))
    priors = (get_priors("implicit"), get_priors("explicit"))

    model = create_model(checkpoint) |> gpu
    opt = Optimisers.setup(
        OptimiserChain(Adam(1f-4, (0.9f0, 0.999f0)), WeightDecay(1f-6)),
        model,
    )
    trainer = Trainer(
        model = model,
        sentences = sentences,
        labels = labels,
        weights = weights,
        priors = priors,
        batch_size = config["batch_size"],
        max_seq_len = config["max_sequence_length"],
        cls_tokens = config["cls_tokens"],
        pad_tokens = config["pad_tokens"],
        opt = opt,
        rng = Random.Xoshiro(20230102),
    )
    trainer
end;

In [None]:
function get_users(rng)
    users = collect(Set(get_split("validation", task, "implicit").user))
    Random.shuffle!(rng, users)
    cutoff = Int(round(0.95 * length(users)))
    users[1:cutoff], users[cutoff+1:end]
end;

# Actually Train Model!

In [None]:
pretrain_checkpoint = "all/Transformer/small/checkpoints/8";

In [None]:
trainer = load_pretrained_model(pretrain_checkpoint);

In [None]:
training, validation = get_users(trainer.rng);

In [None]:
# split_losses(validation, trainer)

In [None]:
# for epoch = 1:100
#     train_epoch!(training, trainer)        
#     checkpoint(validation, trainer, epoch, name)            
# end;