Skip to content

Commit

Permalink
Working summary [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 1, 2019
1 parent 8948fa8 commit 1469176
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
6 changes: 4 additions & 2 deletions lib/tensorflow/keras/layers/dense.rb
Expand Up @@ -22,15 +22,17 @@ def build(input_shape)
@bias = nil
end

@output_shape = [last_dim, @units]

@built = true
end

def output_shape
[]
@output_shape
end

def count_params
0
@units + @kernel.shape.inject(&:*)
end

def call(inputs)
Expand Down
6 changes: 5 additions & 1 deletion lib/tensorflow/keras/layers/dropout.rb
Expand Up @@ -5,13 +5,17 @@ class Dropout
def initialize(rate)
end

def build(input_shape)
@output_shape = input_shape
end

def call(inputs)
# TODO implement
TensorFlow.identity(inputs)
end

def output_shape
[] # TODO same as input shape
@output_shape
end

def count_params
Expand Down
9 changes: 8 additions & 1 deletion lib/tensorflow/keras/models/sequential.rb
Expand Up @@ -28,14 +28,21 @@ def evaluate(x, y)

def summary
sep = "_________________________________________________________________\n"

output_shape = nil
@layers.each do |layer|
layer.build(output_shape) if layer.respond_to?(:build)
output_shape = layer.output_shape
end

total_params = @layers.map(&:count_params).sum

summary = String.new("")
summary << "Model: \"sequential\"\n"
summary << sep
summary << "Layer (type) Output Shape Param # \n"
summary << "=================================================================\n"
summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, l.output_shape.map { |v| v == -1 ? nil : v }.inspect, l.count_params] }.join(sep)
summary << @layers.map { |l| "%-28s %-25s %-10s\n" % [l.class.name.split("::").last, ([nil] + l.output_shape[1..-1]).inspect, l.count_params] }.join(sep)
summary << "=================================================================\n"
summary << "Total params: #{total_params}\n"
summary << "Trainable params: #{total_params}\n"
Expand Down

0 comments on commit 1469176

Please sign in to comment.