Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Hook method and property -RSTensor #194

Merged
merged 4 commits into from Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 38 additions & 17 deletions src/sympc/tensor/replicatedshare_tensor.py
Expand Up @@ -5,12 +5,19 @@
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Union

# third party
import torch

from sympc.session import Session

from .tensor import SyMPCTensor

PROPERTIES_NEW_SHARE_TENSOR: Set[str] = {"T"}
METHODS_NEW_SHARE_TENSOR: Set[str] = {"unsqueeze", "view", "t", "sum", "clone"}
PROPERTIES_NEW_RS_TENSOR: Set[str] = {"T"}
METHODS_NEW_RS_TENSOR: Set[str] = {"unsqueeze", "view", "t", "sum", "clone"}
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved


class ReplicatedSharedTensor(metaclass=SyMPCTensor):
Expand All @@ -30,11 +37,15 @@ class ReplicatedSharedTensor(metaclass=SyMPCTensor):
METHODS_FORWARD = {"numel", "t", "unsqueeze", "view", "sum", "clone"}
PROPERTIES_FORWARD = {"T"}

def __init__(self, shares=None, session=None):
def __init__(
self,
shares: Optional[Union[float, int, torch.Tensor]] = None,
session: Optional[Session] = None,
):
"""Initialize ShareTensor.

Args:
shares (Optional[List[ShareTensor]]): Shares from which RSTensor is created.
shares (Optional[Union[float,int,torch.Tensor]]): Shares from which RSTensor is created.
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
session (Optional[Session]): The session. Defaults to None.
"""
self.session = session
Expand Down Expand Up @@ -154,18 +165,24 @@ def hook_property(property_name: str) -> Any:
A hooked property
"""

def property_new_share_tensor_getter(_self: "ReplicatedSharedTensor") -> Any:
tensor = getattr(_self.tensor, property_name)
res = ReplicatedSharedTensor(session=_self.session)
res.tensor = tensor
def property_new_rs_tensor_getter(_self: "ReplicatedSharedTensor") -> Any:
shares = []
nr_parties = _self.session.nr_parties

for i in range(nr_parties - 1): # each party has n-1 shares.
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
tensor = getattr(_self.shares[i], property_name)
shares.append(tensor)

res = ReplicatedSharedTensor(session=_self.session, shares=shares)

return res

def property_getter(_self: "ReplicatedSharedTensor") -> Any:
prop = getattr(_self.tensor, property_name)
prop = getattr(_self.shares[0], property_name)
return prop

if property_name in PROPERTIES_NEW_SHARE_TENSOR:
res = property(property_new_share_tensor_getter, None)
if property_name in PROPERTIES_NEW_RS_TENSOR:
res = property(property_new_rs_tensor_getter, None)
else:
res = property(property_getter, None)

Expand All @@ -192,20 +209,24 @@ def hook_method(method_name: str) -> Callable[..., Any]:
def method_new_rs_tensor(
_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]
) -> Any:
method = getattr(_self.tensor, method_name)
tensor = method(*args, **kwargs)
res = ReplicatedSharedTensor(session=_self.session, shares=_self.shares)
res.tensor = tensor
shares = []
nr_parties = _self.session.nr_parties
for i in range(nr_parties - 1): # each party has n-1 shares
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
tensor = getattr(_self.shares[i], method_name)(*args, **kwargs)
shares.append(tensor)

res = ReplicatedSharedTensor(session=_self.session, shares=shares)

return res

def method(
_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]
) -> Any:
method = getattr(_self.tensor, method_name)
method = getattr(_self.shares[0], method_name)
res = method(*args, **kwargs)
return res

if method_name in METHODS_NEW_SHARE_TENSOR:
if method_name in METHODS_NEW_RS_TENSOR:
res = method_new_rs_tensor
else:
res = method
Expand Down
39 changes: 38 additions & 1 deletion tests/sympc/tensor/replicatedshare_tensor_test.py
@@ -1,6 +1,43 @@
# third party
import torch

from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import ReplicatedSharedTensor


def test_import_RSTensor():
def test_import_RSTensor() -> None:

ReplicatedSharedTensor()


def test_hook_method(get_clients) -> None:
alice, bob = get_clients(2)
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
session = Session(parties=[alice, bob])
SessionManager.setup_mpc(session)

x = torch.randn(1, 3)
y = torch.randn(1, 3) # noqa: F841
shares = [x, y]

rst = ReplicatedSharedTensor(shares=shares, session=session)

assert rst.numel() == x.numel()
assert (rst.t().shares[0] == x.t()).all()
assert (rst.unsqueeze(dim=0).shares[0] == x.unsqueeze(dim=0)).all()
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
assert (rst.view(3, 1).shares[0] == x.view(3, 1)).all()
assert (rst.sum().shares[0] == x.sum()).all()


def test_hook_property(get_clients) -> None:
alice, bob = get_clients(2)
session = Session(parties=[alice, bob])
SessionManager.setup_mpc(session)

x = torch.randn(1, 3)
y = torch.randn(1, 3) # noqa: F841
shares = [x, y]

rst = ReplicatedSharedTensor(shares=shares, session=session)

assert (rst.T.shares[0] == x.T).all()