# Introduction to TrainConfig

### Context 

> Warning: This is still experimental and may change during June / July 2019

We introduce here the TrainConfig abstraction, a serializible wrapper to the usual setup used to run federated training: a model, a loss function, an optimizer type and training hyper parameters (batch_size, lr, ...).

The main reason why using TrainConfig is to set the limits between a worker (that holds private data and performs training) and another worker that acts as a scheduler (knowns workers, has a model and demands training from this workers).

Authors:
- Marianne Monteiro - Twitter [@hereismari](https://twitter.com/hereismari) - GitHub: [@mari-linhares](https://github.com/mari-linhares)
- Silvia - GitHub [@midokura-silvia](https://github.com/midokura-silvia)
- Bobby Wagner - Twitter [@bobbyawagner](https://twitter.com/bobbyawagner) - GitHub: [@robert-wagner](https://github.com/robert-wagner)

## Remote Training on a Federate Learning setup

For a Federated Learning setup with TrainConfig we consider at least two participants:

* A worker that owns a dataset.

* An entity that knows the workers and the dataset name that lives in each worker. We'll call this a scheduler.




### Create a worker

Let's create a remote worker that holds some data!

#### Preparation: Start the websocket worker

First, we need to create a remote worker, we'll call it alice. For this, you need to run in a terminal (not possible from the notebook):

```bash
python start_worker.py --port 8777 --id alice
```

#### What's going on?

Let's have a look at the main function of `start_worker.py`:

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import inspect
import start_worker

print(inspect.getsource(start_worker.main))

This script creates a worker and populate it with some toy data using `worker.add_dataset`, the dataset is identified by a key in this case `xor`.

The scheduler needs to know the worker (alice) and its dataset (xor) so it can say: "hey alice, here is a TrainConfig definition could you train using dataset `xor`?"

We can add multiple datasets to a single worker.


### Setting up a scheduler

We'll use this notebook as a scheduler, for this we'll need to:

* Have a model
* Have a loss function
* Define an optimizer
* Define hyper-parameters

In [None]:
# Dependencies
import torch as th
import torch.nn.functional as F
from torch import nn

use_cuda = th.cuda.is_available()
th.manual_seed(1)
device = th.device("cuda" if use_cuda else "cpu")

import syft as sy
from syft import workers

hook = sy.TorchHook(th)  # hook torch as always :)

### Model

A model for TrainConfig is a regular torch model with a significant difference: it needs to be serializable. 

Given that, we can turn a regular torch model into a [jit](https://pytorch.org/docs/stable/jit.html) module. Jit modules use Torchscript.

> Torchsript creates serializable and optimizable models from PyTorch code. Any code written in TorchScript can be saved from a Python process and loaded in a process where there is no Python dependency. This facility will allow us to send this model to remote workers. - [jit documentation](https://pytorch.org/docs/stable/jit.html)

We can turn a regular module into a jit module using `th.jit.trace`. First we can implement a regular torch model.

In [None]:
class Net(th.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 20)
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Then we can trace it using `th.jit.trace` using some mock data.

In [None]:
# Instantiate the model
model = Net()

# The data itself doesn't matter as long as the shape is right
mock_data = th.zeros(1, 2)

# Create a jit version of the model
traced_model = th.jit.trace(model, mock_data)

type(traced_model)

### Loss function

The same applies to the loss function, it needs to be serializable. We can define a usual function just changing it to use jit. We can trace the function the same way we need for models or we can use a function decorator called `th.jit.script`.

You can read more about jit trace and jit script in the [pytorch jit documentation](https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting).

In [None]:
# Loss function
@th.jit.script
def loss_fn(target, pred):
    return ((target.view(pred.shape).float() - pred.float()) ** 2).mean()

type(loss_fn)

### Optimizer

Just say which one you want to use (for now only built in torch optimizers are available).

In [None]:
optimizer = "SGD"

### General hyper parameters and training options

TrainConfig currently supports:
* batch_size
* Optimizer_args: A dict of args used to initialize the optimizer
* epochs
* max_nr_batches: Maximum number of training steps that will be performed. For large datasets this can be used to run for less than the number of epochs provided.
* shuffle

In [None]:
batch_size = 4
optimizer_args = {"lr" : 0.1, "weight_decay" : 0.01}
epochs = 1
max_nr_batches = -1  # not used in this example
shuffle = True

## Create a TrainConfig

TrainConfig is just a wrapper to all we defined for the scheduler, creating a train config consists only of sendin

In [None]:
train_config = sy.TrainConfig(model=traced_model,
                              loss_fn=loss_fn,
                              optimizer=optimizer,
                              batch_size=batch_size,
                              optimizer_args=optimizer_args,
                              epochs=epochs,
                              shuffle=shuffle)

## Run training remotely

Now that we have a TrainConfig instance, we can just send it to a remote worker and the worker will know how it should execute training (which model, loss function, optimizer, ... to use).

### Connect to remote worker


We'll connect to the worker (alice) that we initiated at the beginning of the tutorial. We'll instantiate a websocket client, our local access point (proxy) to the remote worker.
Note that **this step will fail if the worker is not running**.

In [None]:
kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": False}
alice = workers.websocket_client.WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)

###  Send TrainConfig to worker

In [None]:
# Send train config
train_config.send(alice)

Now we can execute remote training using our TrainConfig instance!

### Training remotely with TrainConfig

First let's evaluate our model before training.

In [None]:
# Setup toy data (xor example)
data = th.tensor([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.0, 0.0]], requires_grad=True)
target = th.tensor([[1.0], [1.0], [0.0], [0.0]], requires_grad=False)

print("\nEvaluation before training")
pred = model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

We can now train the model on alice's data.

We know that alice has a dataset identified by "xor", so let's ask it to train using this data. Alice knows how to train because we already said it in the TrainConfig.

In [None]:
for epoch in range(10):
    loss = alice.fit(dataset_key="xor")  # ask alice to train using "xor" dataset
    print("-" * 50)
    print("Iteration %s: alice's loss: %s" % (epoch, loss))

In [None]:
new_model = train_config.model_ptr.get()

print("\nEvaluation after training:")
pred = new_model(data)
loss = loss_fn(target=target, pred=pred)
print("Loss: {}".format(loss))
print("Target: {}".format(target))
print("Pred: {}".format(pred))

### Star PySyft on GitHub

The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.

- [Star PySyft](https://github.com/OpenMined/PySyft)

### Pick our tutorials on GitHub!

We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.

- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)


### Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! 

- [Join slack.openmined.org](http://slack.openmined.org)

### Join a Code Project!

The best way to contribute to our community is to become a code contributor! If you want to start "one off" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.

- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)

### Donate

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)