/
__init__.py
135 lines (110 loc) · 4.34 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
r"""
PySyft is a Python library for secure, private Deep Learning.
PySyft decouples private data from model training, using Federated Learning,
Differential Privacy, and Multi-Party Computation (MPC) within PyTorch.
"""
# We load these modules first so that syft knows which are available
from syft import dependency_check
from syft import frameworks # Triggers registration of any available frameworks
# Major imports
from syft.version import __version__
# This import statement is strictly here to trigger registration of syft
# tensor types inside hook_args.py.
import syft.frameworks.torch.hook.hook_args
import logging
logger = logging.getLogger(__name__)
# The purpose of the following import section is to increase the convenience of using
# PySyft by making it possible to import the most commonly used objects from syft
# directly (i.e., syft.TorchHook or syft.VirtualWorker or syft.LoggingTensor)
# Tensorflow / Keras dependencies
# Import Hooks
if dependency_check.tfe_available:
from syft.frameworks.keras import KerasHook
from syft.workers.tfe import TFECluster
from syft.workers.tfe import TFEWorker
__all__ = ["KerasHook", "TFECluster", "TFEWorker"]
else:
logger.info("TF Encrypted Keras not available.")
__all__ = []
# Pytorch dependencies
# Import Hook
from syft.frameworks.torch.hook.hook import TorchHook
# Import grids
from syft.grid.private_grid import PrivateGridNetwork
from syft.grid.public_grid import PublicGridNetwork
# Import sandbox
from syft.sandbox import create_sandbox, hook
# Import federate learning objects
from syft.frameworks.torch.fl import FederatedDataset, FederatedDataLoader, BaseDataset
from syft.federated.train_config import TrainConfig
# Import messaging objects
from syft.messaging.protocol import Protocol
from syft.messaging.plan import Plan
from syft.messaging.plan import func2plan
from syft.messaging.plan import method2plan
from syft.messaging.promise import Promise
# Import Worker Types
from syft.workers.virtual import VirtualWorker
from syft.workers.websocket_client import WebsocketClientWorker
from syft.workers.websocket_server import WebsocketServerWorker
# Import Syft's Public Tensor Types
from syft.frameworks.torch.tensors.decorators.logging import LoggingTensor
from syft.frameworks.torch.tensors.interpreters.additive_shared import AdditiveSharingTensor
from syft.frameworks.torch.tensors.interpreters.crt_precision import CRTPrecisionTensor
from syft.frameworks.torch.tensors.interpreters.autograd import AutogradTensor
from syft.frameworks.torch.tensors.interpreters.precision import FixedPrecisionTensor
from syft.frameworks.torch.tensors.interpreters.numpy import create_numpy_tensor as NumpyTensor
from syft.frameworks.torch.tensors.interpreters.private import PrivateTensor
from syft.frameworks.torch.tensors.interpreters.large_precision import LargePrecisionTensor
from syft.frameworks.torch.tensors.interpreters.promise import PromiseTensor
from syft.generic.pointers.pointer_plan import PointerPlan
from syft.generic.pointers.pointer_protocol import PointerProtocol
from syft.generic.pointers.pointer_tensor import PointerTensor
from syft.generic.pointers.multi_pointer import MultiPointerTensor
# Import serialization tools
from syft import serde
# import functions
from syft.frameworks.torch.functions import combine_pointers
from syft.frameworks.torch.he.paillier import keygen
def pool():
if not hasattr(syft, "_pool"):
import multiprocessing
syft._pool = multiprocessing.Pool()
return syft._pool
__all__.extend(
[
"frameworks",
"serde",
"TorchHook",
"VirtualWorker",
"WebsocketClientWorker",
"WebsocketServerWorker",
"Plan",
"func2plan",
"method2plan",
"make_plan",
"LoggingTensor",
"AdditiveSharingTensor",
"CRTPrecisionTensor",
"AutogradTensor",
"FixedPrecisionTensor",
"LargePrecisionTensor",
"PointerTensor",
"MultiPointerTensor",
"PrivateGridNetwork",
"PublicGridNetwork",
"create_sandbox",
"hook",
"combine_pointers",
"FederatedDataset",
"FederatedDataLoader",
"BaseDataset",
"TrainConfig",
]
)
local_worker = None
torch = None
framework = None
if "ID_PROVIDER" not in globals():
from syft.generic.id_provider import IdProvider
ID_PROVIDER = IdProvider()