/
pointnet.jl
87 lines (70 loc) · 1.81 KB
/
pointnet.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
export stnKD, PointNet
stnKD(K::Int) = Chain(
Conv((1,), K => 64, relu),
BatchNorm(64),
Conv((1,), 64 => 128, relu),
BatchNorm(128),
Conv((1,), 128 => 1024, relu),
BatchNorm(1024),
x -> maximum(x, dims = 1),
x -> reshape(x, :, size(x, 3)),
Dense(1024, 512, relu),
# BatchNorm(512),
Dense(512, 256, relu),
BatchNorm(256),
Dense(256, K * K),
x -> reshape(x, K, K, size(x, 2)),
# x -> x .+ I, #TODO: add identity matrix compatible with gpu
x -> PermutedDimsArray(x, (2, 1, 3)),
)
"""
PointNet(num_classes::Int=10, hidden_dims::Int=64)
Flux implementation of PointNet classification model.
### Fields:
- `num_classes` - Number of classes in dataset.
- `hidden_dims` - Hiddem dimension in PointNet model.
"""
struct PointNet
stn::Any
fstn::Any
conv_block1::Any
feat::Any
cls::Any
end
function PointNet(num_classes::Int = 10, K::Int = 64)
stn = stnKD(3)
fstn = stnKD(K)
conv_block1 = conv_bn_block(3, 64, (1,))
feat = Chain(
Conv((1,), 64 => 128, relu),
BatchNorm(128),
Conv((1,), 128 => 1024),
BatchNorm(1024),
x -> maximum(x, dims = 1),
x -> reshape(x, 1024, :),
Dense(1024, 512, relu),
BatchNorm(512),
Dense(512, 256, relu),
Dropout(0.4),
BatchNorm(256),
)
cls = Dense(256, num_classes, relu)
PointNet(stn, fstn, conv_block1, feat, cls)
end
function (m::PointNet)(X)
# X: [3, N, B]
X = permutedims(X, (2, 1, 3))
# X: [N, 3, B]
X = Flux.batched_mul(X, m.stn(X))
# X: [3, 3, B]
X = m.conv_block1(X)
# X: [3, 64, B]
X = batched_mul(X, m.fstn(X))
# X: [3, 64, B]
X = m.feat(X)
# X: [256, B]
X = m.cls(X)
# X: [num_classes, B]
return softmax(X, dims = 1)
end
@functor PointNet