-
-
Notifications
You must be signed in to change notification settings - Fork 67
/
mlp.jl
48 lines (41 loc) · 2.16 KB
/
mlp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# TODO @theabhirath figure out consistent behaviour for dropout probs - 0.0 vs `nothing`
"""
mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
dropout_prob = 0., activation = gelu)
Feedforward block used in many MLPMixer-like and vision-transformer models.
# Arguments
- `inplanes`: Number of dimensions in the input.
- `hidden_planes`: Number of dimensions in the intermediate layer.
- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`.
- `dropout_prob`: Dropout probability.
- `activation`: Activation function to use.
"""
function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes;
dropout_prob = 0.0, activation = gelu)
return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout_prob),
Dense(hidden_planes, outplanes), Dropout(dropout_prob))
end
"""
gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer,
outplanes::Integer = inplanes; dropout_prob = 0.0, activation = gelu)
Feedforward block based on the implementation in the paper "Pay Attention to MLPs".
([reference](https://arxiv.org/abs/2105.08050))
# Arguments
- `gate_layer`: Layer to use for the gating.
- `inplanes`: Number of dimensions in the input.
- `hidden_planes`: Number of dimensions in the intermediate layer.
- `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`.
- `dropout_prob`: Dropout probability.
- `activation`: Activation function to use.
"""
function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer,
outplanes::Integer = inplanes; dropout_prob = 0.0,
activation = gelu)
@assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP"
return Chain(Dense(inplanes, hidden_planes, activation),
Dropout(dropout_prob),
gate_layer(hidden_planes),
Dense(hidden_planes ÷ 2, outplanes),
Dropout(dropout_prob))
end
gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...)