Skip to content

Commit

Permalink
cat concatenates ReplicatedSharedTensor (#326)
Browse files Browse the repository at this point in the history
* `cat` concatenates `ReplicatedSharedTensor`

- refer to issue #324
- `cat` can concatenates `ShareTensor`s, but it is not workable for `ReplicatedSharedTensor`.
- added `cat_replicatedShare_tensor` in *apy.py* and *static.py* for the concatenation of `ReplicatedSharedTensor`.

* Saving a variable and formatting

Format
The following commands have been run:
- `python -m black tensor/static.py`
- `python -m black api.py`
- `isort tensor/static.py`

Save number of replicated shares as a variable
- Q: Could you save the len(shares[0].shares) in a variable and then use it here and on line 140? [x]*done*

Amended docstring
- `shares` is a tuple of `ReplicatedSharedTensor`s
- the empty blank is deleted in `cat_replicatedShare_tensor`

* For Python3.7 that doesn't support `math.prod`

- python 3.7 cannot support `math.prod`, only python 3.8+ can.

* Remove python3.7

* Remove tutorial python3.7

Co-authored-by: George Muraru <murarugeorgec@gmail.com>
  • Loading branch information
Timo9Madrid7 and gmuraru committed May 15, 2022
1 parent 41da2eb commit 91cf991
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yml
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
max-parallel: 3
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
10 changes: 5 additions & 5 deletions CONTRIBUTING.md
Expand Up @@ -65,7 +65,7 @@ If you are new to the project and want to get into the code, we recommend pickin
Before you get started you will need a few things installed depending on your operating system.

- OS Package Manager
- Python 3.7+
- Python 3.8+
- git

### OSes
Expand Down Expand Up @@ -117,7 +117,7 @@ $ brew install git

## Python Versions

This project supports Python 3.7+, however, if you are contributing it can help to be able to switch between python versions to fix issues or bugs that relate to a specific python version. Depending on your operating system there are a number of ways to install different versions of python however one of the easiest is with the `pyenv` tool. Additionally, as we will be frequently be installing and changing python packages for this project we should isolate it from your system python and other projects you have using a virtualenv.
This project supports Python 3.8+, however, if you are contributing it can help to be able to switch between python versions to fix issues or bugs that relate to a specific python version. Depending on your operating system there are a number of ways to install different versions of python however one of the easiest is with the `pyenv` tool. Additionally, as we will be frequently be installing and changing python packages for this project we should isolate it from your system python and other projects you have using a virtualenv.

### MacOS

Expand Down Expand Up @@ -152,10 +152,10 @@ $ pyenv install --list | grep 3.9
3.9.4
```

Wow, there are lots of options, lets install 3.7.
Wow, there are lots of options, lets install 3.8.

```
$ pyenv install 3.7.9
$ pyenv install 3.8.0
```

Now, lets see what versions are installed:
Expand Down Expand Up @@ -461,7 +461,7 @@ $ pydocstyle .

### Imports Formatting

We use isort to automatically format the python imports.
We use isort to automatically format the python imports.
Run isort manually like this:

```
Expand Down
4 changes: 4 additions & 0 deletions src/sympc/api.py
Expand Up @@ -262,6 +262,10 @@
"sympc.tensor.replicatedshare_tensor.ReplicatedSharedTensor.repeat",
"sympc.tensor.replicatedshare_tensor.ReplicatedSharedTensor",
),
(
"sympc.tensor.static.cat_replicatedShare_tensor",
"sympc.tensor.replicatedshare_tensor.ReplicatedSharedTensor",
),
]

allowed_external_attrs = [
Expand Down
2 changes: 1 addition & 1 deletion src/sympc/protocol/falcon/falcon.py
Expand Up @@ -545,7 +545,7 @@ def private_compare(x: List[MPCTensor], r: torch.Tensor) -> MPCTensor:
c[i] = u[i] + 1 + w
w += x[i] ^ r_i

d = m * math.prod(c)
d = m * math.prod(c)

d_val = d.reconstruct(decode=False) # plaintext d.
d_val[d_val != 0] = 1 # making all non zero values as 1.
Expand Down
38 changes: 36 additions & 2 deletions src/sympc/tensor/static.py
Expand Up @@ -15,8 +15,10 @@
import numpy as np
import torch

import sympc.protocol as protocol
from sympc.session import get_session
from sympc.tensor.mpc_tensor import MPCTensor
from sympc.tensor.replicatedshare_tensor import ReplicatedSharedTensor
from sympc.tensor.share_tensor import ShareTensor
from sympc.utils import parallel_execution

Expand Down Expand Up @@ -90,8 +92,12 @@ def cat(tensors: List, dim: int = 0) -> MPCTensor:
)
)

stack_shares = parallel_execution(cat_share_tensor, session.parties)(args)
from sympc.tensor import MPCTensor
if isinstance(session.protocol, protocol.FSS):
stack_shares = parallel_execution(cat_share_tensor, session.parties)(args)
elif isinstance(session.protocol, protocol.Falcon):
stack_shares = parallel_execution(cat_replicatedShare_tensor, session.parties)(
args
)

expected_shape = torch.cat(
[torch.empty(each_tensor.shape) for each_tensor in tensors], dim=dim
Expand All @@ -118,6 +124,34 @@ def cat_share_tensor(session_uuid_str: str, *shares: Tuple[ShareTensor]) -> Shar
return result


def cat_replicatedShare_tensor(
session_uuid_str: str, *shares: Tuple[ReplicatedSharedTensor]
) -> ReplicatedSharedTensor:
"""Helper method that performs torch.cat on the replicated shares of the Tensors.
Args:
session_uuid_str (str): UUID to identify the session on each party side.
shares (Tuple[ReplicatedSharedTensor]): Replicated shares of the tensors to be concatenated.
Returns:
ReplicatedSharedTensor: Respective replicated shares after concatenation
"""
session = get_session(session_uuid_str)
result = ReplicatedSharedTensor(
session_uuid=UUID(session_uuid_str), config=session.config
)

num_shares = len(shares[0].shares)

cat_result = [torch.tensor([]).type(torch.LongTensor) for _ in range(num_shares)]
for share in shares:
for i in range(num_shares):
cat_result[i] = torch.cat([cat_result[i], share.shares[i]])

result.shares = cat_result
return result


def helper_argmax(
x: MPCTensor,
dim: Optional[Union[int, Tuple[int]]] = None,
Expand Down

0 comments on commit 91cf991

Please sign in to comment.