In [None]:
# item, rating, timestamp, status, completion, user, position
const wordtype = Tuple{Int32,Float32,Float32,Int32,Float32,Int32,Int32}

function get_wordtype_index(field::Symbol)
    if field == :item
        return 1
    elseif field == :rating
        return 2
    elseif field == :timestamp
        return 3
    elseif field == :status
        return 4
    elseif field == :completion
        return 5
    elseif field == :user
        return 6
    elseif field == :position
        return 7
    else
        @assert false
    end
end

extract(word, field::Symbol) = word[get_wordtype_index(field)]

function replace(word, fieldname::Symbol, field)
    pos = get_wordtype_index(fieldname)
    (word[1:pos-1]..., field, word[pos+1:end]...)
end

function encode_word(
    item,
    rating,
    timestamp,
    status,
    completion,
    user,
    position,
)
    word =
        (item, rating, timestamp, status, completion, user, position)
    convert(wordtype, word)
end

is_ptw(word::wordtype) = extract(word, :status) == 1;

In [None]:
function get_training_data(task, medium, include_ptw, cls_tokens; show_progress_bar = true)
    function get_df(task, content, medium)
        df = get_raw_split("training", task, content, medium)
        if content != "explicit"
            df.rating .= 11
        end
        df
    end

    contents = ["explicit", "implicit"]
    if include_ptw
        push!(contents, "ptw")
    end
    sentences = Dict{Int32,Vector{wordtype}}()
    df = reduce(cat, [get_df(task, content, medium) for content in contents])
    order = sortperm(df.timestamp)
    p = ProgressMeter.Progress(length(order); enabled = show_progress_bar, showspeed = true)
    for idx = 1:length(order)
        i = order[idx]
        if df.user[i] ∉ keys(sentences)
            sentences[df.user[i]] = [replace(cls_tokens, :user, df.user[i])]
        end
        word = encode_word(
            df.item[i],
            df.rating[i],
            df.timestamp[i],
            df.status[i],
            df.completion[i],
            df.user[i],
            length(sentences[df.user[i]]),
        )
        push!(sentences[df.user[i]], word)
        ProgressMeter.next!(p)
    end
    ProgressMeter.finish!(p)
    sentences
end;

In [None]:
# function get_training_data(task, include_ptw, cls_tokens)
#     @info "loading training splits for $task"
#     function get_df(task, content)
#         df = get_raw_split("training", task, content)
#         if content != "explicit"
#             df.rating .= 11
#         end
#         df
#     end
#     contents = ["explicit", "implicit"]
#     if include_ptw
#         push!(contents, "ptw")
#     end
#     df = reduce(cat, [get_df(task, content) for content in contents])

#     # shard the users across threads
#     sharded_users = [[Int32[] for _ = 1:Threads.nthreads()] for _ = 1:Threads.nthreads()]
#     sharded_timestamps =
#         [[Float32[] for _ = 1:Threads.nthreads()] for _ = 1:Threads.nthreads()]
#     @tprogress Threads.@threads for i = 1:length(df.user)
#         key = (df.user[i] % Threads.nthreads()) + 1
#         push!(sharded_users[Threads.threadid()][key], i)
#         push!(sharded_timestamps[Threads.threadid()][key], df.timestamp[i])
#     end
#     users = Vector{Vector{Int32}}(undef, Threads.nthreads())
#     timestamps = Vector{Vector{Float32}}(undef, Threads.nthreads())
#     Threads.@threads for t = 1:Threads.nthreads()
#         users[t] = reduce(vcat, [sharded_users[s][t] for s = 1:Threads.nthreads()])
#         timestamps[t] =
#             reduce(vcat, [sharded_timestamps[s][t] for s = 1:Threads.nthreads()])
#     end
#     users = [x for x in users if length(x) > 0]
#     timestamps = [x for x in timestamps if length(x) > 0]

#     @info "constructing watch histories for $task"
#     sentences = Dict{Int32,Vector{wordtype}}(i => [] for i = 1:num_users())
#     p = ProgressMeter.Progress(length(users[1]); showspeed = true)
#     Threads.@threads for t = 1:length(users)
#         order = users[t][sortperm(timestamps[t])]
#         for i in order
#             if length(sentences[df.user[i]]) == 0
#                 sentences[df.user[i]] = [replace(cls_tokens, :user, df.user[i])]
#             end
#             word = encode_word(
#                 df.item[i],
#                 df.rating[i],
#                 df.timestamp[i],
#                 df.status[i],
#                 df.completion[i],
#                 df.user[i],
#                 length(sentences[df.user[i]]),
#             )
#             push!(sentences[df.user[i]], word)
#             if t == 1
#                 ProgressMeter.next!(p)
#             end
#         end
#     end
#     ProgressMeter.finish!(p)
#     Dict(k => v for (k, v) in sentences if length(v) > 0)
# end;

In [None]:
function subset_sentence(sentence, max_seq_length; recent, rng)
    if length(sentence) > max_seq_length
        if recent
            # keep the rightmost entries
            idx = length(sentence) - max_seq_length + 1
        else
            # take a random contiguous subset            
            idx = rand(rng, 1:length(sentence)-max_seq_length+1)
        end
        sentence = sentence[idx:idx+max_seq_length-1]
    end
    sentence
end;

In [None]:
function pad_sentence(sentence, max_seq_length, max_position, pad_tokens, cls_tokens; rng)
    outputs = fill.(pad_tokens, max_seq_length)
    sentence = subset_sentence(sentence, max_seq_length; recent = false, rng = rng)
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            if j == get_wordtype_index(:position) && extract(sentence[i], :item) != extract(cls_tokens, :item)
                p = (sentence[i][j] % max_position)
                if p == 0
                    p = max_position
                end
                outputs[j][i] = p
            else
                outputs[j][i] = sentence[i][j]
            end
        end
    end
    outputs
end;

In [None]:
function get_token_ids(sentences, max_seq_length, max_position, pad_tokens, cls_tokens; rng)
    padded_sentences = [
        pad_sentence(x, max_seq_length, max_position, pad_tokens, cls_tokens; rng = rng) for x in sentences
    ]
    Tuple(hcat([x[i] for x in padded_sentences]...) for i = 1:length(pad_tokens))
end;