-
Notifications
You must be signed in to change notification settings - Fork 1
/
block.jl
42 lines (34 loc) · 1.19 KB
/
block.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
struct Block{S,F,L1,L2}
sa::S
ffwd::F
ln1::L1
ln2::L2
end
Functors.@functor Block
"""
Block(input_dim; num_heads=1, head_size=(input_dim÷num_heads), dropout=0)
Initializes an instance of the **`Block`** type, representing a transformer block.
A **`Block`** instance accepts an input array **`x`** of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
- `mask` (Defaults to nothing. Must be of dimensions (T, T).)
## Examples:
```julia
C,T,B = 8,3,4
block = Block(C)
@assert size(block(rand(Float32, C,T,B))) == (C,T,B)
```
"""
function Block(input_dim; num_heads=1, head_size=(input_dim÷num_heads), dropout=0)
@assert num_heads > 0
@assert head_size == (input_dim ÷ num_heads)
@assert input_dim > 0
Block(
MultiheadAttention(input_dim, num_heads; dropout=dropout),
FeedForward(input_dim; dropout=dropout),
LayerNorm(input_dim),
LayerNorm(input_dim),
)
end
function (m::Block)(x; mask=nothing)
x + m.sa(m.ln1(x); mask=mask) + m.ffwd(m.ln2(x))
end