In [None]:
# third party
# syft absolute
import syft as sy
import torch
from mnist_dataset import mnist

In [None]:
server = sy.orchestra.launch(name="mnist-torch-datasite", dev_mode=True)
ds_client = server.login(email="sheldon@caltech.edu", password="changethis")

## After the DO has ran the code and deposited the results, the DS downloads them

In [None]:
datasets = ds_client.datasets.get_all()
assets = datasets[0].assets
assert len(assets) == 2

In [None]:
training_images = assets[0]
training_labels = assets[1]

In [None]:
ds_client.code

In [None]:
result = ds_client.code.mnist_3_linear_layers_torch(
    mnist_images=training_images, mnist_labels=training_labels,
)

In [None]:
train_accs, params = result.get_from(ds_client)

In [None]:
assert isinstance(train_accs, list)
train_accs

In [None]:
assert isinstance(params, dict)
params

## Having the trained weights, the DS can do inference on the its MNIST test dataset

In [None]:
_, _, test_images, test_labels = mnist()

In [None]:
assert test_images.shape == (10000, 784)
assert test_labels.shape == (10000, 10)

#### Define the neural network and the accuracy function

In [None]:
# third party
from torch import nn


class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(784, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return torch.log_softmax(self.fc3(x), dim=1)


# Print the model to see the architecture
model = MLP()

model

In [None]:
def accuracy(model, batch, params=None):
    if params is not None:
        model.load_state_dict(params)

    # Convert inputs and targets to PyTorch tensor
    inputs, targets = batch
    inputs = torch.tensor(inputs)
    targets = torch.tensor(targets)

    # Get model predictions
    with torch.no_grad():
        outputs = model(inputs)
    # Get predicted class
    _, predicted_class = torch.max(outputs, dim=1)

    # Calculate accuracy
    accuracy = torch.mean((predicted_class == torch.argmax(targets, dim=1)).float())
    return accuracy.item()  # Convert accuracy to a Python scalar

#### Test inference using random weights

In [None]:
test_acc = accuracy(model, (test_images, test_labels))

#### Test inference using the trained weights recevied from the DO

In [None]:
test_acc = accuracy(model, (test_images, test_labels), params)

In [None]:
assert test_acc * 100 > 70