/
ops.jl
87 lines (72 loc) · 2.13 KB
/
ops.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
ops = Dict{Symbol, Any}()
ops[:InputLayer] = function(a)
return nothing
end
ops[:Input] = function(a)
return vcall(:.+, a, 0)
end
ops[:Conv] = function(a)
activation = a.fields["activation"]
if activation=="linear"
activation = "relu"
end
kernel_weight = reshape(weight[a.fields["name"]][a.fields["name"]]["kernel:0"],
reverse(size(weight[a.fields["name"]][a.fields["name"]]["kernel:0"])))
kernel_bias = weight[a.fields["name"]][a.fields["name"]]["bias:0"]
strides = (a.fields["strides"]...)
if a.fields["padding"] == "valid"
pads = (0,0)
elseif a.fields["padding"] == "same"
pads = (Int64.((a.fields["kernel_size"] .-1)./2)...)
end
return vcall(:Conv, Symbol(activation), kernel_weight, kernel_bias, strides, pads)
end
ops[:Dropout] = function(a)
return vcall(:Dropout, a.fields["rate"])
end
ops[:MaxPool] = function(a)
return x->maxpool(x, (a.fields["pool_size"]...), pad=(0,0), stride=(a.fields["strides"]...))
#return vcall(x->maxpool(x, (a.fields["pool_size"]...), pads=(0,0), strides=(a.fields["strides"]...)))
end
ops[:Flatten] = function(a)
return :vec
end
ops[:BatchNormalization] = function(a)
epsilon = a.fields["epsilon"]
momentum = a.fields["momentum"]
return x -> BatchNorm(size(x)[3], ϵ=epsilon, momentum=momentum)(x)
end
ops[:Dense] = function(a)
name = a.fields["name"]
weight_kernel = weight[name][name]["kernel:0"]
bias = weight[name][name]["bias:0"]
if !haskey(a.fields, "activation")
return Dense(weight_kernel, bias)
else
if a.fields["activation"] == "linear"
a.fields["activation"] = "relu"
end
return Dense(weight_kernel, bias), Symbol(a.fields["activation"])
end
end
ops[:Reshape] = function(a)
return (x -> reshape(x, (a.fields["target_shape"]...)))
end
ops[:relu] = function(a)
return relu
end
ops[:tanh] = function(a)
return tanh
end
ops[:sigmoid] = function(a)
return sigmoid
end
ops[:elu] = function(a)
return elu
end
ops[:softplus] = function(a)
return softplus
end
ops[:softmax] = function(a)
return relu
end