In [None]:
# itemid, rating, updated_at, status, position, userid
const wordtype = Tuple{Int32,Float32,Float32,Int32,Int32,Int32}

function get_wordtype_index(field::Symbol)
    if field == :itemid
        return 1
    elseif field == :rating
        return 2
    elseif field == :updated_at
        return 3
    elseif field == :status
        return 4
    elseif field == :position
        return 5
    elseif field == :userid
        return 6
    else
        @assert false
    end
end

extract(word, field::Symbol) = word[get_wordtype_index(field)]

function replace(word, fieldname::Symbol, field)
    pos = get_wordtype_index(fieldname)
    (word[1:pos-1]..., field, word[pos+1:end]...)
end

function encode_word(itemid, rating, updated_at, status, position, userid)
    word = (itemid, rating, updated_at, status, position, userid)
    convert(wordtype, word)
end;

In [None]:
function get_training_data(cls_tokens, userids = nothing)
    function get_df(medium, userids)
        df = get_raw_split(
            "training",
            medium,
            [:medium, :userid, :itemid, :status, :rating, :update_order, :updated_at],
            nothing,
        )
        if !isnothing(userids)
            df = filter(df, df.userid .∈ (Set(userids),))
        end
        @assert ALL_MEDIUMS == ["manga", "anime"]
        @tprogress Threads.@threads for i = 1:length(df.userid)
            for j = 1:df.medium[i]
                df.itemid[i] += num_items(ALL_MEDIUMS[j])
            end
        end
        df
    end
    @info "loading training splits"
    dfs = [get_df(medium, userids) for medium in ALL_MEDIUMS]

    @info "processing training splits"
    chunks = 128
    sentences = Vector{Dict{Int32,Vector{wordtype}}}(undef, chunks)
    function save_sentence(t)
        df = reduce(cat, [filter(x, @. (x.userid % chunks) + 1 == t) for x in dfs])
        sentences[t] = get_training_data(df, cls_tokens)
    end
    tforeach(save_sentence, 1:chunks, 8)
    merge(sentences...)
end;

In [None]:
function get_training_data(df::RatingsDataset, cls_tokens)
    sentences = Dict{Int32,Vector{wordtype}}()
    order = sortperm(collect(zip(df.updated_at, df.update_order)))
    for idx = 1:length(order)
        i = order[idx]
        if df.userid[i] ∉ keys(sentences)
            sentences[df.userid[i]] = [replace(cls_tokens, :userid, df.userid[i])]
        end
        word = encode_word(
            df.itemid[i],
            df.rating[i],
            df.updated_at[i],
            df.status[i],
            length(sentences[df.userid[i]]) - 1,
            df.userid[i],
        )
        push!(sentences[df.userid[i]], word)
    end
    sentences
end;

In [None]:
function subset_sentence(sentence, max_seq_length; recent)
    if length(sentence) <= max_seq_length
        return sentence
    end
    if recent
        # keep the rightmost entries
        idx = length(sentence) - max_seq_length + 1
    else
        # take a random contiguous subset            
        idx = rand(1:length(sentence)-max_seq_length+1)
    end
    sentence[idx:idx+max_seq_length-1]
end;

In [None]:
function get_token_ids(sentence, max_seq_length, pad_tokens, cls_tokens)
    outputs = fill.(pad_tokens, max_seq_length)
    @assert length(sentence) <= max_seq_length
    function process(word, i)
        # update the position for sentences that are longer than max_seq_len
        if i == get_wordtype_index(:position) &&
           extract(word, :itemid) != extract(cls_tokens, :itemid)
            return (word[i] % max_seq_length)
        else
            return word[i]
        end
    end
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            outputs[j][i] = process(sentence[i], j)
        end
    end
    outputs
end;