Skip to content

Commit

Permalink
Fixes breakage with PyTorch nightly rusty1s#22
Browse files Browse the repository at this point in the history
This PR is supposed to fix issue rusty1s#22

Replaces THRandom with the new CPUGenerator
  • Loading branch information
Dawars committed Jun 13, 2019
1 parent f7e9e8f commit 1088c64
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions cpu/sampler.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#include <TH/THRandom.h>
#include <ATen/CPUGenerator.h>
#include <torch/extension.h>

#include <TH/THGenerator.hpp>

at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
float factor) {
THGenerator *generator = THGenerator_new();
auto* generator = at::detail::getDefaultCPUGenerator();

auto start_ptr = start.data<int64_t>();
auto cumdeg_ptr = cumdeg.data<int64_t>();
Expand All @@ -26,7 +24,7 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
std::unordered_set<int64_t> set;
if (size_i < 0.7 * float(num_neighbors)) {
while (set.size() < size_i) {
int64_t z = THRandom_random(generator) % num_neighbors;
int64_t z = generator->random() % num_neighbors;
set.insert(z + low);
}
std::vector<int64_t> v(set.begin(), set.end());
Expand All @@ -40,8 +38,6 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
}
}

THGenerator_free(generator);

int64_t len = e_ids.size();
auto e_id = torch::from_blob(e_ids.data(), {len}, start.options()).clone();
return e_id;
Expand Down

0 comments on commit 1088c64

Please sign in to comment.