In [None]:
import syft as sy
import jax
from jax.example_libraries import stax
from jax import random
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
import jax.numpy as jnp
from pprint import pprint
from mnist_dataset import mnist

## 1. DS logins to the domain with the credentials created by the DO

In [None]:
node = sy.orchestra.launch(name="dk-domain", dev_mode=True)
ds_client = node.login(email="sheldon@caltech.edu", password="changethis")

### Inspect the datasets on the domain

In [None]:
datasets = ds_client.datasets.get_all()
datasets

In [None]:
assets = datasets[0].assets
assets

In [None]:
training_images = assets[0]
training_images

In [None]:
training_labels = assets[1]
training_labels

#### The DS can not access the real data

In [None]:
training_images.data

In [None]:
assert training_images.data == None

#### The DS can only access the mock data, which is some random noise

In [None]:
mock_images = training_images.mock

import matplotlib.pyplot as plt
import numpy as np
plt.imshow(np.reshape(mock_images[0], (28, 28)))

#### We need the pointers to the mock data to construct a `syft` function (later in the notebook)

In [None]:
mock_images_ptr = training_images.pointer
mock_images_ptr

In [None]:
type(mock_images_ptr)

In [None]:
mock_labels = training_labels.mock
mock_labels_ptr = training_labels.pointer
mock_labels_ptr

## 2. The DS prepare the training code and experiment on the mock data

In [None]:
def mnist_3_linear_layers(mnist_images, mnist_labels):
    # import the packages
    import jax
    from jax.example_libraries import stax
    from jax.example_libraries.stax import Dense, Relu, LogSoftmax
    import time
    from jax.example_libraries import optimizers
    import itertools
    import jax.numpy as jnp
    import numpy.random as npr
    from jax import jit, grad, random
    
    # define the neural network
    init_random_params, predict = stax.serial(
        Dense(1024), Relu,
        Dense(1024), Relu,
        Dense(10), LogSoftmax)
    
    # initialize the random parameters
    rng = random.PRNGKey(0)
    _, init_params = init_random_params(rng, (-1, 784))
    
    # the hyper parameters
    num_epochs = 10
    batch_size = 4
    num_train = mnist_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)
    step_size = 0.001
    momentum_mass = 0.9
    
    # initialize the optimizer
    opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)
    
    def data_stream():
        """
        Create a batch of data picked randomly 
        """
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield mnist_images[batch_idx], mnist_labels[batch_idx]
        
    def loss(params, batch):
        inputs, targets = batch
        preds = predict(params, inputs)
        return -jnp.mean(jnp.sum(preds * targets, axis=1))


    def accuracy(params, batch):
        inputs, targets = batch
        target_class = jnp.argmax(targets, axis=1)
        predicted_class = jnp.argmax(predict(params, inputs), axis=1)
        return jnp.mean(predicted_class == target_class)
    
    batches = data_stream()
    train_accs = []
    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time
        params = get_params(opt_state)
        train_acc = accuracy(params, (mnist_images, mnist_labels))
        print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
        print(f"Training set accuracy {train_acc}")
        train_accs.append(train_acc)
    
    return train_accs, params

In [None]:
train_accs, params = mnist_3_linear_layers(mnist_images=mock_images, mnist_labels=mock_labels)

In [None]:
train_accs

In [None]:
# shape of the model's parameters
jax.tree_map(lambda x: x.shape, params)

## 3. Now that the code works on mock data, the DS submits the code request for execution to the DO

#### First the DS wraps the training function with the `@sy.syft_function` decorator

In [None]:
@sy.syft_function(input_policy=sy.ExactMatch(mnist_images=mock_images_ptr, mnist_labels=mock_labels_ptr),
                  output_policy=sy.SingleExecutionExactOutput())
def mnist_3_linear_layers(mnist_images, mnist_labels):
    # import the packages
    import jax
    from jax.example_libraries import stax
    from jax.example_libraries.stax import Dense, Relu, LogSoftmax
    import time
    from jax.example_libraries import optimizers
    import itertools
    import jax.numpy as jnp
    import numpy.random as npr
    from jax import jit, grad, random
    
    # define the neural network
    init_random_params, predict = stax.serial(
        Dense(1024), Relu,
        Dense(1024), Relu,
        Dense(10), LogSoftmax)
    
    # initialize the random parameters
    rng = random.PRNGKey(0)
    _, init_params = init_random_params(rng, (-1, 784))
    
    # the hyper parameters
    num_epochs = 10
    batch_size = 4
    num_train = mnist_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)
    step_size = 0.001
    momentum_mass = 0.9
    
    # initialize the optimizer
    opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)
    
    def data_stream():
        """
        Create a batch of data picked randomly 
        """
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield mnist_images[batch_idx], mnist_labels[batch_idx]
        
    def loss(params, batch):
        inputs, targets = batch
        preds = predict(params, inputs)
        return -jnp.mean(jnp.sum(preds * targets, axis=1))


    def accuracy(params, batch):
        inputs, targets = batch
        target_class = jnp.argmax(targets, axis=1)
        predicted_class = jnp.argmax(predict(params, inputs), axis=1)
        return jnp.mean(predicted_class == target_class)
    
    batches = data_stream()
    train_accs = []
    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time
        params = get_params(opt_state)
        train_acc = accuracy(params, (mnist_images, mnist_labels))
        print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
        print(f"Training set accuracy {train_acc}")
        train_accs.append(train_acc)

    return train_accs, params

#### Then the DS creates a new project with relevant name and description, as well as specify itself as a member of the project

In [None]:
new_project = sy.Project(
    name="Training a 3-layer jax neural network on MNIST data",
    description="""Hi, I would like to train my neural network on your MNIST data 
                (I can download it online too but I just want to use Syft coz it's cool)""",
    members=[ds_client],
) 
new_project

#### Add a code request to the project

In [None]:
new_project.create_code_request(obj=mnist_3_linear_layers, client=ds_client)

In [None]:
ds_client.code

#### Start the project which will notifies the DO

In [None]:
project = new_project.start()

In [None]:
project.events

In [None]:
project.requests

In [None]:
project.requests[0]

### 📓 Now switch back to the [DO's notebook](./Data%20Owner%20(DO).ipynb) at step 2

## 4. After the DO has ran the code and deposited the results, the DS gets them

In [None]:
ds_client.code

In [None]:
project.requests

In [None]:
result = ds_client.code.mnist_3_linear_layers(mnist_images=training_images, mnist_labels=training_labels)

In [None]:
result

In [None]:
train_accs, params = result.get_from(ds_client)

In [None]:
pprint(train_accs)

In [None]:
jax.tree_map(lambda x: x.shape, params)

## 5. The DS gets the trained weights and do inference on the MNIST test dataset

In [None]:
_, _, test_images, test_labels = mnist()

In [None]:
print(f"{test_images.shape = }")
print(f"{test_labels.shape = }")

In [None]:
# define the neural network
init_random_params, predict = stax.serial(
    Dense(1024), Relu,
    Dense(1024), Relu,
    Dense(10), LogSoftmax)

In [None]:
def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

#### Test inference using random weights

In [None]:
rng = random.PRNGKey(0)
_, random_params = init_random_params(rng, (-1, 28 * 28))

test_acc = accuracy(random_params, (test_images, test_labels))
print(f"Test set accuracy with random weights {test_acc * 100 : .2f}%")

#### Test inference using the trained weights recevied from the DO

In [None]:
test_acc = accuracy(params, (test_images, test_labels))
print(f"Test set accuracy with trained weights {test_acc * 100 : .2f}%")