-
-
Notifications
You must be signed in to change notification settings - Fork 67
/
alexnet.jl
68 lines (55 loc) · 2.19 KB
/
alexnet.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
alexnet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
Create an AlexNet model
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
# Arguments
- `dropout_prob`: dropout probability for the classifier
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
"""
function alexnet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
MaxPool((3, 3); stride = 2),
Conv((5, 5), 64 => 192, relu; pad = 2),
MaxPool((3, 3); stride = 2),
Conv((3, 3), 192 => 384, relu; pad = 1),
Conv((3, 3), 384 => 256, relu; pad = 1),
Conv((3, 3), 256 => 256, relu; pad = 1),
MaxPool((3, 3); stride = 2))
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
Dropout(dropout_prob),
Dense(256 * 6 * 6, 4096, relu),
Dropout(dropout_prob),
Dense(4096, 4096, relu),
Dense(4096, nclasses))
return Chain(backbone, classifier)
end
"""
AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
Create a `AlexNet`.
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).
# Arguments
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
!!! warning
`AlexNet` does not currently support pretrained weights.
See also [`alexnet`](@ref).
"""
struct AlexNet
layers::Any
end
@functor AlexNet
function AlexNet(; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
layers = alexnet(; inchannels, nclasses)
model = AlexNet(layers)
if pretrain
loadpretrain!(model, "alexnet")
end
return model
end
(m::AlexNet)(x) = m.layers(x)
backbone(m::AlexNet) = m.layers[1]
classifier(m::AlexNet) = m.layers[2]