Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add utilities for hooking and chaining methods calls along tensor chains #3636

Merged
merged 1 commit into from
May 29, 2020

Conversation

karlhigley
Copy link
Contributor

@karlhigley karlhigley commented May 29, 2020

Description

Adds utilities for inserting optional _before_method()/_after_method() hooks on a method decorated with @hookable, which allows custom tensor types to insert any functionality they individually require before or after a generic tensor method like send. This should help clean up the copious if isinstance(self, TensorClass) code in AbstractTensor and beyond.

In order to make sure that all custom functionality for types in the tensor chain gets called, this also creates a chain_call() function, which accepts a method name with args/kwargs, and calls that method (if it exists) on every tensor in the chain from outermost to innermost, returning the results in an ordered list.

Type of Change

Please mark options that are relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Documentation (non-breaking change which adds documentation)
  • Improvement (non-breaking change that improves the performance or reliability of existing functionality)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

Checklist

@karlhigley karlhigley added the Type: Refactor 🔨 A complete overhaul of a file, feature, or codebase label May 29, 2020
@karlhigley karlhigley requested a review from a team as a code owner May 29, 2020 16:29
@karlhigley karlhigley self-assigned this May 29, 2020
@karlhigley karlhigley added this to the Cross-platform Execution milestone May 29, 2020
@karlhigley
Copy link
Contributor Author

karlhigley commented May 29, 2020

Here are some examples of the kinds of code I'm trying to remove:

  • TorchTensor.send() checks permissions on the whole tensor chain before sending a tensor, which means it has to know about functionality that properly belongs in PrivateTensor.
  • TorchTensor.send() checks child types in order to determine whether or not to reset the garbage collection flag.
  • TorchTensor.send() checks several flags to determine how to handle gradients, in order to make AutogradTensor compatible with MultiPointerTensor.

With _before_send() and _after_send() hooks that get called on each tensor in the chain, we should be able to move that kind of code into the relevant classes and avoid creating loops in the dependencies of the abstract and custom tensor types.

@codecov
Copy link

codecov bot commented May 29, 2020

Codecov Report

Merging #3636 into master will increase coverage by 0.01%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3636      +/-   ##
==========================================
+ Coverage   94.68%   94.70%   +0.01%     
==========================================
  Files         161      163       +2     
  Lines       17246    17294      +48     
==========================================
+ Hits        16330    16378      +48     
  Misses        916      916              
Impacted Files Coverage Δ
syft/generic/abstract/hookable.py 100.00% <100.00%> (ø)
test/generic/test_hookable.py 100.00% <100.00%> (ø)

from functools import wraps


def chain_call(obj, method_name, *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nice convenience function - although I'm not sure that we ever truly go more than one child deep (or two children deep in special circumstances) within any call. I do assume that you have a use for it in mind though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we don’t usually reach very far into the chain yet, but having stared at the code for a while, it’s starting to look like it would get simpler/cleaner if we did. There are also a few ad-hoc implementations of this pattern floating around.

from syft.generic.abstract.hookable import hookable


def test_chain_call():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! I get it... wow that's clever! Love that!

return results


def hookable(hookable_method):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does overlap with ObjectConstructor a bit - but I'm ok with it given that ObjectConstructor doesn't exist yet - and there's a good chance they'll actually be compatible pieces of functionality with each other (aka ObjectConstructor creates the method names which this decorator detects)

@iamtrask iamtrask merged commit edbb576 into OpenMined:master May 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: Refactor 🔨 A complete overhaul of a file, feature, or codebase
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants