In [None]:
# animeid, mangaid, rating, timestamp, status, completion, user, position
const wordtype = Tuple{Int32,Int32,Float32,Float32,Int32,Float32,Int32,Int32}

function get_wordtype_index(field::Symbol)
    if field == :anime
        return 1
    elseif field == :manga
        return 2       
    elseif field == :rating
        return 3
    elseif field == :timestamp
        return 4
    elseif field == :status
        return 5
    elseif field == :completion
        return 6
    elseif field == :user
        return 7
    elseif field == :position
        return 8
    else
        @assert false
    end
end

extract(word, field::Symbol) = word[get_wordtype_index(field)]

function replace(word, fieldname::Symbol, field)
    pos = get_wordtype_index(fieldname)
    (word[1:pos-1]..., field, word[pos+1:end]...)
end

function encode_word(
    animeid,
    mangaid,
    rating,
    timestamp,
    status,
    completion,
    user,
    position,
)
    word =
        (animeid, mangaid, rating, timestamp, status, completion, user, position)
    convert(wordtype, word)
end

In [None]:
function get_training_data(df::RatingsDataset, media, cls_tokens, empty_tokens)
    function itemids(df, i)
        if media[df.source[i]] == "manga"
            return (extract(empty_tokens, :anime), df.item[i])
        elseif media[df.source[i]] == "anime"
            return (df.item[i], extract(empty_tokens, :manga))
        else
            @assert false
        end
    end    
    
    sentences = Dict{Int32,Vector{wordtype}}()
    order = sortperm(df.timestamp)
    for idx = 1:length(order)
        i = order[idx]
        if df.user[i] ∉ keys(sentences)
            sentences[df.user[i]] = [replace(cls_tokens, :user, df.user[i])]
        end
        word = encode_word(
            itemids(df, i)...,
            df.rating[i],
            df.timestamp[i],
            df.status[i],
            df.completion[i],
            df.user[i],
            length(sentences[df.user[i]]),
        )
        push!(sentences[df.user[i]], word)
    end
    sentences
end;

In [None]:
function get_training_data(
    task,
    media,
    include_ptw,
    cls_tokens,
    empty_tokens;
    explicit_baseline = nothing,
)
    function get_df(task, content, medium, baseline)
        df = get_raw_split("training", task, content, medium)
        if content == "explicit"
            if !isnothing(baseline)
                Threads.@threads for i = 1:length(df.rating)
                    df.rating[i] -=
                        baseline[medium]["u"][df.user[i]] +
                        baseline[medium]["a"][df.item[i]]
                end
            end
        else
            df.rating .= 11
        end
        df.source .= findfirst(x -> x == df.medium, media)
        Threads.@threads for i = 1:length(df.timestamp)
            df.timestamp[i] = universal_timestamp(df.timestamp[i], df.medium)
        end
        @set df.medium = ""
    end
    contents = ["explicit", "implicit"]
    if include_ptw
        push!(contents, "ptw")
    end
    @info "loading training splits for $task"
    dfs = [
            get_df(task, content, medium, explicit_baseline) for content in contents for
            medium in media
    ]

    @info "processing training splits for $task"
    chunks = 128
    sentences = [Dict{Int32,Vector{wordtype}}() for _ in 1:chunks]
    @tprogress Threads.@threads for t = 1:chunks
        df = reduce(
            cat,
            [
                 filter(x, @. (x.user % chunks) + 1 == t) for x in dfs
            ]
        )        
        partition = get_training_data(
            df,
            media,
            cls_tokens,
            empty_tokens,
        )
        for (k, v) in partition
            sentences[t][k] = v
        end
    end
    merge(sentences...)
end;

In [None]:
function subset_sentence(sentence, max_seq_length; recent, keep_first, rng)
    if keep_first
        # the first token is usually a cls_token that embeds user metadata
        cls_token = sentence[1]
        sentence = sentence[2:end]
        max_seq_length -= 1
    end
    if length(sentence) > max_seq_length
        if recent
            # keep the rightmost entries
            idx = length(sentence) - max_seq_length + 1
        else
            # take a random contiguous subset            
            idx = rand(rng, 1:length(sentence)-max_seq_length+1)
        end
        sentence = sentence[idx:idx+max_seq_length-1]
    end
    if keep_first
        pushfirst!(sentence, cls_token)
    end    
    sentence
end;

In [None]:
function pad_sentence(sentence, max_seq_length, max_position, pad_tokens, cls_tokens)
    outputs = fill.(pad_tokens, max_seq_length)
    @assert length(sentence) <= max_seq_length "$(length(sentence)) $max_seq_length"
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            if j == get_wordtype_index(:position) && extract(sentence[i], :anime) != extract(cls_tokens, :anime)
                p = (sentence[i][j] % max_position)
                if p == 0
                    p = max_position
                end
                outputs[j][i] = p
            else
                outputs[j][i] = sentence[i][j]
            end
        end
    end
    outputs
end;

In [None]:
function get_token_ids(sentences, max_seq_length, max_position, pad_tokens, cls_tokens)
    padded_sentences = [
        pad_sentence(x, max_seq_length, max_position, pad_tokens, cls_tokens) for x in sentences
    ]
    Tuple(hcat([x[i] for x in padded_sentences]...) for i = 1:length(pad_tokens))
end;