In [None]:
import CUDA
import Flux
import Flux: Chain, Dense, Dropout, LayerNorm, cpu, gelu, gpu, logsoftmax
import Flux.NNlib: gather
import Functors: @functor
import Optimisers
import Optimisers: OptimiserChain, Adam, WeightDecay
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful
import Random
import StatsBase: mean, sample
import Transformers: Bert, Positionwise, TransformerModel
import Transformers.Basic: CompositeEmbedding, Embed, PositionEmbedding

# Model Utils

In [None]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
@functor BiasLayer;

In [None]:
(pe::PositionEmbedding)(x::AbstractArray{X}) where {X<:Integer} = pe(size(x, 1));

# Data Utils

In [None]:
function encode_raw_timestamp(timestamp)
    if timestamp == -1
        return 1
    else
        date = timestamp_to_date(timestamp)
        year = Dates.value(Dates.Year(date)) - 2004
        season = div(Dates.value(Dates.Month(date)) - 1, 4) + 1
        return 1 + year * 4 + season
    end    
end

function encode_word(item, rating, timestamp, status, completion, user)
    ts = encode_raw_timestamp(timestamp)
    r = Int32(round(rating)) + 1
    c = Int32(round(10 * completion)) + 1
    word = (item, r, ts, status, c, user)
    convert.(Int32, word)
end

function get_training_data(task, include_ptw; show_progress_bar = false)
    function get_df(task, content)
        df = get_raw_split("training", task, content)
        if content != "explicit"
            df.rating .= 11
        end
        df
    end

    contents = ["explicit", "implicit"]
    if include_ptw
        push!(contents, "ptw")
    end
    sentences = Dict{Int32,Vector{NTuple{6,Int32}}}()
    df = reduce(cat, [get_df(task, content) for content in contents])
    order = sortperm(df.timestamp)
    p = ProgressMeter.Progress(length(order); enabled = show_progress_bar, showspeed = true)
    for idx = 1:length(order)
        i = order[idx]
        if df.user[i] ∉ keys(sentences)
            sentences[df.user[i]] = NTuple{6,Int32}[]
        end
        word = encode_word(
            df.item[i],
            df.rating[i],
            df.timestamp[i],
            df.status[i],
            df.completion[i],
            df.user[i],
        )
        push!(sentences[df.user[i]], word)
        ProgressMeter.next!(p)
    end
    ProgressMeter.finish!(p)
    sentences
end;

In [None]:
function subset_sentence(sentence, max_seq_length; recent, rng)
    if length(sentence) > max_seq_length
        if recent
            # keep the rightmost entries
            idx = length(sentence)-max_seq_length+1
        else
            # take a random contiguous subset            
            idx = rand(rng, 1:length(sentence)-max_seq_length+1)
        end
        sentence = sentence[idx:idx+max_seq_length-1]
    end
    sentence
end;

In [None]:
function pad_sentence(sentence, max_seq_length, pad_tokens; rng)
    outputs = fill.(pad_tokens, max_seq_length)
    sentence = subset_sentence(sentence, max_seq_length; recent=false, rng=rng)
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            outputs[j][i] = sentence[i][j]
        end
    end
    outputs
end;

In [None]:
function get_token_ids(sentences, max_seq_length, pad_tokens; rng)
    padded_sentences = [
        pad_sentence(x, max_seq_length, pad_tokens; rng = rng) for
        x in sentences
    ]
    Tuple(hcat([x[i] for x in padded_sentences]...) for i = 1:length(pad_tokens))
end;