# Pretrains a tranformer encoder model on watch histories

In [None]:
datapath = "alphas/all/Transformer/v0";

In [None]:
name = "all/Transformer/v21";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/CUDA.ipynb")
@nbinclude("Reference/Include.ipynb");

In [None]:
import Flux
import Flux: cpu, gpu, LayerNorm, logsoftmax
import JSON
import HDF5
import Random
import StatsBase: mean, sample

# Structs

In [None]:
struct Trainer
    model::Any
    opt::Any
    weightdecay::Any
    lr_schedule::Any
    training_config::Any
    model_config::Any
    rng::Any
end;

In [None]:
function device(x::NamedTuple)
    fields = fieldnames(typeof(x))
    NamedTuple{fields}(gpu(x[k]) for k in fields)
end

function device(batch)
    gpu.(batch[1]),
    gpu(batch[2]),
    device(batch[3]),
    device(batch[4]),
    device(batch[5]),
    device(batch[6])
end

CUDA.unsafe_free!(::Nothing) = nothing
device_free!(x) = CUDA.unsafe_free!(x)
function device_free!(x::NamedTuple)
    fields = fieldnames(typeof(x))
    for f in fields
        device_free!(x[f])
    end
end
function device_free!(batch::Tuple)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    device_free!(batch[3])
    device_free!(batch[4])
    device_free!(batch[5])
    device_free!(batch[6])
end;

# Create model

In [None]:
function create_bert(config)
    bert = Bert(
        hidden_size = config["hidden_size"],
        num_attention_heads = config["num_attention_heads"],
        intermediate_size = config["intermediate_size"],
        num_layers = config["num_hidden_layers"];
        activation_fn = config["hidden_act"],
        dropout = config["dropout"],
        attention_dropout = config["attention_dropout"],
    )

    anime_emb = DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :anime))
    manga_emb = DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :manga))
    rating_emb = ContinuousEmbed(config["hidden_size"])
    timestamp_emb = ContinuousEmbed(config["hidden_size"])
    status_emb =
        DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :status))
    completion_emb = ContinuousEmbed(config["hidden_size"])
    position_emb =
        DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :position))
    emb_post = Chain(LayerNorm(config["hidden_size"]), Dropout(config["dropout"]))
    emb = CompositeEmbedding(
        anime = anime_emb,
        manga = manga_emb,
        rating = rating_emb,
        timestamp = timestamp_emb,
        status = status_emb,
        completion = completion_emb,
        position = position_emb,
        postprocessor = emb_post,
    )

    clf = (
        anime = (
            item = Dense(config["hidden_size"], extract(config["vocab_sizes"], :anime)),
            rating = Dense(config["hidden_size"], extract(config["vocab_sizes"], :anime)),
        ),
        manga = (
            item = Dense(config["hidden_size"], extract(config["vocab_sizes"], :manga)),
            rating = Dense(config["hidden_size"], extract(config["vocab_sizes"], :manga)),
        ),
    )

    TransformerModel(emb, bert, clf)
end;

# Loss metrics

In [None]:
function masklm_losses(model, batch)
    tokens, attention_mask, batch_positions, item_positions, labels, weights = batch
    X = model.embed(
        anime = extract(tokens, :anime),
        manga = extract(tokens, :manga),
        rating = extract(tokens, :rating),
        timestamp = extract(tokens, :timestamp),
        status = extract(tokens, :status),
        completion = extract(tokens, :completion),
        position = extract(tokens, :position),
    )
    X = model.transformers(X, attention_mask)

    if length(item_positions[:anime][:item]) > 0
        anime_item_pred = logsoftmax(
            model.classifier.anime.item(gather(X, batch_positions[:anime][:item])),
        )
        anime_item_loss =
            -(
                weights[:anime][:item]' *
                gather(anime_item_pred, item_positions[:anime][:item])
            ) / sum(weights[:anime][:item])
    else
        anime_item_loss = 0.0f0
    end
    if length(item_positions[:anime][:rating]) > 0
        anime_rating_pred =
            model.classifier.anime.rating(gather(X, batch_positions[:anime][:rating]))
        anime_rating_loss =
            (
                weights[:anime][:rating]' *
                (
                    gather(anime_rating_pred, item_positions[:anime][:rating]) -
                    labels[:anime][:rating]
                ) .^ 2
            ) / sum(weights[:anime][:rating])
    else
        anime_rating_loss = 0.0f0
    end

    if length(item_positions[:manga][:item]) > 0
        manga_item_pred = logsoftmax(
            model.classifier.manga.item(gather(X, batch_positions[:manga][:item])),
        )
        manga_item_loss =
            -(
                weights[:manga][:item]' *
                gather(manga_item_pred, item_positions[:manga][:item])
            ) / sum(weights[:manga][:item])
    else
        manga_item_loss = 0.0f0
    end
    if length(item_positions[:manga][:rating]) > 0
        manga_rating_pred =
            model.classifier.manga.rating(gather(X, batch_positions[:manga][:rating]))
        manga_rating_loss =
            (
                weights[:manga][:rating]' *
                (
                    gather(manga_rating_pred, item_positions[:manga][:rating]) -
                    labels[:manga][:rating]
                ) .^ 2
            ) / sum(weights[:manga][:rating])
    else
        manga_rating_loss = 0.0f0
    end

    anime_item_loss, anime_rating_loss, manga_item_loss, manga_rating_loss
end;

# Data

In [None]:
struct Dataloader
    embeds::Any
    positions::Any
    labels::Any
    weights::Any
end

function Dataloader(filename::String)
    f = HDF5.h5open(filename, "r")
    embeds = [
        HDF5.read(f, x) for x in [
            "anime",
            "manga",
            "rating",
            "timestamp",
            "status",
            "completion",
            "user",
            "position",
        ]
    ]
    positions = [
        HDF5.read(f, "positions_$(medium)_$(task)") for medium in ["anime", "manga"] for
        task in ["item", "rating"]
    ]
    labels = [
        HDF5.read(f, "labels_$(medium)_$(task)") for medium in ["anime", "manga"] for
        task in ["item", "rating"]
    ]
    weights = [
        HDF5.read(f, "weights_$(medium)_$(task)") for medium in ["anime", "manga"] for
        task in ["item", "rating"]
    ]
    Dataloader(embeds, positions, labels, weights)
end

Base.length(d::Dataloader) = size(d.embeds[1])[2];

In [None]:
function get_dataloader(outdir, split, batch_size, num_workers)
    # find the next data shard
    dataloader_file = joinpath(outdir, "training", "dataloader.$split")
    if isfile(dataloader_file)
        open(dataloader_file, "r") do f
            worker = parse(Int, readline(f))
        end
    else
        worker = 0
    end
    worker = (worker + 1) % num_workers
    if worker == 0
        worker = num_workers
    end

    # wait for the data shard to be written
    completion_file = joinpath(outdir, "training", "$split.$worker.h5.complete")
    while !isfile(completion_file)
        sleep(1)
    end

    # read the data shard
    data_file = completion_file[1:end-length(".complete")]
    dataloader = Dataloader(data_file)

    # sync and update disk
    open(dataloader_file, "w") do f
        write(f, "$worker")
    end
    rm(completion_file)
    rm(data_file)
    return dataloader
end;

get_dataloader(split, config) =
    get_dataloader(get_data_path(datapath), split, config["batch_size"], 4);

In [None]:
function get_batch(d::Dataloader, batch_size, i)
    idx = (batch_size*(i-1)+1):min(batch_size * i, length(d))
    tokens = Tuple(d.embeds[j][:, idx] for j = 1:length(fieldnames(wordtype)))

    mask = extract(tokens, :user)
    s, b = size(mask)
    attention_mask = reshape(mask, (1, s, b)) .== reshape(mask, (s, 1, b))

    batch_positions = (
        anime = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
        manga = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
    )
    taskidx = 0
    for medium in [:anime, :manga]
        for task in [:item, :rating]
            taskidx += 1
            w = @view d.weights[taskidx][:, idx]
            for b::Int32 = 1:b
                for i::Int32 = 1:s
                    if w[i, b] != 0
                        push!(batch_positions[medium][task], (i, b))
                    end
                end
            end
        end
    end

    item_positions = (
        anime = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
        manga = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
    )
    taskidx = 0
    for medium in [:anime, :manga]
        for task in [:item, :rating]
            taskidx += 1
            p = @view d.positions[taskidx][:, idx]
            w = @view d.weights[taskidx][:, idx]
            for b::Int32 = 1:b
                for i::Int32 = 1:s
                    if w[i, b] != 0
                        push!(
                            item_positions[medium][task],
                            (p[i, b], Int32(length(item_positions[medium][task]) + 1)),
                        )
                    end
                end
            end
        end
    end

    labels = (
        anime = (item = Int32[], rating = Float32[]),
        manga = (item = Int32[], rating = Float32[]),
    )
    taskidx = 0
    for medium in [:anime, :manga]
        for task in [:item, :rating]
            taskidx += 1
            l = @view d.labels[taskidx][:, idx]
            w = @view d.weights[taskidx][:, idx]
            for b::Int32 = 1:b
                for i::Int32 = 1:s
                    if w[i, b] != 0
                        push!(labels[medium][task], l[i, b])
                    end
                end
            end
        end
    end

    weights = (
        anime = (item = Float32[], rating = Float32[]),
        manga = (item = Float32[], rating = Float32[]),
    )
    taskidx = 0
    for medium in [:anime, :manga]
        for task in [:item, :rating]
            taskidx += 1
            w = @view d.weights[taskidx][:, idx]
            for b::Int32 = 1:b
                for i::Int32 = 1:s
                    if w[i, b] != 0
                        push!(weights[medium][task], w[i, b])
                    end
                end
            end
        end
    end

    tokens, attention_mask, batch_positions, item_positions, labels, weights
end;

In [None]:
function read_json(file)
    open(file, "r") do f
        return JSON.parse(read(f, String))
    end
end;

# Training

In [None]:
function train_epoch!(t::Trainer)
    losses = 0.0f0
    steps = 0
    remaining_tokens = t.training_config["tokens_per_epoch"]
    batch_size = t.training_config["batch_size"]
    p = ProgressMeter.Progress(remaining_tokens)
    while remaining_tokens > 0
        dataloader = get_dataloader("training", t.training_config)
        for i = 1:div(length(dataloader), batch_size, RoundUp)
            schedule_learning_rate!(
                t.opt,
                t.weightdecay,
                t.lr_schedule,
                t.training_config["weight_decay"],
            )
            batch = get_batch(dataloader, batch_size, i) |> device
            num_tokens = size(batch[1][1])[1] * size(batch[1][1])[2]
            tloss, grads = Flux.withgradient(t.model) do m
                sum(masklm_losses(m, batch))
            end
            batch |> device_free!
            losses += tloss
            steps += 1
            Flux.update!(t.opt, t.model, grads[1])
            Flux.update!(t.weightdecay, t.model, grads[1])
            remaining_tokens -= num_tokens
            ProgressMeter.next!(p; step = num_tokens)
            if remaining_tokens < 0
                break
            end
        end
    end
    ProgressMeter.finish!(p)
    losses / steps
end;

In [None]:
function evaluate_metrics(t::Trainer)
    losses = zeros(Float32, 4)
    weights = zeros(Float32, 4)
    remaining_sentences = t.training_config["num_validation_sentences"]
    batch_size = t.training_config["batch_size"]
    p = ProgressMeter.Progress(remaining_sentences)
    while remaining_sentences > 0
        dataloader = get_dataloader("validation", t.training_config)
        for i = 1:div(length(dataloader), batch_size, RoundUp)
            batch = get_batch(dataloader, batch_size, i) |> device
            w =
                sum.([
                    batch[6][:anime][:item],
                    batch[6][:anime][:rating],
                    batch[6][:manga][:item],
                    batch[6][:manga][:rating],
                ])
            losses .+= masklm_losses(t.model, batch) .* w
            weights .+= w
            num_sentences = size(batch[1][1])[2]
            device_free!(batch)
            remaining_sentences -= num_sentences
            ProgressMeter.next!(p; step = num_sentences)
            if remaining_sentences < 0
                break
            end
        end
    end
    ProgressMeter.finish!(p)
    names = [
        "Anime Crossentropy Loss",
        "Anime Rating Loss",
        "Manga Crossentropy Loss",
        "Manga Rating Loss",
    ]
    Dict(names .=> losses ./ weights)
end;

In [None]:
function checkpoint(t::Trainer, training_loss, epoch, name)
    @info "evaluating metrics"
    metrics = evaluate_metrics(t)
    metrics["Validation Loss"] = sum(values(metrics))
    metrics["Training Loss"] = training_loss
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "opt" => t.opt |> cpu,
            "weightdecay" => t.weightdecay |> cpu,
            "lr_schedule" => t.lr_schedule,
            "epoch" => epoch,
            "metrics" => metrics,
            "training_config" => t.training_config,
            "model_config" => t.model_config,
            "rng" => t.rng,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    rng
end;

In [None]:
function create_training_config()
    media = ["anime", "manga"]
    base_vocab_sizes = (
        Int32(num_items("anime")),
        Int32(num_items("manga")),
        Float32(11),
        Float32(1),
        Int32(5),
        Float32(1),
        Int32(maximum(num_users(x) for x in media)),
        Int32(512),
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        # training
        "batch_size" => 16,
        "peak_learning_rate" => 3f-4,
        "weight_decay" => 1f-2,
        # data
        "media" => media,
        "num_epochs" => 1,
        # model
        "num_layers" => 4,
        "hidden_size" => 512,
        "max_sequence_length" => extract(base_vocab_sizes, :position),
    )
    d
end;

In [None]:
function create_model_config(training_config)
    # follows the recipe in Section 5 of [Well-Read Students Learn Better: On the 
    # Importance of Pre-training Compact Models](https://arxiv.org/pdf/1908.08962.pdf)
    Dict(
        "attention_dropout" => 0.1,
        "hidden_act" => gelu,
        "num_hidden_layers" => training_config["num_layers"],
        "hidden_size" => training_config["hidden_size"],
        "max_sequence_length" => training_config["max_sequence_length"],
        "vocab_sizes" => training_config["vocab_sizes"],
        "num_attention_heads" => Int(training_config["hidden_size"] / 64),
        "dropout" => 0.1,
        "intermediate_size" => training_config["hidden_size"] * 4,
    )
end;

In [None]:
function load_from_checkpoint(training_config, rng)
    model_config = create_model_config(training_config)
    model = create_bert(model_config) |> gpu
    lr = Float32(config["peak_learning_rate"])
    wd = Float32(config["weight_decay"])
    opt = Optimisers.setup(Adam(lr, (0.9f0, 0.999f0)), model)
    weightdecay = Optimisers.setup(PureWeightDecay(lr * wd), model)
    initialize_weight_decay!(weightdecay, model)
    lr_schedule = get_lr_schedule(config)
    Trainer(model, opt, weightdecay, lr_schedule, training_config, model_config, rng)
end;

# Actually Train Model!

In [None]:
config_checkpoint = nothing
config_epoch = nothing
config_rng = set_rngs(20221221)
config = create_training_config();

In [None]:
data_config = read_json(joinpath(get_data_path(datapath), "training", "config.json"))
config["tokens_per_epoch"] = data_config["tokens_per_epoch"]
config["num_validation_sentences"] = data_config["num_validation_sentences"]
config["iters_per_epoch"] =
    Int(ceil(config["tokens_per_epoch"] / config["max_sequence_length"]));

In [None]:
trainer = load_from_checkpoint(config, config_rng)
@info "Num epochs: $(config["num_epochs"])"
@info "Training model with $(sum(length, Flux.params(trainer.model))) total parameters"
@info "Embedding parameters: $(sum(length, Flux.params(trainer.model.embed)))"
@info "Transformer parameters: $(sum(length, Flux.params(trainer.model.transformers)))"
@info "Classifier parameters: $(sum(length, Flux.params(trainer.model.classifier)))"

In [None]:
training_loss = Inf
for i = 1:trainer.training_config["num_epochs"]
    GC.gc()
    training_loss = train_epoch!(trainer)
    checkpoint(trainer, training_loss, i, name)
end;