In [None]:
import CUDA
import Flux
import Flux: Chain, Dense, Dropout, LayerNorm, cpu, gelu, gpu, logsoftmax
import Flux.NNlib: gather
import Functors: @functor
import NNlibCUDA
import Optimisers
import Optimisers: Adam, OptimiserChain, WeightDecay
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful
import Random
import StatsBase: mean, sample
import Transformers: Bert, Positionwise, TransformerModel
import Transformers.Basic: AbstractEmbed, CompositeEmbedding, Embed, Transformer

# CUDA Performance

In [None]:
CUDA.math_mode!(CUDA.FAST_MATH; precision = :BFloat16)
NNlibCUDA.softmaxalgo() = NNlibCUDA.CUDNN_SOFTMAX_ACCURATE

# Transformer Architecture Overrides

In [None]:
function (bert::Bert)(x::T, mask) where {T}
    e = bert.drop(x)
    t, _ = bert.ts(e, mask)
    t
end;

In [None]:
# Do pre layer normalization
function (t::Transformer)(
    x::A,
    mask = nothing,
) where {T,N,A<:AbstractArray{T,N}}
    dropout = t.drop
    a = t.mhn(x)
    a = t.mh(a, a, a; mask = mask)
    a = dropout(a)
    res_a = x + a
    pwffn = t.pwn(res_a)
    pwffn = t.pw(pwffn)
    pwffn = dropout(pwffn)
    res_pwffn = res_a + pwffn
    res_pwffn
end;

# Model Utils

In [None]:
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
@functor BiasLayer;

In [None]:
struct ContinuousEmbed <: AbstractEmbed{Float32}
    embedding::Any
end
@functor ContinuousEmbed
ContinuousEmbed(hidden_size::Int) = ContinuousEmbed(
    Chain(
        Dense(1, div(hidden_size, 64), gelu),
        Dense(div(hidden_size, 64), hidden_size),
    ),
)
(e::ContinuousEmbed)(x) = e.embedding(reshape(x, (1, size(x)...)))

In [None]:
function tuplesum(a::NamedTuple, b::NamedTuple)
    fields = fieldnames(typeof(a))
    NamedTuple{fields}(tuplesum(a[k], b[k]) for k in fields)
end
tuplesum(a::Tuple, b::Tuple) = Tuple(tuplesum(a[k], b[k]) for k = 1:length(a))
tuplesum(a::Nothing, b) = b
tuplesum(a, b) = a + b;

function tupledivide(a::NamedTuple, d)
    fields = fieldnames(typeof(a))
    NamedTuple{fields}(tupledivide(a[k], d) for k in fields)
end
tupledivide(a::Tuple, d) = Tuple(tupledivide(a[k], d) for k = 1:length(a))
tupledivide(a::Nothing, d) = nothing
tupledivide(a, d) = a ./ d;

# Data Utils

In [None]:
# TODO named tuples
# item, rating, timestamp, status, completion, user, position
const word_type = Tuple{Int32, Float32, Float32, Int32, Float32, Int32, Int32}
replace_user(word::word_type, user::Int32) = (word[1:5]..., user, word[7])
replace_position(word::word_type, position::Int32) = (word[1:6]..., position)
replace_timestamp(word, timestamp) = (word[1:2]..., timestamp, word[4:end]...)
is_ptw(word::word_type) = word[4] == 1;

In [None]:
function encode_word(item, rating, timestamp, status, completion, user, position)
    word = (item, rating, timestamp, status, completion, user, position)
    convert(word_type, word)
end

function get_training_data(task, include_ptw, cls_tokens; show_progress_bar = false)
    function get_df(task, content)
        df = get_raw_split("training", task, content)
        if content != "explicit"
            df.rating .= 11
        end
        df
    end

    contents = ["explicit", "implicit"]
    if include_ptw
        push!(contents, "ptw")
    end
    sentences = Dict{Int32,Vector{word_type}}()
    df = reduce(cat, [get_df(task, content) for content in contents])
    order = sortperm(df.timestamp)
    p = ProgressMeter.Progress(length(order); enabled = show_progress_bar, showspeed = true)
    for idx = 1:length(order)
        i = order[idx]
        if df.user[i] ∉ keys(sentences)
            sentences[df.user[i]] = [replace_user(cls_tokens, df.user[i])]
        end
        word = encode_word(
            df.item[i],
            df.rating[i],
            df.timestamp[i],
            df.status[i],
            df.completion[i],
            df.user[i],
            length(sentences[df.user[i]]),
        )
        push!(sentences[df.user[i]], word)
        ProgressMeter.next!(p)
    end
    ProgressMeter.finish!(p)
    sentences
end;

In [None]:
function subset_sentence(sentence, max_seq_length; recent, rng)
    if length(sentence) > max_seq_length
        if recent
            # keep the rightmost entries
            idx = length(sentence)-max_seq_length+1
        else
            # take a random contiguous subset            
            idx = rand(rng, 1:length(sentence)-max_seq_length+1)
        end
        sentence = sentence[idx:idx+max_seq_length-1]
    end
    sentence
end;

In [None]:
function pad_sentence(sentence, max_seq_length, max_position, pad_tokens, cls_tokens; rng)
    outputs = fill.(pad_tokens, max_seq_length)
    sentence = subset_sentence(sentence, max_seq_length; recent = false, rng = rng)
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            if j == 7 && sentence[i][1] != cls_tokens[1]
                p = (sentence[i][j] % max_position)
                if p == 0
                    p = 1
                end
                outputs[j][i] = p
            else
                outputs[j][i] = sentence[i][j]
            end
        end
    end
    outputs
end;

In [None]:
function get_token_ids(sentences, max_seq_length, max_position, pad_tokens, cls_tokens; rng)
    padded_sentences = [
        pad_sentence(x, max_seq_length, max_position, pad_tokens, cls_tokens; rng = rng) for x in sentences
    ]
    Tuple(hcat([x[i] for x in padded_sentences]...) for i = 1:length(pad_tokens))
end;