In [None]:
function get_wordtype_index(field::Symbol)
    fields = [
        :itemid,
        :rating,
        :updated_at,
        :status,
        :source,
        :created_at,
        :started_at,
        :finished_at,
        :progress,
        :repeat_count,
        :priority,
        :sentiment,
        :sentiment_score,
        :position,
        :userid,
    ]
    findfirst(x -> x == field, fields)
end

const wordtype = Tuple{
    Int32, # itemid
    Float32, # rating
    Float32, # updated_at
    Int32, # status
    Int32, # source
    Float32, # created_at
    Float32, # started_at
    Float32, # finished_at
    Float32, # progress
    Float32, # 1 - 1 / (repeat_count + 1)
    Float32, # 1 - 1 / (priority + 1)
    Int32, # sentiment
    Float32, # sentiment_score
    Int32, # position
    Int32, # userid
}

extract(word, field::Symbol) = word[get_wordtype_index(field)]

function replace(word, fieldname::Symbol, field)
    pos = get_wordtype_index(fieldname)
    (word[1:pos-1]..., field, word[pos+1:end]...)
end;

In [None]:
function get_training_data(cls_tokens, partition, users_to_finetune)
    function get_df(medium, partition, users_to_finetune)
        fields = [
            :itemid,
            :rating,
            :updated_at,
            :status,
            :source,
            :created_at,
            :started_at,
            :finished_at,
            :progress,
            :repeat_count,
            :priority,
            :sentiment,
            :sentiment_score,
            :userid,
            :medium,
            :update_order,
        ]
        df = get_raw_split("training", medium, fields, nothing)
        if !isnothing(partition)
            df = filter(df, df.userid .% partition[2] .== partition[1])
        end
        if !isnothing(users_to_finetune)
            df = filter(df, df.userid .∈ (Set(users_to_finetune),))
            df = training_test_split(df, 1)[1]
        end
        itemid_offset::Int32 = 0
        for j = 1:findfirst(x -> x == medium, ALL_MEDIUMS)-1
            itemid_offset += num_items(ALL_MEDIUMS[j])
        end
        df.itemid .+= itemid_offset
        df
    end
    @info "loading training splits"
    dfs = [get_df(medium, partition, users_to_finetune) for medium in ALL_MEDIUMS]
    GC.gc()

    @info "processing training splits"
    chunks = 128
    sentences = Vector{Dict{Int32,Vector{wordtype}}}(undef, chunks)
    function save_sentence(t)
        df = reduce(cat, [filter(x, @. (x.userid % chunks) + 1 == t) for x in dfs])
        sentences[t] = get_training_data(df, cls_tokens)
    end
    tforeach(save_sentence, 1:chunks, 4)
    merge(sentences...)
end;

In [None]:
function get_training_data(df::RatingsDataset, cls_tokens)
    sentences = Dict{Int32,Vector{wordtype}}()
    # need to sort by updated_at because we're combining multiple_media
    order = sortperm(collect(zip(df.updated_at, -df.update_order)))
    for idx = 1:length(order)
        i = order[idx]
        if df.userid[i] ∉ keys(sentences)
            sentences[df.userid[i]] = [replace(cls_tokens, :userid, df.userid[i])]
        end
        word = convert(
            wordtype,
            (
                df.itemid[i],
                df.rating[i],
                df.updated_at[i],
                df.status[i],
                df.source[i],
                df.created_at[i],
                df.started_at[i],
                df.finished_at[i],
                df.progress[i],
                Float32(1 - 1 / (df.repeat_count[i] + 1)),
                Float32(1 - 1 / (df.priority[i] + 1)),
                df.sentiment[i],
                df.sentiment_score[i],
                length(sentences[df.userid[i]]) - 1,
                df.userid[i],
            ),
        )
        push!(sentences[df.userid[i]], word)
    end
    sentences
end;

In [None]:
function subset_sentence(sentence, max_seq_length; recent)
    if length(sentence) <= max_seq_length
        return sentence
    end
    if recent
        # keep the rightmost entries
        idx = length(sentence) - max_seq_length + 1
    else
        # take a random contiguous subset            
        idx = rand(1:length(sentence)-max_seq_length+1)
    end
    sentence[idx:idx+max_seq_length-1]
end;

In [None]:
function get_token_ids(sentence, max_seq_length, pad_tokens, cls_tokens)
    outputs = fill.(pad_tokens, max_seq_length)
    @assert length(sentence) <= max_seq_length
    function process(word, i)
        # update the position for sentences that are longer than max_seq_len
        if i == get_wordtype_index(:position) &&
           extract(word, :itemid) != extract(cls_tokens, :itemid)
            return (word[i] % max_seq_length)
        else
            return word[i]
        end
    end
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            outputs[j][i] = process(sentence[i], j)
        end
    end
    outputs
end;