In [None]:
import Flux: Chain, Dense, Dropout, gelu
import Flux.NNlib: gather
import Functors: @functor
import NeuralAttentionlib: GenericAttenMask
import Transformers: Transformer, Layers

In [None]:
struct Bert{T,D}
    transformers::T
    dropout::D
end

@functor Bert

In [None]:
function Bert(;
    hidden_size,
    num_attention_heads,
    intermediate_size,
    num_layers,
    activation_fn = gelu,
    dropout = 0.1,
    attention_dropout = 0.1,
)
    ts = Transformer(
        Layers.PreNormTransformerBlock,
        num_layers,
        activation_fn,
        num_attention_heads,
        hidden_size,
        div(hidden_size, num_attention_heads),
        intermediate_size,
        attention_dropout = attention_dropout,
        dropout = dropout,
        return_score = false,
    )
    Bert(ts, Dropout(dropout))
end;

In [None]:
function (bert::Bert)(x::T, mask) where {T}
    e = bert.dropout(x)
    y = bert.transformers((hidden_state = x, attention_mask = GenericAttenMask(mask)))
    y.hidden_state
end;

In [None]:
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
@functor BiasLayer;

In [None]:
# # for use with the WeightDecayNobias optimizer
# function mark_embedding_matrix(hidden_size)
#     if hidden_size % 2 == 0
#         return hidden_size + 1
#     else
#         return hidden_size
#     end
# end;

In [None]:
struct DiscreteEmbed <: AbstractEmbed{Float32}
    embedding::Any
end
@functor DiscreteEmbed
Base.size(e::DiscreteEmbed, s...) = size(e.embedding, s...)
DiscreteEmbed(size, vocab_size; init=randn) = DiscreteEmbed(init(Float32, Int32(size), Int32(vocab_size)))
(e::DiscreteEmbed)(x) = gather(e.embedding, x)

In [None]:
struct ContinuousEmbed <: AbstractEmbed{Float32}
    embedding::Any
end
@functor ContinuousEmbed
function ContinuousEmbed(hidden_size::Int) 
    embed_size = div(hidden_size, 64)
    ContinuousEmbed(
        Chain(Dense(1, embed_size, gelu), Dense(embed_size, hidden_size)),
    )
end
(e::ContinuousEmbed)(x) = e.embedding(reshape(x, (1, size(x)...)))