Skip to content

Commit

Permalink
Fix a dangling pointer issue in the C++ layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Feb 6, 2020
1 parent 312a28d commit db23c54
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
9 changes: 9 additions & 0 deletions tests/tensor_indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,12 @@ fn complex_index() {
&[157, 158, 159, 160, 164, 165, 166, 167, 143, 144, 145, 146]
);
}

#[test]
fn index_3d() {
let values: Vec<i64> = (0..24).collect();
let tensor = tch::Tensor::of_slice(&values).view((2, 3, 4));
assert_eq!(Vec::<i64>::from(tensor.i((0, 0, 0))), &[0]);
assert_eq!(Vec::<i64>::from(tensor.i((1, 0, 0))), &[12]);
assert_eq!(Vec::<i64>::from(tensor.i((0..2, 0, 0))), &[0, 12]);
}
13 changes: 8 additions & 5 deletions torch-sys/libtch/torch_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,13 @@ void at_copy_data(tensor tensor, void *vs, size_t numel, size_t elt_size_in_byte
if (numel > tensor->numel())
throw std::invalid_argument("target numel is larger than tensor numel");
if (tensor->device().type() != at::kCPU) {
torch::Tensor tmp_tensor = tensor->to(at::kCPU);
void *tensor_data = tmp_tensor.contiguous().data_ptr();
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
}
else {
void *tensor_data = tensor->contiguous().data_ptr();
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
}
)
Expand Down Expand Up @@ -314,7 +315,8 @@ int at_save_image(tensor tensor, char *filename) {
int h = sizes[0];
int w = sizes[1];
int c = sizes[2];
void *tensor_data = tensor->contiguous().data_ptr();
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
if (ends_with(filename, ".jpg"))
return stbi_write_jpg(filename, w, h, c, tensor_data, 90);
if (ends_with(filename, ".bmp"))
Expand All @@ -334,7 +336,8 @@ tensor at_resize_image(tensor tensor, int out_w, int out_h) {
int h = sizes[0];
int w = sizes[1];
int c = sizes[2];
const unsigned char *tensor_data = (unsigned char*)tensor->contiguous().data_ptr();
auto tmp_tensor = tensor->contiguous();
const unsigned char *tensor_data = (unsigned char*)tmp_tensor.data_ptr();
torch::Tensor out = torch::zeros({ out_h, out_w, c }, at::ScalarType::Byte);
stbir_resize_uint8(tensor_data, w, h, 0, (unsigned char*)out.data_ptr(), out_w, out_h, 0, c);
return new torch::Tensor(out);
Expand Down

0 comments on commit db23c54

Please sign in to comment.