In [None]:
SYFT_VERSION = ">=0.8.2.b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
# %pip install {package_string} -q

In [None]:
# third party
import torch
from torch import nn
import torch.nn.functional as F

# syft absolute
import syft as sy

sy.requires(SYFT_VERSION)

In [None]:
server = sy.orchestra.launch(name="test-datasite-1", dev_mode=True, reset=True)

In [None]:
datasite_client = server.login(email="info@openmined.org", password="changethis")

In [None]:
# Set the random seed for reproducibility
torch.manual_seed(42)

In [None]:
# Generate random data
train_data = torch.rand((4, 28, 28, 1))
train_data.shape

In [None]:
assert torch.round(train_data.sum()) == 1557

In [None]:
train = sy.ActionObject.from_obj(train_data)

In [None]:
type(train.syft_action_data), train.id, train.shape

In [None]:
train_datasite_obj = train.send(datasite_client)
type(train_datasite_obj)

In [None]:
train_datasite_obj

In [None]:
assert torch.round(train_datasite_obj.syft_action_data.sum()) == 1557

In [None]:
class MLP(nn.Module):
    def __init__(self, out_dims):
        super().__init__()
        self.out_dims = out_dims
        self.linear1 = nn.Linear(784, 128)
        self.linear2 = nn.Linear(128, out_dims)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x


model = MLP(out_dims=10)
model

In [None]:
weights = model.state_dict()

In [None]:
assert isinstance(weights, dict)

In [None]:
w = sy.ActionObject.from_obj(weights)

In [None]:
type(w.syft_action_data), w.id

In [None]:
weight_datasite_obj = w.send(datasite_client)

In [None]:
@sy.syft_function(
    input_policy=sy.ExactMatch(
        weights=weight_datasite_obj.id, data=train_datasite_obj.id
    ),
    output_policy=sy.SingleExecutionExactOutput(),
)
def train_mlp(weights, data):
    # third party
    import torch
    from torch import nn
    import torch.nn.functional as F

    class MLP(nn.Module):
        def __init__(self, out_dims):
            super().__init__()
            self.out_dims = out_dims
            self.linear1 = nn.Linear(784, 128)
            self.linear2 = nn.Linear(128, out_dims)

        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = self.linear1(x)
            x = F.relu(x)
            x = self.linear2(x)
            return x

    # Initialize the model
    model = MLP(out_dims=10)

    # Load weights into the model
    model.load_state_dict(weights)

    # Perform a forward pass
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        output = model(data)

    return output

In [None]:
pointer = train_mlp(weights=weight_datasite_obj, data=train_datasite_obj)
output = pointer.get()

In [None]:
assert torch.allclose(torch.sum(output), torch.tensor(1.3907))

In [None]:
request = datasite_client.code.request_code_execution(train_mlp)
request

In [None]:
request.approve()

In [None]:
datasite_client._api = None
_ = datasite_client.api

In [None]:
result_ptr = datasite_client.code.train_mlp(weights=w.id, data=train.id)

In [None]:
result = result_ptr.get()

In [None]:
assert torch.allclose(torch.sum(result), torch.tensor(1.3907))

In [None]:
if server.server_type.value == "python":
    server.land()