Skip to content

Commit

Permalink
Merge pull request #25 from iotamudelta/master
Browse files Browse the repository at this point in the history
Merge from upstream
  • Loading branch information
iotamudelta committed Jul 10, 2018
2 parents 6af217c + 907c9f1 commit 22a194d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
16 changes: 15 additions & 1 deletion torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,26 @@ PyObject *THPModule_hasDistributed(PyObject *_unused)
#endif
}

void DLPack_Capsule_Destructor(PyObject* data) {
HANDLE_TH_ERRORS
DLManagedTensor * dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(data, "dltensor");
if (dlMTensor) {
// the dlMTensor has not been consumed, call deleter ourselves
dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));
} else {
// the dlMTensor has been consumed
// PyCapsule_GetPointer has set an error indicator
PyErr_Clear();
}
END_HANDLE_TH_ERRORS_RET()
}

PyObject *THPModule_toDLPack(PyObject *_unused, PyObject *data)
{
HANDLE_TH_ERRORS
THPUtils_assert(THPVariable_Check(data), "data must be a Tensor");
DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_UnpackData(data));
return PyCapsule_New(dlMTensor, "dltensor", NULL);
return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
END_HANDLE_TH_ERRORS
}

Expand Down
4 changes: 2 additions & 2 deletions torch/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2):
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef.item())
p.grad.data.mul_(clip_coef)
return total_norm


Expand Down

0 comments on commit 22a194d

Please sign in to comment.