Skip to content

Commit

Permalink
Fixed int types for imported tf networks
Browse files Browse the repository at this point in the history
  • Loading branch information
theraysmith committed May 5, 2017
1 parent 4fa463c commit d18931e
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions lstm/tfnetwork.cpp
Expand Up @@ -91,33 +91,36 @@ void TFNetwork::Forward(bool debug, const NetworkIO& input,
// objects.
if (!model_proto_.image_widths().empty()) {
TensorShape size_shape{1};
Tensor width_tensor(tensorflow::DT_INT32, size_shape);
auto eigen_wtensor = width_tensor.flat<int32>();
Tensor width_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_wtensor = width_tensor.flat<int64>();
*eigen_wtensor.data() = stride_map.Size(FD_WIDTH);
tf_inputs.emplace_back(model_proto_.image_widths(), width_tensor);
}
if (!model_proto_.image_heights().empty()) {
TensorShape size_shape{1};
Tensor height_tensor(tensorflow::DT_INT32, size_shape);
auto eigen_htensor = height_tensor.flat<int32>();
Tensor height_tensor(tensorflow::DT_INT64, size_shape);
auto eigen_htensor = height_tensor.flat<int64>();
*eigen_htensor.data() = stride_map.Size(FD_HEIGHT);
tf_inputs.emplace_back(model_proto_.image_heights(), height_tensor);
}
std::vector<string> target_layers = {model_proto_.output_layer()};
std::vector<Tensor> outputs;
Status s = session_->Run(tf_inputs, target_layers, {}, &outputs);
if (!s.ok()) tprintf("session->Run failed:%s\n", s.error_message().c_str());
ASSERT_HOST(s.ok());
ASSERT_HOST(outputs.size() == 1);
const Tensor& output_tensor = outputs[0];
// Check the dimensions of the output.
ASSERT_HOST(output_tensor.shape().dims() == 2);
int output_dim0 = output_tensor.shape().dim_size(0);
int output_dim1 = output_tensor.shape().dim_size(1);
ASSERT_HOST(output_dim1 == output_shape_.depth());
output->Resize2d(false, output_dim0, output_dim1);
ASSERT_HOST(output_tensor.shape().dims() == 3);
int output_batch = output_tensor.shape().dim_size(0);
int output_steps = output_tensor.shape().dim_size(1);
int output_depth = output_tensor.shape().dim_size(2);
ASSERT_HOST(output_batch == 1);
ASSERT_HOST(output_depth == output_shape_.depth());
output->Resize2d(false, output_steps, output_depth);
auto eigen_output = output_tensor.flat<float>();
memcpy(output->f(0), eigen_output.data(),
output_dim0 * output_dim1 * sizeof(output->f(0)[0]));
output_steps * output_depth * sizeof(output->f(0)[0]));
}

int TFNetwork::InitFromProto() {
Expand Down

0 comments on commit d18931e

Please sign in to comment.