Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Hook method and property -RSTensor #194

Merged
merged 4 commits into from Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 40 additions & 19 deletions src/sympc/tensor/replicatedshare_tensor.py
Expand Up @@ -5,20 +5,28 @@
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Union

# third party
import torch

from sympc.session import Session

from .tensor import SyMPCTensor

PROPERTIES_NEW_SHARE_TENSOR: Set[str] = {"T"}
METHODS_NEW_SHARE_TENSOR: Set[str] = {"unsqueeze", "view", "t", "sum", "clone"}
PROPERTIES_NEW_RS_TENSOR: Set[str] = {"T"}
METHODS_NEW_RS_TENSOR: Set[str] = {"unsqueeze", "view", "t", "sum", "clone"}
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved


class ReplicatedSharedTensor(metaclass=SyMPCTensor):
"""RSTensor is used when a party holds more than a single share,required by various protocols.

Arguments:
session (Session): the session
shares: The shares held by the party
shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
from which RSTensor is created.
session (Optional[Session]): The session.

Attributes:
shares: The shares held by the party
Expand All @@ -30,11 +38,16 @@ class ReplicatedSharedTensor(metaclass=SyMPCTensor):
METHODS_FORWARD = {"numel", "t", "unsqueeze", "view", "sum", "clone"}
PROPERTIES_FORWARD = {"T"}

def __init__(self, shares=None, session=None):
def __init__(
self,
shares: Optional[List[Union[float, int, torch.Tensor]]] = None,
session: Optional[Session] = None,
):
"""Initialize ShareTensor.

Args:
shares (Optional[List[ShareTensor]]): Shares from which RSTensor is created.
shares (Optional[List[Union[float, int, torch.Tensor]]]): Shares list
from which RSTensor is created.
session (Optional[Session]): The session. Defaults to None.
"""
self.session = session
Expand Down Expand Up @@ -154,18 +167,23 @@ def hook_property(property_name: str) -> Any:
A hooked property
"""

def property_new_share_tensor_getter(_self: "ReplicatedSharedTensor") -> Any:
tensor = getattr(_self.tensor, property_name)
res = ReplicatedSharedTensor(session=_self.session)
res.tensor = tensor
def property_new_rs_tensor_getter(_self: "ReplicatedSharedTensor") -> Any:
shares = []

for i in range(len(_self.shares)):
tensor = getattr(_self.shares[i], property_name)
shares.append(tensor)

res = ReplicatedSharedTensor(session=_self.session, shares=shares)

return res

def property_getter(_self: "ReplicatedSharedTensor") -> Any:
prop = getattr(_self.tensor, property_name)
prop = getattr(_self.shares[0], property_name)
return prop

if property_name in PROPERTIES_NEW_SHARE_TENSOR:
res = property(property_new_share_tensor_getter, None)
if property_name in PROPERTIES_NEW_RS_TENSOR:
res = property(property_new_rs_tensor_getter, None)
else:
res = property(property_getter, None)

Expand All @@ -192,20 +210,23 @@ def hook_method(method_name: str) -> Callable[..., Any]:
def method_new_rs_tensor(
_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]
) -> Any:
method = getattr(_self.tensor, method_name)
tensor = method(*args, **kwargs)
res = ReplicatedSharedTensor(session=_self.session, shares=_self.shares)
res.tensor = tensor
shares = []
for i in range(len(_self.shares)):
tensor = getattr(_self.shares[i], method_name)(*args, **kwargs)
shares.append(tensor)

res = ReplicatedSharedTensor(session=_self.session, shares=shares)

return res

def method(
_self: "ReplicatedSharedTensor", *args: List[Any], **kwargs: Dict[Any, Any]
) -> Any:
method = getattr(_self.tensor, method_name)
method = getattr(_self.shares[0], method_name)
res = method(*args, **kwargs)
return res

if method_name in METHODS_NEW_SHARE_TENSOR:
if method_name in METHODS_NEW_RS_TENSOR:
res = method_new_rs_tensor
else:
res = method
Expand Down
46 changes: 45 additions & 1 deletion tests/sympc/tensor/replicatedshare_tensor_test.py
@@ -1,6 +1,50 @@
# third party
import torch

from sympc.session import Session
from sympc.session import SessionManager
from sympc.tensor import ReplicatedSharedTensor


def test_import_RSTensor():
def test_import_RSTensor() -> None:

ReplicatedSharedTensor()


def test_hook_method(get_clients) -> None:
clients = get_clients(3)
session = Session(parties=clients)
SessionManager.setup_mpc(session)

x = torch.randn(1, 3)
y = torch.randn(1, 3)
shares = [x, y]

rst = ReplicatedSharedTensor(shares=shares, session=session)

assert rst.numel() == x.numel()
assert (rst.t().shares[0] == x.t()).all()
assert (rst.unsqueeze(dim=0).shares[0] == x.unsqueeze(dim=0)).all()
rasswanth-s marked this conversation as resolved.
Show resolved Hide resolved
assert (rst.view(3, 1).shares[0] == x.view(3, 1)).all()
assert (rst.sum().shares[0] == x.sum()).all()

assert rst.numel() == y.numel()
assert (rst.t().shares[1] == y.t()).all()
assert (rst.unsqueeze(dim=0).shares[1] == y.unsqueeze(dim=0)).all()
assert (rst.view(3, 1).shares[1] == y.view(3, 1)).all()
assert (rst.sum().shares[1] == y.sum()).all()


def test_hook_property(get_clients) -> None:
clients = get_clients(3)
session = Session(parties=clients)
SessionManager.setup_mpc(session)

x = torch.randn(1, 3)
y = torch.randn(1, 3)
shares = [x, y]

rst = ReplicatedSharedTensor(shares=shares, session=session)

assert (rst.T.shares[0] == x.T).all()
assert (rst.T.shares[1] == y.T).all()