Skip to content

Commit

Permalink
Added functional API support
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush1999 committed May 30, 2018
1 parent f0d1cf1 commit 867775a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/graph.jl
Expand Up @@ -56,8 +56,12 @@ end

function load(structure_file, weight_file)
global weight = weights(weight_file)

s = load_structure(structure_file)
if check_modeltype(structure_file) == "Sequential"
s = load_structure(structure_file)
elseif check_modeltype(structure_file) == "Model"
s = load_structure(structure_file)["layers"]
filter!(x->x["class_name"]!="InputLayer", s)
end
l = load_layers(s)
go = get_ops(l)
return go, weight
Expand Down
5 changes: 5 additions & 0 deletions src/ops.jl
@@ -1,4 +1,9 @@
ops = Dict{Symbol, Any}()

ops[:InputLayer] = function(a)
return nothing
end

ops[:Input] = function(a)
return vcall(:.+, a, 0)
end
Expand Down
12 changes: 12 additions & 0 deletions src/read.jl
Expand Up @@ -19,6 +19,14 @@ function load_structure(file="structure.json")
return res["config"]
end

"""
Check if model uses sequential/functional API.
"""
function check_modeltype(file)
res = JSON.parse(String(read(open(file, "r"))))
return res["class_name"]
end

struct new_type
layer_type::Symbol
fields::Any
Expand All @@ -44,6 +52,8 @@ function layer_type(a)
return :Reshape
elseif (a["class_name"] == "BatchNormalization")
return :BatchNormalization
elseif (a["class_name"] == "InputLayer")
return :InputLayer
end
end

Expand All @@ -69,6 +79,8 @@ function fields(a)
return ["name", "target_shape"]
elseif layer_type(a) == :BatchNormalization
return ["name", "momentum", "epsilon"]
elseif layer_type(a) == :InputLayer
return ["name"]
end
end

Expand Down

0 comments on commit 867775a

Please sign in to comment.