In [1]:
# third party
import jax
import matplotlib.pyplot as plt
import numpy as np

# syft absolute
import syft as sy

## 1. DS logins to the domain with the credentials created by the DO

In [2]:
node = sy.orchestra.launch(name="esol-domain", dev_mode=True)
ds_client = node.login(email="sheldon@caltech.edu", password="changethis")

### Inspect the datasets on the domain

In [3]:
datasets = ds_client.datasets.get_all()
assert len(datasets) == 1
datasets

In [4]:
assets = datasets[0].assets
assert len(assets) == 5
assets

In [5]:
y = assets[0]
#y

In [6]:
x = assets[1]
#x

In [7]:
x_masks = assets[2]
#x_masks

In [8]:
edge_index = assets[3]
#edge_index

In [9]:
edge_index_masks = assets[4]
edge_index_masks

#### The DS can not access the real data

In [13]:
assert x.data is None
assert edge_index.data is None
assert y.data is None

#### The DS can only access the mock data, which is some random noise

In [16]:
x_mock = x.mock
edge_index_mock = edge_index.mock
y_mock = y.mock
print(x_mock.shape)
print(edge_index_mock.shape)
print(y_mock.shape)

#### We need the pointers to the mock data to construct a `syft` function (later in the notebook)

In [21]:
x_mock_ptr = x.pointer
#x_mock_ptr

In [18]:
type(x_mock)

In [22]:
edge_index_mock = edge_index.mock
edge_index_mock_ptr = edge_index.pointer
#edge_index_mock_ptr

In [24]:
y_mock = y.mock
y_mock_ptr = y.pointer
#y_mock_ptr

In [None]:
# @todo: recreate the DATA as it was in the 00-data-owner-upload-dat... then change "in_channels=x.shape[-1]," to "in_channels=DATA.num_features,"

## 2. The DS prepare the training code and experiment on the mock data

In [27]:
import torch
from torch.nn import Linear
import torch.nn.functional as F 
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

EMBEDDING_DIM = 64

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(42)

        self.initial_conv = GCNConv( # The graph convolutional operator from the “Semi-supervised Classification with Graph Convolutional Networks” paper
          in_channels=x_mock.shape[-1], # number of features per node of graph before transformation
          out_channels=EMBEDDING_DIM # number of features per node of graph after transformation
        )
        self.conv1 = GCNConv(EMBEDDING_DIM, EMBEDDING_DIM)
        self.conv2 = GCNConv(EMBEDDING_DIM, EMBEDDING_DIM)
        self.conv3 = GCNConv(EMBEDDING_DIM, EMBEDDING_DIM)
        self.out = Linear(
          in_features=EMBEDDING_DIM*2, # we stack the different global pooling aggregations below
          out_features=1
        )

    def forward(self, x, edge_index, batch_index):
        # First Conv layer
        hidden = self.initial_conv(x, edge_index)
        hidden = F.tanh(hidden)

        # Other Conv layers
        hidden = self.conv1(hidden, edge_index)
        hidden = F.tanh(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = F.tanh(hidden)
        hidden = self.conv3(hidden, edge_index)
        hidden = F.tanh(hidden)
          
        # Global Pooling (stack different aggregations over nodes of graph)
        hidden = torch.cat([gmp(hidden, batch_index), 
                            gap(hidden, batch_index)], dim=1)

        # Apply a final (linear) classifier.
        out = self.out(hidden)

        return out, hidden

MODEL = GCN()
print(MODEL)
print("Number of parameters: ", sum(p.numel() for p in MODEL.parameters()))

In [30]:
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

NUM_GRAPHS_PER_BATCH = 64

# Use GPU for training (if available)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [33]:
DATA_ = {
 "x": x,
 "edge_index": edge_index,
 "y": y,
}


In [34]:
def train(model, data):
  model = model.to(DEVICE)

  loss_fn = torch.nn.MSELoss()

  optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)

  data_size = x.shape[0]
  train_loader = DataLoader(
    data[:int(data_size * 0.8)], 
    batch_size=NUM_GRAPHS_PER_BATCH, 
    shuffle=True
  )

  for batch in train_loader:
    # Use GPU
    batch.to(DEVICE)  
    # Reset gradients
    optimizer.zero_grad() 
    # Passing the node features and the edge info
    pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch) 
    # Calculating the loss and gradients
    loss = loss_fn(pred, batch.y)     
    loss.backward()  
    # Update using the gradients
    optimizer.step()   
  return loss, embedding

def train_wrapper():
  print("Starting training...")
  losses = []
  for epoch in range(2000): 
      loss, h = train(MODEL, DATA_)
      losses.append(loss)
      if epoch % 100 == 0: 
        print(f"Epoch {epoch} | Train Loss {loss}")
  return losses 

LOSSES = train_wrapper()

In [None]:
train_accs, params = mnist_3_linear_layers(
    mnist_images=mock_images, mnist_labels=mock_labels
)

#### Inspect the training accuracies and the shape of the model's parameters

In [None]:
train_accs

In [None]:
jax.tree_map(lambda x: x.shape, params)

## 3. Now that the code works on mock data, the DS submits the code request for execution to the DO

#### First the DS wraps the training function with the `@sy.syft_function` decorator

In [None]:
@sy.syft_function(
    input_policy=sy.ExactMatch(
        mnist_images=mock_images_ptr, mnist_labels=mock_labels_ptr
    ),
    output_policy=sy.SingleExecutionExactOutput(),
)
def mnist_3_linear_layers(mnist_images, mnist_labels):
    # import the packages
    # stdlib
    import itertools
    import time

    # third party
    from jax import grad
    from jax import jit
    from jax import random
    from jax.example_libraries import optimizers
    from jax.example_libraries import stax
    from jax.example_libraries.stax import Dense
    from jax.example_libraries.stax import LogSoftmax
    from jax.example_libraries.stax import Relu
    import jax.numpy as jnp
    import numpy.random as npr

    # define the neural network
    init_random_params, predict = stax.serial(
        Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax
    )

    # initialize the random parameters
    rng = random.PRNGKey(0)
    _, init_params = init_random_params(rng, (-1, 784))

    # the hyper parameters
    num_epochs = 10
    batch_size = 4
    num_train = mnist_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)
    step_size = 0.001
    momentum_mass = 0.9

    # initialize the optimizer
    opt_init, opt_update, get_params = optimizers.momentum(
        step_size, mass=momentum_mass
    )
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    def data_stream():
        """
        Create a batch of data picked randomly
        """
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield mnist_images[batch_idx], mnist_labels[batch_idx]

    def loss(params, batch):
        inputs, targets = batch
        preds = predict(params, inputs)
        return -jnp.mean(jnp.sum(preds * targets, axis=1))

    def accuracy(params, batch):
        inputs, targets = batch
        target_class = jnp.argmax(targets, axis=1)
        predicted_class = jnp.argmax(predict(params, inputs), axis=1)
        return jnp.mean(predicted_class == target_class)

    batches = data_stream()
    train_accs = []
    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_time = time.time() - start_time
        params = get_params(opt_state)
        train_acc = accuracy(params, (mnist_images, mnist_labels))
        print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
        print(f"Training set accuracy {train_acc}")
        train_accs.append(train_acc)

    return train_accs, params

#### Then the DS creates a new project with relevant name and description, as well as specify itself as a member of the project

In [None]:
new_project = sy.Project(
    name="Training a 3-layer jax neural network on MNIST data",
    description="""Hi, I would like to train my neural network on your MNIST data 
                (I can download it online too but I just want to use Syft coz it's cool)""",
    members=[ds_client],
)
new_project

#### Add a code request to the project

In [None]:
new_project.create_code_request(obj=mnist_3_linear_layers, client=ds_client)

In [None]:
ds_client.code

#### Start the project which will notifies the DO

In [None]:
project = new_project.send()

In [None]:
project.events

In [None]:
project.requests

In [None]:
project.requests[0]

### 📓 Now switch to the [second DO's notebook](./02-data-owner-review-approve-code.ipynb)