In [1]:
# Import the necessary libraries
import syft as sy
sy.requires(">=0.8-beta")
import jax
import jax.numpy as jnp



✅ The installed version of syft==0.8.1b1 matches the requirement >=0.8b0


In [2]:
# Register a client to the domain
node = sy.orchestra.launch(name="test-domain-1")
guest_domain_client = node.client
guest_domain_client.register(name="Jane Doe", email="jane@caltech.edu", password="abc123", institution="Caltech", website="https://www.caltech.edu/")
guest_domain_client.login(email="jane@caltech.edu", password="abc123")

SQLite Store Path:
!open file:///tmp/7bca415d13ed1ec841f0d0aede098dbb.sqlite



<SyftClient - test-domain-1 <7bca415d13ed1ec841f0d0aede098dbb>: PythonConnection>

In [8]:
# Inspect available data
results = guest_domain_client.api.services.dataset.get_all()
results[0].assets[1]

```python
Asset: test_data
Pointer Id: dd76b11202f84aa0a5ef109f2d147711
Description: test data for MNIST
Total Data Subjects: 1
Shape: (10000, 28, 28)
Contributors: 0

```

In [None]:
# Create a function for code execution
# ATTENTION: ALL LIBRARIES USED SHOULD BE DEFINED INSIDE THE FUNCTION CONTEXT!!!

@sy.syft_function(input_policy=sy.ExactMatch(),
                  output_policy=sy.SingleExecutionExactOutput())
def training_loop(train_dataset, test_dataset):
    import numpy as onp
    import jax.numpy as np
    from jax import grad, jit, vmap, value_and_grad
    from jax import random
    
    from jax.scipy.special import logsumexp
    from jax.experimental import optimizers
    
    import time
    
    key = random.PRNGKey(1)

    def ReLU(x):
        """ Rectified Linear Unit (ReLU) activation function """
        return np.maximum(0, x)
    
    def relu_layer(params, x):
        """ Simple ReLu layer for single sample """
        return ReLU(np.dot(params[0], x) + params[1])
    
    def initialize_mlp(sizes, key):
        """ Initialize the weights of all layers of a linear layer network """
        keys = random.split(key, len(sizes))
        # Initialize a single layer with Gaussian weights -  helper function
        def initialize_layer(m, n, key, scale=1e-2):
            w_key, b_key = random.split(key)
            return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
        return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

    layer_sizes = [784, 512, 512, 10]
    # Return a list of tuples of layer weights
    params = initialize_mlp(layer_sizes, key)
    
    def forward_pass(params, in_array):
        """ Compute the forward pass for each example individually """
        activations = in_array

        # Loop over the ReLU hidden layers
        for w, b in params[:-1]:
            activations = relu_layer([w, b], activations)

        # Perform final trafo to logits
        final_w, final_b = params[-1]
        logits = np.dot(final_w, activations) + final_b
        return logits - logsumexp(logits)

    # Make a batched version of the `predict` function
    batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)
    # Defining an optimizer in Jax
    step_size = 1e-3
    opt_init, opt_update, get_params = optimizers.adam(step_size)
    opt_state = opt_init(params)

    num_epochs = 10
    num_classes = 10


    def one_hot(x, k, dtype=np.float32):
        """Create a one-hot encoding of x of size k """
        return np.array(x[:, None] == np.arange(k), dtype)

    def loss(params, in_arrays, targets):
        """ Compute the multi-class cross-entropy loss """
        preds = batch_forward(params, in_arrays)
        return -np.sum(preds * targets)

    def accuracy(params, dataset):
        """ Compute the accuracy for a provided dataloader """
        acc_total = 0
        for batch_idx, (data, target) in enumerate(dataset):
            images = np.array(data).reshape(data.size(0), 28*28)
            targets = one_hot(np.array(target), num_classes)

            target_class = np.argmax(targets, axis=1)
            predicted_class = np.argmax(batch_forward(params, images), axis=1)
            acc_total += np.sum(predicted_class == target_class)
        return acc_total/len()
    
    @jit
    def update(params, x, y, opt_state):
        """ Compute the gradient for a batch and update the parameters """
        value, grads = value_and_grad(loss)(params, x, y)
        opt_state = opt_update(0, grads, opt_state)
        return get_params(opt_state), opt_state, value
    
    def run_mnist_training_loop(num_epochs, opt_state, net_type="MLP"):
        """ Implements a learning loop over epochs. """
        # Initialize placeholder for loggin
        log_acc_train, log_acc_test, train_loss = [], [], []

        # Get the initial set of parameters
        params = get_params(opt_state)

        # Get initial accuracy after random init
        train_acc = accuracy(params, train_dataset)
        test_acc = accuracy(params, test_dataset)
        log_acc_train.append(train_acc)
        log_acc_test.append(test_acc)

        # Loop over the training epochs
        for epoch in range(num_epochs):
            start_time = time.time()
            for data, target in train_dataset:
                if net_type == "MLP":
                    # Flatten the image into 784 vectors for the MLP
                    x = np.array(data).reshape(data.size(0), 28*28)
                elif net_type == "CNN":
                    # No flattening of the input required for the CNN
                    x = np.array(data)
                y = one_hot(np.array(target), num_classes)
                params, opt_state, loss = update(params, x, y, opt_state)
                train_loss.append(loss)

            epoch_time = time.time() - start_time
            train_acc = accuracy(params, train_dataset)
            test_acc = accuracy(params, test_dataset)
            log_acc_train.append(train_acc)
            log_acc_test.append(test_acc)
            print("Epoch {} | T: {:0.2f} | Train A: {:0.3f} | Test A: {:0.3f}".format(epoch+1, epoch_time,
                                                                        train_acc, test_acc))

        return train_loss, log_acc_train, log_acc_test


    train_loss, train_log, test_log = run_mnist_training_loop(num_epochs,
                                                            opt_state,
                                                            net_type="MLP")

    # Plot the loss curve over time
    from helpers import plot_mnist_performance
    plot_mnist_performance(train_loss, train_log, test_log,
                        "MNIST MLP Performance")