In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");

In [None]:
using Flux
using Transformers
using Transformers.Basic
import CUDA
import Random: shuffle!
import StatsBase: mean, sample

# Tokenize training data

In [None]:
const BASE_VOCAB_SIZE = num_items()
const CLS_TOKEN = num_items() + 1
const PAD_TOKEN = num_items() + 2
const MASK_TOKEN = num_items() + 3
const VOCAB_SIZE = num_items() + 3;

In [None]:
function get_training_data()
    sentences = Dict{Int32,Vector{Int32}}()
    for task in ALL_TASKS
        df = get_split("training", task, "implicit")
        @showprogress for i = 1:length(df.user)
            if df.user[i] ∉ keys(sentences)
                sentences[df.user[i]] = Int32[]
            end
            push!(sentences[df.user[i]], df.item[i])
        end
    end
    [sentences[k] for k in keys(sentences)]
end;

In [None]:
function pad_sentence(sentence, max_seq_length, cls_token, pad_token)
    output = fill(pad_token, max_seq_length)
    output[1] = cls_token
    seq_len = max_seq_length - 1
    if length(sentence) > seq_len
        # take a random contiguous subset
        idx = rand(1:length(sentence)-seq_len)
        sentence = sentence[idx:idx+seq_len-1]
    end
    output[2:1+length(sentence)] .= sentence
    output
end;

In [None]:
function get_token_ids(sentences, max_seq_length, batch_size, cls_token, pad_token)
    hcat(
        [
            pad_sentence(rand(sentences), max_seq_length, cls_token, pad_token) for
            _ = 1:batch_size
        ]...,
    )
end;

In [None]:
function get_batch(
    sentences,
    max_seq_len,
    batch_size;
    vocab_size = BASE_VOCAB_SIZE,
    cls_token = CLS_TOKEN,
    pad_token = PAD_TOKEN,
    mask_token = MASK_TOKEN,
)
    # get tokenized sentences
    tokens = get_token_ids(sentences, max_seq_len, batch_size, cls_token, pad_token)

    # don't attend to padding tokens
    attention_mask =
        reshape(convert.(Float32, tokens .!= PAD_TOKEN), (1, max_seq_len, batch_size))

    # apply BERT masking
    masked_token_positions = []
    labels = []
    for b = 1:batch_size
        seq_len = Int(sum(attention_mask[:, :, b]))
        nsamples = seq_len * 0.15
        nsamples = Int(floor(nsamples) + (rand() < nsamples - floor(nsamples)))
        mask_samples = sample(1:seq_len, nsamples, replace = false)
        for m in mask_samples
            push!(labels, (tokens[m, b], length(labels) + 1))
            r = rand(Float32)
            if r < 0.8
                # replace with <mask>
                tokens[m, b] = mask_token
            elseif r < 0.9
                # replace with <random>                
                tokens[m, b] = rand(1:vocab_size)
            else
                # keep token unchanged
                nothing
            end
            push!(masked_token_positions, (m, b))
        end
    end

    if length(labels) == 0
        return get_batch(
            sentences,
            max_seq_len,
            batch_size;
            vocab_size = vocab_size,
            cls_token = cls_token,
            pad_token = pad_token,
            mask_token = mask_token,
        )
    end

    tokens, attention_mask, masked_token_positions, labels
end;

In [None]:
function device(batch)
    gpu(batch[1]), gpu(batch[2]), gpu(batch[3]), gpu(batch[4])
end

function device_free!(batch)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!(batch[1])
    CUDA.unsafe_free!(batch[2])
end;

# Create Model

In [None]:
const config = Dict(
    "attention_probs_dropout_prob" => 0.1,
    "hidden_act" => gelu,
    "num_hidden_layers" => 6, # halving the number of hidden layers
    "hidden_size" => 768,
    "max_sequence_length" => 512,
    "vocab_size" => VOCAB_SIZE,
    "num_attention_heads" => 12,
    "hidden_dropout_prob" => 0.1,
    "intermediate_size" => 3072,
);

In [None]:
function create_bert(config)
    bert = Bert(
        config["hidden_size"],
        config["num_attention_heads"],
        config["intermediate_size"],
        config["num_hidden_layers"];
        act = config["hidden_act"],
        pdrop = config["hidden_dropout_prob"],
        attn_pdrop = config["attention_probs_dropout_prob"],
    )

    tok_emb = Embed(config["hidden_size"], config["vocab_size"])

    posi_emb = PositionEmbedding(
        config["hidden_size"],
        config["max_sequence_length"];
        trainable = true,
    )

    emb_post = Positionwise(
        LayerNorm(config["hidden_size"]),
        Dropout(config["hidden_dropout_prob"]),
    )

    cls = Positionwise(Dense(config["hidden_size"], config["vocab_size"]), logsoftmax)

    emb = CompositeEmbedding(tok = tok_emb, pe = posi_emb, postprocessor = emb_post)
    clf = (cls = cls)

    TransformerModel(emb, bert, clf)
end;

# Loss metrics

In [None]:
function masklm_loss(model, batch)
    tokens, attention_mask, masked_token_positions, masked_token_labels = batch
    X = model.embed(tok = tokens, pe = tokens)
    X = model.transformers(X, attention_mask)
    X = model.classifier(gather(X, masked_token_positions))
    -mean(gather(X, masked_token_labels))
end;

In [None]:
function accuracy(model, batch)
    tokens, attention_mask, masked_token_positions, masked_token_labels = batch
    X = model.embed(tok = tokens, pe = tokens)
    X = model.transformers(X, attention_mask)
    X = model.classifier(gather(X, masked_token_positions))
    mislabel_count =
        sum(X .> reshape(gather(X, masked_token_labels), (1, size(X)[2])), dims = 1)
    sum(mislabel_count .== 0), length(mislabel_count)
end

function accuracy(model, sentences, max_seq_length, batch_size)
    totals = [0, 0]
    @showprogress for i = 1:Int(ceil(length(sentences) / batch_size))
        batch = get_batch(sentences, max_seq_length, batch_size) |> device
        totals .+= accuracy(model, batch)
        device_free!(batch)
    end
    totals[1] / totals[2]
end;

In [None]:
function train_epoch!(model, opt, sentences, max_seq_length, batch_size, iters)
    ps = Flux.params(model)
    batchloss(batch) = masklm_loss(model, batch)
    @showprogress for i = 1:iters
        batch = get_batch(sentences, max_seq_length, batch_size) |> device
        Flux.train!(batchloss, ps, [(batch,)], opt)
        device_free!(batch)
    end
end;

In [None]:
function checkpoint(model, sentences, max_seq_length, batch_size, iters)
    write_params(Dict("m" => cpu(model), "iters" => iters), "Transformer/$iters")
    acc = accuracy(model, sentences, max_seq_length, batch_size)
    @info "saving model after $iters iters with accuracy $acc"
end;

In [None]:
sentences = get_training_data()
shuffle!(sentences)
cutoff = Int(round(0.95 * length(sentences)))
training_sentences = sentences[1:cutoff]
validation_sentences = sentences[cutoff+1:end];

In [None]:
max_seq_length = config["max_sequence_length"]
batch_size = 8
checkpoint_iters = 20000;

In [None]:
ryouko = create_bert(config) |> gpu;

In [None]:
opt = ADAMW(1e-4, (0.9, 0.999), 1e-4 * 0.01) # defaults taken from the BERT paper
# todo learning rate scheduling decay

In [None]:
iters = 0
while true
    train_epoch!(ryouko, opt, training_sentences, max_seq_length, batch_size, checkpoint_iters)
    iters += checkpoint_iters
    checkpoint(ryouko, validation_sentences, max_seq_length, batch_size, iters)
end