Skip to content

Commit

Permalink
Revert & improve wrapper creation (#19)
Browse files Browse the repository at this point in the history
* revert object_type, pass  param to .wrap() instead

* ben's README suggest

Co-Authored-By: Benjamin DeCoste <bendecoste@gmail.com>
  • Loading branch information
jvmncs and bendecoste committed Oct 10, 2019
1 parent 68c7126 commit 70c8087
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 20 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -15,6 +15,8 @@ PySyft-TensorFlow is available on pip
pip install syft-tensorflow
```

NOTE: We aren't yet on a proper release schedule. Until then, we recommend building the code from source. The master branch is intended to be kept in line with [this branch](https://github.com/dropoutlabs/PySyft/tree/dev) on the [DropoutLabs](https://github.com/dropoutlabs/PySyft) fork of PySyft. If you have any trouble, please open an issue or reach out on Slack via the #team_tensorflow or #team_pysyft channels.

## Usage

See the [PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)
Expand Down
7 changes: 0 additions & 7 deletions syft_tensorflow/hook/hook.py
Expand Up @@ -282,13 +282,6 @@ def _add_methods_from_native_tensor(tensor_type: type, syft_type: type):
# Add this method to the TF tensor
setattr(tensor_type, attr, getattr(syft_type, attr))

@classmethod
def create_wrapper(cls, child_to_wrap, *args, **kwargs):
if child_to_wrap.object_type==tf.Variable:
return tf.Variable([])
else:
return tf.constant([])

@classmethod
def create_shape(cls, shape_dims):
return tf.TensorShape(shape_dims)
Expand Down
9 changes: 2 additions & 7 deletions syft_tensorflow/tensor/tensor.py
Expand Up @@ -162,7 +162,7 @@ def send(
self.child = ptr
return self
else:
output = ptr if no_wrap else ptr.wrap()
output = ptr if no_wrap else ptr.wrap(type=tf.constant, value=[])

else:

Expand All @@ -173,7 +173,7 @@ def send(
output = syft.MultiPointerTensor(children=children)

if not no_wrap:
output = output.wrap()
output = output.wrap(type=tf.constant, value=[])

return output

Expand Down Expand Up @@ -212,7 +212,6 @@ def create_pointer(
ptr_id: (str or int) = None,
garbage_collect_data: bool = True,
shape=None,
object_type=None,
) -> PointerTensor:
"""Creates a pointer to the "self" torch.Tensor object.
Expand All @@ -232,9 +231,6 @@ def create_pointer(
if shape is None:
shape = self.shape

if object_type is None:
object_type = tf.Tensor

ptr = syft.PointerTensor.create_pointer(
self,
location,
Expand All @@ -244,7 +240,6 @@ def create_pointer(
ptr_id,
garbage_collect_data,
shape,
object_type=object_type,
)

return ptr
Expand Down
8 changes: 2 additions & 6 deletions syft_tensorflow/tensor/variable.py
Expand Up @@ -162,7 +162,7 @@ def send(
self.child = ptr
return self
else:
output = ptr if no_wrap else ptr.wrap()
output = ptr if no_wrap else ptr.wrap(type=tf.Variable, initial_value=[])

else:

Expand All @@ -173,7 +173,7 @@ def send(
output = syft.MultiPointerTensor(children=children)

if not no_wrap:
output = output.wrap()
output = output.wrap(type=tf.Variable, initial_value=[])

return output

Expand Down Expand Up @@ -232,9 +232,6 @@ def create_pointer(
if shape is None:
shape = self.shape

if object_type is None:
object_type = tf.Variable

ptr = syft.PointerTensor.create_pointer(
self,
location,
Expand All @@ -244,7 +241,6 @@ def create_pointer(
ptr_id,
garbage_collect_data,
shape,
object_type=object_type,
)

return ptr
Expand Down

0 comments on commit 70c8087

Please sign in to comment.