# Neural Network with SPU

>  Please read lab `Logistic Regression On SPU` first if you have not。

In lab `Logistic Regression On SPU`, we have showed how to use SecretFlow/SPU to convert a plaintext JAX training program to a secrue MPC trainning program.

In this lab, the idea is quite similar but this time we will work with a Neural Network model.

We are going to use the same dataset and all the settings as lab `Logistic Regression On SPU`.

And first, let's work out the plaintext model.

*The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production.*

## Train a model with JAX/FLAX

### Load the Dataset

The below is just copied from lab `Logistic Regression On SPU`. I'm not going to explain again.

In [1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def load_train_dataset(party_id=None) -> (np.ndarray, np.ndarray):
    features, label = load_breast_cancer(return_X_y=True)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)
    X_train, _, y_train, _ = train_test_split(
        features, label, test_size=0.8, random_state=42
    )

    if party_id:
        if party_id == 1:
            return X_train[:, 15:], _
        else:
            return X_train[:, :15], y_train
    else:
        return X_train, y_train


def load_test_dataset():
    features, label = load_breast_cancer(return_X_y=True)
    scaler = StandardScaler()
    features = scaler.fit_transform(features)
    _, X_test, _, y_test = train_test_split(
        features, label, test_size=0.8, random_state=42
    )
    return X_test, y_test


### Define the Model


We are going to use a 4-layer [MLP](https://en.wikipedia.org/wiki/Multilayer_perceptron) model with a [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) activation function here.

In [2]:
from jax.example_libraries import stax
from jax.example_libraries.stax import (
    Dense,
    Relu,
)


def MLP():
    nn_init, nn_apply = stax.serial(
        Dense(30),
        Relu,
        Dense(15),
        Relu,
        Dense(8),
        Relu,
        Dense(1),
    )

    return nn_init, nn_apply


Then we define the trainning method here.

In [3]:
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers, stax


KEY = jax.random.PRNGKey(42)
INPUT_SHAPE = (-1,30)

def init_state(learning_rate):
    init_fun, _ = MLP()
    _, params_init = init_fun(KEY, INPUT_SHAPE)
    opt_init, _, _ = optimizers.sgd(learning_rate)
    opt_state = opt_init(params_init)
    return opt_state

def train(
    train_x1,
    train_x2,
    train_y,
    opt_state,
    learning_rate,
    epochs,
    batch_size,
):
    train_x = jnp.concatenate([train_x1, train_x2], axis=1)

    _, predict_fun = MLP()
    _, opt_update, get_params = optimizers.sgd(learning_rate)

    def update_model(state, imgs, labels, i):
        def mse(y, pred):
            return jnp.mean(jnp.multiply(y - pred, y - pred) / 2.0)

        def loss_func(params):
            y = predict_fun(params, imgs)
            return mse(y, labels), y

        grad_fn = jax.value_and_grad(loss_func, has_aux=True)
        (loss, y), grads = grad_fn(get_params(state))
        return opt_update(i, grads, state)

    for i in range(1, epochs + 1):
        imgs_batchs = jnp.array_split(train_x, len(train_x) / batch_size, axis=0)
        labels_batchs = jnp.array_split(train_y, len(train_y) / batch_size, axis=0)

        for batch_idx, (batch_images, batch_labels) in enumerate(
            zip(imgs_batchs, labels_batchs)
        ):
            opt_state = update_model(opt_state, batch_images, batch_labels, i)
    return get_params(opt_state)




### Validate the Model

We use AUC as the validation metric.

In [4]:
from sklearn.metrics import roc_auc_score


def validate_model(params, X_test, y_test):
    _, predict_fun = MLP()
    y_pred = predict_fun(params, X_test)
    return roc_auc_score(y_test, y_pred)


### BUILD Together

Let's put everything together and train a plaintext NN model!

In [5]:
import jax

# Load the data
x1, _ = load_train_dataset(party_id=1)
x2, y = load_train_dataset(party_id=2)


# Hyperparameter
batch_size = 5
epochs = 15
learning_rate = 0.1


# Load the data
x1, _ = load_train_dataset(party_id=1)
x2, y = load_train_dataset(party_id=2)

init_params = init_state(learning_rate)

params = train(x1, x2, y, init_params,learning_rate, epochs, batch_size)

print(params)

X_test, y_test = load_test_dataset()
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')


[(DeviceArray([[ 0.05228749,  0.2254616 ,  0.14233142,  0.2218795 ,
               0.14313836, -0.11788452, -0.25514612,  0.05024785,
               0.19256376, -0.16702715,  0.1360347 ,  0.09218568,
              -0.14514978, -0.00203137,  0.1026915 ,  0.16864221,
              -0.19459064, -0.1680608 , -0.0291565 ,  0.24780752,
              -0.10630313,  0.19863537, -0.07401107, -0.24004339,
              -0.11391521,  0.3779385 ,  0.1651407 ,  0.0792419 ,
              -0.21970308, -0.07381143],
             [-0.23652235,  0.3905772 , -0.0325642 ,  0.16153163,
               0.03559006, -0.34141707,  0.36782634,  0.35197827,
              -0.03010983,  0.17985871,  0.1022082 ,  0.15198684,
               0.20280722, -0.15535937, -0.3000578 ,  0.28659174,
              -0.2538568 , -0.10807706,  0.02643047, -0.01051225,
              -0.23814394, -0.11327465, -0.2019997 ,  0.03128352,
              -0.046227  ,  0.13949452, -0.05219001,  0.10902884,
               0.134705  , -0.110

You may complain AUC here is quite low. But in the end, we are not an AI course. Just keep the number in mind, we are going to repeat the trainning with SPU. Let's do that magic!


## Train a Model with SPU

In [6]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob'], num_cpus=8, log_to_driver=True)

alice, bob = sf.PYU('alice'), sf.PYU('bob')
spu = sf.SPU(sf.utils.testing.cluster_def(['alice', 'bob']))

x1, _ = alice(load_train_dataset)(party_id=1)
x2, y = bob(load_train_dataset)(party_id=2)


device = spu
x1_, x2_, y_ = x1.to(device), x2.to(device), y.to(device)
init_params_ = sf.to(spu, lambda: init_state(learning_rate))

params_spu = spu(train, static_argnames=['learning_rate', 'epochs', 'batch_size'])(
    x1_, x2_, y_,init_params_, learning_rate=learning_rate, epochs=epochs, batch_size=batch_size
)


2022-06-16 13:01:25.635121: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib


Let's check params from SPU program.

In [7]:
params = sf.reveal(params_spu)
print(params)


[2m[36m(_run pid=2460284)[0m 2022-06-16 13:01:30.776125: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib
[2m[36m(_run pid=2460288)[0m 2022-06-16 13:01:30.800513: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/dev

[(array([[ 5.21224737e-02,  2.25653112e-01,  1.42217025e-01,
         2.22038493e-01,  1.43407062e-01, -1.17939964e-01,
        -2.54890293e-01,  5.02917618e-02,  1.92605436e-01,
        -1.66738570e-01,  1.35716438e-01,  9.23201591e-02,
        -1.45163283e-01, -2.06950307e-03,  1.02719918e-01,
         1.68477565e-01, -1.94680765e-01, -1.68111518e-01,
        -2.91620046e-02,  2.47961149e-01, -1.06470272e-01,
         1.98481962e-01, -7.39427507e-02, -2.39929244e-01,
        -1.13915399e-01,  3.78076851e-01,  1.65145919e-01,
         7.91657716e-02, -2.19707474e-01, -7.42163062e-02],
       [-2.36553714e-01,  3.90677869e-01, -3.25132608e-02,
         1.61681026e-01,  3.65113467e-02, -3.41343939e-01,
         3.67962658e-01,  3.51848811e-01, -3.00575048e-02,
         1.80104598e-01,  1.01788193e-01,  1.52144164e-01,
         2.02847883e-01, -1.55509651e-01, -3.00182939e-01,
         2.86474586e-01, -2.53845513e-01, -1.08098716e-01,
         2.64337212e-02, -1.04821473e-02, -2.38679454

Lastly, let's validate the model.

In [8]:
auc = validate_model(params, X_test, y_test)
print(f'auc={auc}')


auc=0.7982343165766513


This is the end of the lab.