Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smart tensor wrapping #35

Merged
merged 13 commits into from Mar 6, 2020
Merged

Smart tensor wrapping #35

merged 13 commits into from Mar 6, 2020

Conversation

HEmile
Copy link
Owner

@HEmile HEmile commented Mar 6, 2020

Implemented smart tensor wrapping. Storch is now a wrapper around pytorch itself, replacing its methods with storch wrappers. This allows torch.Tensors and storch.Tensors to be used interchangeably, except for the disallowed functionality in storch.Tensor like casting to bool. This closes #11, #29, #31.

In addition,

Allows storch.Tensors to pass instance checks for torch.Tensors by inheriting them. By overriding __getattribute__ and __dir__, methods that appear in torch.Tensor but no in storch.Tensor are ignored. __new__ is overriden to prevent construction errors from the __new__ on torch.Tensor.
- Fixes many bugs related to storch.Tensor inheriting from torch.Tensor
- Uses a much cleaner way of inheriting by overriding __getattribute__ which wraps return torch.Tensor functions using deterministic
- storch.sample now properly unwraps and wraps torch.Distribution's so that they can contain storch.Tensors as inputs.
- Rewrote the test script to not use explicit wraps as they are no longer required.
- The batch links are now not the storch.Tensors that created the batch dimension, but a tuple of (str, int), corresponding to name of the plate (or sample) and the size of plate or the amount of samples.
- batch_links now also includes 1-dimensional plates/samples, to allow for invalid sampling checks. For example, it will now raise an error if a sample with name "z" has some other plate with name "z" as its parent (ie in the batch_links).
- Automatically wrap methods in torch.__init__. These methods are implemented in C, and therefore will not automatically unwrap by calling methods on the tensor.
- Makes sure the context managers are properly set when an exception happens in the wrapped functions.
- Sampling will no longer rewrap in @deterministic, because sampling statements will insert an additional dimensions in the first batch dimension, causing the rewrapping statement to error as it violates the plating constraints. Plates are checked in the storch.Tensor construction
- Removed the cost wrapper as it did not have much of a logical use anymore. See #32.
- Replaced `DeterministicTensor` by `CostTensor`, as its functionality was only to be able to accomodate positive `is_cost` checks.
- Added a experimental`reduce` wrapper that is able to reduce a batched dim without raising an error. Might remove it, as the use case did not actually need it.
- Reworked `storch.nn.b_binary_cross_entropy` to better accomodate the current API.
- Now wraps the methods in torch.Tensor during initialization instead of wrapping them during runtime in __getattribute__
- Properly handles exceptions on invalid method uses
- If torch.Tensor uses self methods and accepts a storch.Tensor, it will also unwrap the storch.Tensor. Before, this would only work if the torch.Tensor method was called through a storch.Tensor object as it requires the __getattribute__ for wrapping.
- Removed adding the size on __new__ for storch.Tensor to reduce memory use (likely?)
Makes it easier to create other wrappers around losses
- For some reason, the wrapped versions of __getitem__ and __setitem__ would sometimes ignore the first dimension of the mask tensor. This is now caught as an IndexError, and corrected by unsqueezing the mask and trying again. It is extremely messy code. See lines 284-288 of storch.Tensor for more info
@HEmile HEmile added the critical feature A feature that is of high importance label Mar 6, 2020
@HEmile HEmile added this to In progress in Stochastic computation graphs via automation Mar 6, 2020
@HEmile HEmile merged commit 5f2d9af into scg Mar 6, 2020
Stochastic computation graphs automation moved this from In progress to Done Mar 6, 2020
@HEmile HEmile deleted the smart_tensor_wrapping branch March 6, 2020 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
critical feature A feature that is of high importance
Development

Successfully merging this pull request may close these issues.

None yet

1 participant