Skip to content

Commit

Permalink
Merge pull request #39 from iotamudelta/master
Browse files Browse the repository at this point in the history
Merge from upstream
  • Loading branch information
iotamudelta committed Jul 17, 2018
2 parents 3fadf87 + b29376c commit 33ebb58
Show file tree
Hide file tree
Showing 90 changed files with 2,019 additions and 724 deletions.
5 changes: 5 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh"

echo "Testing pytorch"

if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
echo "Skipping ROCm tests for now"
exit 0
fi

# JIT C++ extensions require ninja.
git clone https://github.com/ninja-build/ninja --quiet
pushd ninja
Expand Down
40 changes: 11 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<p align="center"><img width="40%" src="docs/source/_static/img/pytorch-logo-dark.png" /></p>
![PyTorch Logo](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/pytorch-logo-dark.png)

--------------------------------------------------------------------------------

Expand Down Expand Up @@ -34,32 +34,14 @@ See also the [ci.pytorch.org HUD](https://ezyang.github.io/pytorch-ci-hud/build/

At a granular level, PyTorch is a library that consists of the following components:

<table>
<tr>
<td><b> torch </b></td>
<td> a Tensor library like NumPy, with strong GPU support </td>
</tr>
<tr>
<td><b> torch.autograd </b></td>
<td> a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch </td>
</tr>
<tr>
<td><b> torch.nn </b></td>
<td> a neural networks library deeply integrated with autograd designed for maximum flexibility </td>
</tr>
<tr>
<td><b> torch.multiprocessing </b></td>
<td> Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training. </td>
</tr>
<tr>
<td><b> torch.utils </b></td>
<td> DataLoader, Trainer and other utility functions for convenience </td>
</tr>
<tr>
<td><b> torch.legacy(.nn/.optim) </b></td>
<td> legacy code that has been ported over from torch for backward compatibility reasons </td>
</tr>
</table>
| Component | Description |
| ---- | --- |
| **torch** | a Tensor library like NumPy, with strong GPU support |
| **torch.autograd** | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
| **torch.nn** | a neural networks library deeply integrated with autograd designed for maximum flexibility |
| **torch.multiprocessing** | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
| **torch.utils** | DataLoader, Trainer and other utility functions for convenience |
| **torch.legacy(.nn/.optim)** | legacy code that has been ported over from torch for backward compatibility reasons |

Usually one uses PyTorch either as:

Expand All @@ -72,7 +54,7 @@ Elaborating further:

If you use NumPy, then you have used Tensors (a.k.a ndarray).

<p align=center><img width="30%" src="docs/source/_static/img/tensor_illustration.png" /></p>
![Tensor illustration](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/tensor_illustration.png)

PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerate
compute by a huge amount.
Expand All @@ -99,7 +81,7 @@ from several research papers on this topic, as well as current and past work suc
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
You get the best of speed and flexibility for your crazy research.

<p align=center><img width="80%" src="docs/source/_static/img/dynamic_graph.gif" /></p>
![Dynamic graph](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/dynamic_graph.gif)

### Python First

Expand Down
37 changes: 34 additions & 3 deletions aten/src/ATen/Retainable.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,52 @@ namespace at {
// base class for refcounted things, allows for collects of generic
// refcounted objects that include tensors
struct Retainable {
Retainable(): refcount(1) {}
Retainable(): refcount(1), weak_refcount(1) {}
void retain() {
++refcount;
}
void release() {
if(--refcount == 0) {
// If we know that this is the last reference then we can skip
// all the decrements and release_resources().
if (weak_refcount == 1) {
delete this;
} else {
release_resources();
weak_release();
}
}
}
void weak_retain() {
++weak_refcount;
}
void weak_release() {
if (--weak_refcount == 0) {
delete this;
}
}
int use_count() const {
bool weak_lock() {
for (;;) {
auto current_refcount = refcount.load();
if (current_refcount == 0) return false;
if (refcount.compare_exchange_strong(current_refcount, current_refcount + 1)) break;
}
return true;
}
uint32_t use_count() const {
return refcount.load();
}
uint32_t weak_use_count() const {
return weak_refcount.load();
}

virtual void release_resources() {};
virtual ~Retainable() {}
private:
std::atomic<int> refcount;
// INVARIANT: once refcount reaches 0 it can never go up
// INVARIANT: weak_refcount = number of weak references + (refcount > 0 ? 1 : 0)
std::atomic<uint32_t> refcount;
std::atomic<uint32_t> weak_refcount;
};

}
82 changes: 55 additions & 27 deletions aten/src/ATen/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,62 @@

namespace at { namespace detail {

// TensorBase is the base class for Tensor which handles the reference counting
struct TensorBase {
TensorBase(): TensorBase(UndefinedTensor::singleton(), false) {}
TensorBase(TensorImpl * self, bool retain)
// TensorBaseImpl is the base class for Tensor which handles the reference counting
template<bool is_strong>
struct TensorBaseImpl {
TensorBaseImpl(): TensorBaseImpl(UndefinedTensor::singleton(), false) {}
TensorBaseImpl(TensorImpl * self, bool should_retain)
: pImpl(self) {
if (pImpl == nullptr) {
throw std::runtime_error("TensorBase with nullptr not supported");
throw std::runtime_error("TensorBaseImpl with nullptr not supported");
}
if(should_retain && pImpl != UndefinedTensor::singleton()) {
retain();
}
if(retain && pImpl != UndefinedTensor::singleton())
pImpl->retain();
}
TensorBase(const TensorBase & rhs)
TensorBaseImpl(const TensorBaseImpl & rhs)
: pImpl(rhs.pImpl) {
if (pImpl != UndefinedTensor::singleton())
pImpl->retain();
if (pImpl != UndefinedTensor::singleton()) {
retain();
}
}
TensorBase(TensorBase && rhs) noexcept
TensorBaseImpl(TensorBaseImpl && rhs) noexcept
: pImpl(rhs.pImpl) {
rhs.pImpl = UndefinedTensor::singleton();
}
~TensorBase() {
if (pImpl != UndefinedTensor::singleton())
pImpl->release();
~TensorBaseImpl() {
if (pImpl != UndefinedTensor::singleton()) {
release();
}
}
TensorBase & operator=(TensorBase && rhs) & {
TensorBaseImpl & operator=(TensorBaseImpl && rhs) & {
rhs.swap(*this);
return *this;
}
TensorBase & operator=(TensorBase const & rhs) & {
//TensorBase ctor retains original rhs.pImpl
//then rhs.pImpl is swapped with this->pImpl
//finally TensorBase dtor releases rhs.pImpl, which was originally this->pImpl
TensorBase(rhs).swap(*this);
return *this;
TensorBaseImpl & operator=(TensorBaseImpl const & rhs) & {
//TensorBaseImpl ctor retains original rhs.pImpl
//then rhs.pImpl is swapped with this->pImpl
//finally TensorBaseImpl dtor releases rhs.pImpl, which was originally this->pImpl
TensorBaseImpl(rhs).swap(*this);
return *this;
}
int64_t dim() const {
return pImpl->dim();
if (is_strong) {
return pImpl->dim();
} else {
AT_ERROR("Can't call dim() on a WeakTensor");
}
}
void reset() {
TensorBase().swap(*this);
TensorBaseImpl().swap(*this);
}
void reset(TensorImpl * rhs) {
TensorBase(rhs, true).swap(*this);
TensorBaseImpl(rhs, true).swap(*this);
}
void reset(TensorImpl * rhs, bool retain) {
TensorBase(rhs, retain).swap(*this );
void reset(TensorImpl * rhs, bool should_retain) {
TensorBaseImpl(rhs, should_retain).swap(*this );
}
void swap(TensorBase & rhs) {
void swap(TensorBaseImpl & rhs) {
TensorImpl * tmp = pImpl;
pImpl = rhs.pImpl;
rhs.pImpl = tmp;
Expand All @@ -75,6 +83,26 @@ struct TensorBase {
//TODO(zach): sort out friend structes
public:
TensorImpl * pImpl;

private:
void retain() {
if (is_strong) {
pImpl->retain();
} else {
pImpl->weak_retain();
}
}

void release() {
if (is_strong) {
pImpl->release();
} else {
pImpl->weak_release();
}
}
};

using TensorBase = TensorBaseImpl<true>;
using WeakTensorBase = TensorBaseImpl<false>;

}} // namespace at::detail
2 changes: 0 additions & 2 deletions aten/src/ATen/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@

#if defined(__clang__)
#define __ubsan_ignore_float_divide_by_zero__ __attribute__((no_sanitize("float-divide-by-zero")))
#define __ubsan_ignore_function__ __attribute__((no_sanitize("function")))
#define __ubsan_ignore_vptr__ __attribute__((no_sanitize("vptr")))
#else
#define __ubsan_ignore_float_divide_by_zero__
#define __ubsan_ignore_function__
#define __ubsan_ignore_vptr__
#endif

Expand Down
42 changes: 42 additions & 0 deletions aten/src/ATen/templates/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ATen/Utils.h"
#include "ATen/Device.h"
#include "ATen/Layout.h"
#include "ATen/optional.h"

namespace at {
struct Type;
Expand Down Expand Up @@ -42,6 +43,7 @@ namespace at {
// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
// special care must be taken to handle this.
struct Tensor : public detail::TensorBase {
using TensorBase = detail::TensorBase;
Tensor() : TensorBase() {}
Tensor(TensorImpl * self, bool retain) : TensorBase(self, retain) {}
Tensor(const TensorBase & rhs) : TensorBase(rhs) {}
Expand Down Expand Up @@ -198,6 +200,46 @@ struct Tensor : public detail::TensorBase {
auto m(F func, Args&&... params) const -> decltype(func(*this, std::forward<Args>(params)...)) {
return func(*this, std::forward<Args>(params)...);
}

friend struct WeakTensor;
};

struct WeakTensor : public detail::WeakTensorBase {
using WeakTensorBase = detail::WeakTensorBase;
WeakTensor() : WeakTensorBase() {}
WeakTensor(TensorImpl * self, bool retain) : WeakTensorBase(self, retain) {}
WeakTensor(const WeakTensor & rhs) = default;
WeakTensor(WeakTensor && rhs) noexcept = default;
WeakTensor(const Tensor& t) : WeakTensorBase(t.pImpl, true) {}

// reimplemented from TensorBase so the return type is WeakTensor rather than TensorBase
WeakTensor & operator=(WeakTensor && rhs) & {
rhs.swap(*this);
return *this;
}
WeakTensor & operator=(WeakTensor const & rhs) & {
//Tensor ctor retains original rhs.pImpl
//then rhs.pImpl is swapped with this->pImpl
//finally Tensor dtor releases rhs.pImpl, which was originally this->pImpl
WeakTensor(rhs).swap(*this);
return *this;
}

WeakTensor & operator=(const Tensor& t) {
WeakTensor(t.pImpl, true).swap(*this);
return *this;
}

// non-retaining
TensorImpl * unsafeGetTensorImpl() const {
return pImpl;
}

// XXX: this can return undefined tensors
// Ideally it would be at::optional<Tensor>, but MSVC is too cool for that
Tensor lock() const {
return pImpl->weak_lock() ? Tensor(pImpl, false) : Tensor();
}
};

namespace detail {
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/templates/TensorDerived.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ void * ${Tensor}::unsafeGetTH(bool retain) {
return tensor;
}

void ${Tensor}::release_resources() {
${THTensor}_free(${state,} tensor);
tensor = nullptr;
}

${TensorDenseOrSparse}

}
1 change: 1 addition & 0 deletions aten/src/ATen/templates/TensorDerived.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct ${Tensor} final : public TensorImpl {
virtual Scalar localScalar() override;
virtual void * unsafeGetTH(bool retain) override;
virtual std::unique_ptr<Storage> storage() override;
virtual void release_resources() override;
static const char * typeString();

//TODO(zach): sort of friend permissions later so this
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp)

list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu
Expand Down
Loading

0 comments on commit 33ebb58

Please sign in to comment.