Skip to content

Commit

Permalink
1. Modified Rakefile to add pry.
Browse files Browse the repository at this point in the history
2. Improved specs
  • Loading branch information
Arafatk committed Jul 29, 2016
1 parent 3ae114b commit d8649ea
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 207 deletions.
11 changes: 11 additions & 0 deletions Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,15 @@ YARD::Rake::YardocTask.new(:doc) do |t|
t.files = ['lib/*.rb', 'lib/**/*.rb']
end


task :pry do |task|
cmd = [ 'pry', "-r './lib/tensorflow.rb' "]
run *cmd
end

def run *cmd
sh(cmd.join(" "))
end


task :default => :spec
3 changes: 2 additions & 1 deletion lib/tensorflow/session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run(inputs, outputs, targets)

status = Tensorflow::TF_NewStatus()
Tensorflow::TF_Run_wrapper(self.session, input_names, input_values, output_names, output_values, target_names, self.status)
raise ("Incorrect specifications passed.") if Tensorflow::TF_GetCode(status) != Tensorflow::TF_OK
raise "Incorrect specifications passed." if Tensorflow::TF_GetCode(status) != Tensorflow::TF_OK

output_array = []

Expand All @@ -47,6 +47,7 @@ def run(inputs, outputs, targets)
end

def extend_graph(graph)
graph.graph_def_raw = Tensorflow::GraphDef.encode(graph.graph_def)
self.status = Tensorflow::TF_NewStatus()
Tensorflow::TF_ExtendGraph(self.session, graph_def_to_c_array(graph.graph_def_raw), graph.graph_def_raw.length, self.status)
self.graph = graph
Expand Down
10 changes: 8 additions & 2 deletions lib/tensorflow/tensor.rb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def shape_proto(array)
Tensorflow::TensorShapeProto.new(dim: dimensions)
end

#
# Helper function to set the data type of tensor.
#
def set_type(type)
case type
when :float
Expand Down Expand Up @@ -207,13 +210,16 @@ def ruby_array_to_c(array, type)
#
# Returns the value of the element contained in the specified position in the tensor.
#
# * *Input* :
# - Dimension array(1 based indexing).
#
# * *Returns* :
# - value of the element contained in the specified position in the tensor.
# - Value of the element contained in the specified position in the tensor.
#
def getval(dimension)
raise("Invalid dimension array passed as input.",ShapeError) if dimension.length != self.dimensions.length
(0..dimension.length-1).each do |i|
raise("Invalid dimension array passed as input.",ShapeError) if dimension[i] > self.dimensions[i] || dimension[i] < 1
raise("Invalid dimension array passed as input.",ShapeError) if dimension[i] > self.dimensions[i] || dimension[i] < 1 || !(dimension[i].is_a? Integer)
end
sum = dimension[dimension.length - 1] - 1
prod = self.dimensions[self.dimensions.length - 1]
Expand Down
11 changes: 2 additions & 9 deletions spec/graph_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,11 @@
input2 = graph.placeholder('input2', Tensorflow::TF_INT32, [2,3])
graph.define_op("Add",'output',[input1,input2],"",nil)

encoder = Tensorflow::GraphDef.encode(graph.graph_def)
session = Tensorflow::Session.new
graph = Tensorflow::Graph.new
graph.graph_def = Tensorflow::GraphDef.decode(encoder)
graph.graph_def_raw = encoder
session.extend_graph(graph)
s = session

input1 = Tensorflow::Tensor.new([[1,3, 5],[2,4, 7]],:int32)
input2 = Tensorflow::Tensor.new([[-5,1,4],[8,2, 3]],:int32)
input = Hash.new
input["input1"] = input1.tensor
input["input2"] = input2.tensor
result = s.run(input, ["output"], nil)
result = session.run({"input1" => input1.tensor, "input2" => input2.tensor}, ["output"], nil)
end
end
Loading

0 comments on commit d8649ea

Please sign in to comment.