# Federated Learning Training Plan: Create Plan

Let's try to make protobuf-serializable Training Plan and Model that work after deserializing :)

Current list of problems:
 * No support for autograd in Plan tracing (.backward() doesn't work inside the Plan).
 * `tensor.shape` value seem to be recorded as constant during Plan tracing, so we need to pass `batch_size`, can't take it from tensor itself.
 * Plan needs a list of all Model params in the argument list, it would be nicer if this list is figured out automatically so you just pass the Model (not sure it's solvable jit might not accept the model as ScriptModule input?)
 * Since loops aren't supported inside Plan, working with Model params (i.e. weights update) is awkward. 
 * others? 


In [1]:
%load_ext autoreload
%autoreload 2
import syft as sy
import torch as th
from torch import jit
from syft.serde import protobuf
import base64
import os
from syft.messaging.plan.state import State
from syft.frameworks.torch.tensors.interpreters.placeholder import PlaceHolder


sy.hook(globals())
th.random.manual_seed(1)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was 'C:\Users\Vova\AppData\Local\conda\conda\envs\pysyft\lib\site-packages\tf_encrypted/operations/secure_random/secure_random_module_tf_1.13.1.so'


Setting up Sandbox...
Done!


<torch._C.Generator at 0x2b0896b3cb0>

This utility function will serialize any object to protobuf binary, 
encode to base64 and save to a file.

In [2]:
def serializeToBase64Pb(worker, obj, filename):
    pb = protobuf.serde._bufferize(worker, obj)
    bin = pb.SerializeToString()
    b64 = base64.b64encode(bin).decode('utf-8')
    print("Writing %s to %s/%s" % (obj.__class__.__name__, os.getcwd(), filename))
    with open(filename, "w") as f:
        f.write(b64)

## Step 1: Define the model

This model will train on MNIST data, it's very simple yet can demonstrate learning process.
There're 2 linear layers: 

* Linear 784x392
* ReLU
* Linear 392x10 

Not using nn.Module or nn.Linear for now, just vanilla class and tensors.
No autograd, gradients are hand-coded. 

As no loops supported inside Plan, 
we can't iterate over parameters so everything that works with params
(get/set, step) is moved into the model to make Plan code more generic.  

In [3]:
class Net():
    def __init__(self):
        super(Net, self).__init__()
        self.W1 = th.randn(392, 784) / th.sqrt(th.tensor(784.))
        self.b1 = th.zeros(392)
        self.W2 = th.randn(10, 392) / th.sqrt(th.tensor(392.))
        self.b2 = th.zeros(10)
        self.update_fn = None

    def forward(self, X):
        self.Z1 = X @ self.W1.t() + self.b1
        self.A1 = th.nn.functional.relu(self.Z1)
        self.Z2 = self.A1 @ self.W2.t() + self.b2
        return self.Z2

    def get_params(self):
        return self.W1, self.b1, self.W2, self.b2

    def set_params(self, *model_params):
        self.W1, self.b1, self.W2, self.b2 = model_params

    def grad(self, X, error):
        Z1_grad = (error @ self.W2) * (self.Z1 > 0).float()
        W1_grad = Z1_grad.t() @ X
        b1_grad = Z1_grad.sum(0)
        W2_grad = error.t() @ self.A1 
        b2_grad = error.sum(0)
        return W1_grad, b1_grad, W2_grad, b2_grad

    def update_weights(self, *grads, **update_kwargs):
        W1_grad, b1_grad, W2_grad, b2_grad = grads
        self.update_fn(self.W1, W1_grad, **update_kwargs)
        self.update_fn(self.b1, b1_grad, **update_kwargs)
        self.update_fn(self.W2, W2_grad, **update_kwargs)
        self.update_fn(self.b2, b2_grad, **update_kwargs)

model = Net()

## Step 2: Define Training Plan
### Loss function 
Batch size needs to be passed because otherwise `target.shape[0]` will be saved as `1` constant during Plan trace with dummy data.
Grad is also returned here. 

In [4]:
def cross_entropy_with_logits(output, target, batch_size):
    probs = th.nn.functional.softmax(output, dim=1)
    loss = -(target * th.log(probs)).mean()
    loss_grad = (probs - target) / (batch_size * target.shape[1])
    return probs, loss, loss_grad

### Optimization function
 
Just updates weights with grad*lr.

Adding it into a Model is a way to decouple
Plan from dealing with param update directly,
so that only Model knows about its weights.

In [5]:
def naive_sgd(param, grad, **kwargs):
    param.add_(-grad * kwargs['lr'])

model.update_fn = naive_sgd
   

### Training Plan procedure

In [6]:
model_params = model.get_params()

# define plan input dimensions
X_size = (-1, 784)
y_size = (-1, 10)
scalar_size = (1,)
model_params_shapes = [p.shape for p in model_params]

args_shape = [
    X_size,  # X
    y_size,  # y
    scalar_size,  # batch_size
    scalar_size,  # lr
    *model_params_shapes  # *model_params
]

@sy.func2plan(args_shape=args_shape)
def training_plan(X, y, batch_size, lr, *model_params):
    # inject params into model
    model.set_params(*model_params)

    # forward pass
    output = model.forward(X)
    
    # loss
    probs, loss, loss_grad = cross_entropy_with_logits(output, y, batch_size)

    # backprop
    grads = model.grad(X, loss_grad)

    # step
    model.update_weights(*grads, lr=lr)

    # accuracy
    pred = th.argmax(probs, dim=1)
    target = th.argmax(y, dim=1)
    acc = pred.eq(target).float().sum() / batch_size

    return (
        *model_params,
        loss,
        acc,
    )

# check that operations look good
print(training_plan.operations)

[(Operation (('t', PlaceHolder[Tags:#1 #input-4], (), {}), PlaceHolder[Tags:#2])), (Operation (('__matmul__', PlaceHolder[Tags:#3 #input-0], (PlaceHolder[Tags:#2],), {}), PlaceHolder[Tags:#4])), (Operation (('__add__', PlaceHolder[Tags:#4], (PlaceHolder[Tags:#5 #input-5],), {}), PlaceHolder[Tags:#6])), (Operation (('torch.nn.functional.relu', None, (PlaceHolder[Tags:#6],), {}), PlaceHolder[Tags:#7])), (Operation (('t', PlaceHolder[Tags:#input-6 #8], (), {}), PlaceHolder[Tags:#9])), (Operation (('__matmul__', PlaceHolder[Tags:#7], (PlaceHolder[Tags:#9],), {}), PlaceHolder[Tags:#10])), (Operation (('__add__', PlaceHolder[Tags:#10], (PlaceHolder[Tags:#input-7 #11],), {}), PlaceHolder[Tags:#12])), (Operation (('torch.nn.functional.softmax', None, (PlaceHolder[Tags:#12],), {'dim': 1}), PlaceHolder[Tags:#13])), (Operation (('torch.log', None, (PlaceHolder[Tags:#13],), {}), PlaceHolder[Tags:#14])), (Operation (('__mul__', PlaceHolder[Tags:#15 #input-1], (PlaceHolder[Tags:#14],), {}), PlaceHol

## Step 3: JIT Trace Training Plan

Note: Plan expects everything to be a tensor.


In [7]:
X_trace = th.randn(2, 784)
y_trace = th.randn(2, 10)
lr = th.tensor(0.001)
batch_size = th.tensor(32)

# jit trace
training_plan_torchscript = th.jit.trace(training_plan.__call__, (X_trace, y_trace, batch_size, lr, *model_params))

# Let's see
print(training_plan_torchscript.code)

def __call__(argument_0: Tensor,
    argument_1: Tensor,
    argument_2: Tensor,
    argument_3: Tensor,
    argument_4: Tensor,
    argument_5: Tensor,
    argument_6: Tensor,
    argument_7: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  _0 = torch.matmul(argument_0, torch.t(argument_4))
  _1 = torch.add(_0, argument_5, alpha=1)
  _2 = torch.relu(_1)
  _3 = torch.add(torch.matmul(_2, torch.t(argument_6)), argument_7, alpha=1)
  _4 = torch.softmax(_3, 1, None)
  _5 = torch.mean(torch.mul(argument_1, torch.log(_4)), dtype=None)
  _6 = torch.neg(_5)
  _7 = ops.prim.NumToTensor(torch.size(argument_1, 1))
  _8 = torch.div(torch.sub(_4, argument_1, alpha=1), torch.mul(argument_2, _7))
  _9 = torch.matmul(_8, argument_6)
  _10 = torch.to(torch.gt(_1, 0), 6, False, False, None)
  _11 = torch.mul(_9, _10)
  _12 = torch.matmul(torch.t(_11), argument_0)
  _13 = torch.sum(_11, [0], False, dtype=None)
  _14 = torch.matmul(torch.t(_8), _2)
  _15 = torch.sum(_8, [0], False, dty

## Step 4: Serialize!

Note that we don't serialize full Model, only weights.
State is suitable protobuf class to wrap list of Model params tensors. 

In [8]:
serializeToBase64Pb(hook.local_worker, training_plan, "tp_ops.b64")
serializeToBase64Pb(hook.local_worker, training_plan_torchscript, "tp_ts.b64")

# wrap weights in State to serialize
model_params_state = State(
    owner=hook.local_worker,
    state_placeholders=[PlaceHolder().instantiate(param) for param in model_params]
)

serializeToBase64Pb(hook.local_worker, model_params_state, "model_params.b64")

Writing Plan to e:\ml/tp_ops.b64
Writing ScriptFunction to e:\ml/tp_ts.b64
Writing State to e:\ml/model_params.b64


In next notebook, we load and execute this plan.