# VAE - Syft Duet - Data Scientist

This example trains a VAE network on the MNIST dataset with Syft


## PART 1: Connect to a Remote Duet Server
As the Data Scientist, you want to perform data science on data that is sitting in the Data Owner's Duet server in their Notebook.

In order to do this, we must run the code that the Data Owner sends us, which importantly includes their Duet Session ID. The code will look like this, importantly with their real Server ID.

import syft as sy
duet = sy.duet('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
This will create a direct connection from my notebook to the remote Duet server. Once the connection is established all traffic is sent directly between the two nodes.

Paste the code or Server ID that the Data Owner gives you and run it in the cell below. It will return your Client ID which you must send to the Data Owner to enter into Duet so it can pair your notebooks.

In [None]:
import syft as sy

duet = sy.duet('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 0 : Now STOP and run the Data Owner notebook until Checkpoint 1.

In [None]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F

In [None]:
import os
import torch
import torchvision
import torchvision.utils as vutils
from torchvision.utils import save_image

from PIL import Image

try:
    # make notebook progress bars nicer
    from tqdm.notebook import tqdm
except ImportError:
    print(f"Unable to import tqdm")

In [None]:
remote_torch = duet.torch

In [None]:
dry_run = True
epochs = 1

config = {
    "batch_size": 128,
    "no_cuda": True,
    "seed": 42,
    "epochs": epochs,
    "dry_run": dry_run,
    "log_interval": 10,
}

## Load

In [None]:
from syft.util import get_root_data_path

# we need some transforms for the MNIST data set
remote_torchvision = duet.torchvision

transform_1 = (
    remote_torchvision.transforms.ToTensor()
)  # this converts PIL images to Tensors

remote_list = duet.python.List()  # create a remote list to add the transforms to
remote_list.append(transform_1)

# compose our transforms
transforms = remote_torchvision.transforms.Compose(remote_list)

# The DO has kindly let us initialise a DataLoader for their training set
train_kwargs = {"batch_size": config["batch_size"], "shuffle": True}
train_data_ptr = remote_torchvision.datasets.MNIST(
    str(get_root_data_path()), train=True, download=True, transform=transforms
)
train_loader_ptr = remote_torch.utils.data.DataLoader(train_data_ptr, **train_kwargs)

test_data_ptr = remote_torchvision.datasets.MNIST(
    str(get_root_data_path()), train=False, download=True, transform=transforms
)
test_loader_ptr = remote_torch.utils.data.DataLoader(test_data_ptr, **train_kwargs)

In [None]:
# normally we would not necessarily know the length of a remote dataset so lets ask for it
# so we can pass that to our training loop and know when to stop
def get_train_length(train_data_ptr):
    train_length = train_data_ptr.__len__()
    return train_length


try:
    if train_data_length is None:
        train_data_length = get_train_length(train_data_ptr)
except NameError:
    train_data_length = get_train_length(train_data_ptr)

try:
    if test_data_length is None:
        test_data_length = get_train_length(test_data_ptr)
except NameError:
    test_data_length = get_train_length(test_data_ptr)

print(f"Training Dataset size is: {train_data_length}")
print(f"Training Dataset size is: {test_data_length}")

In [None]:
train_loader_length = train_loader_ptr.len().get(request_block=True)

## Check GPU

In [None]:
has_cuda = False
has_cuda_ptr = remote_torch.cuda.is_available().get(request_block=True)

# lets ask to see if our Data Owner has CUDA
print("Is cuda available ? : ", has_cuda)

use_cuda = not config["no_cuda"] and has_cuda
# now we can set the seed
remote_torch.manual_seed(config["seed"])

device = remote_torch.device("cuda" if use_cuda else "cpu")
# print(f"Data Owner device is {device.type.get()}")

## Define and Create models

In [None]:
class VAE(sy.Module):
    def __init__(self, torch_ref):
        super(VAE, self).__init__(torch_ref=torch_ref)

        self.fc1 = self.torch_ref.nn.Linear(784, 400)
        self.fc21 = self.torch_ref.nn.Linear(400, 20)
        self.fc22 = self.torch_ref.nn.Linear(400, 20)
        self.fc3 = self.torch_ref.nn.Linear(20, 400)
        self.fc4 = self.torch_ref.nn.Linear(400, 784)

    def encode(self, x):
        h1 = self.torch_ref.nn.ReLU()(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = remote_torch.exp(0.5 * logvar)
        eps = torch.randn(config["batch_size"], 20)
        return mu + eps * std

    def decode(self, z):
        h3 = self.torch_ref.nn.ReLU()(self.fc3(z))
        return self.torch_ref.nn.Sigmoid()(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
local_vae = VAE(torch)
vae = local_vae.send(duet)

# if we have CUDA lets send our model to the GPU
if has_cuda:
    vae.cuda(device)
else:
    vae.cpu()

In [None]:
assert not vae.is_local, "Training requires remote model"

In [None]:
optimizer = remote_torch.optim.Adam(vae.parameters(), lr=1e-3)

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = remote_torch.nn.BCELoss(reduction="sum")(recon_x, x.view(-1, 784))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * remote_torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [None]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader_ptr):
        data_ptr = remote_torch.Tensor(data[0]).to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data_ptr)
        loss = loss_function(recon_batch, data_ptr, mu, logvar)
        loss.backward()
        optimizer.step()

        batch_loss = loss.item().get(request_block=True)
        if batch_loss is not None:
            train_loss += batch_loss
            if batch_idx % config["log_interval"] == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * config["batch_size"],
                        train_data_length,
                        100.0 * (batch_idx / train_loader_length),
                        batch_loss / config["batch_size"],
                    )
                )

        if config["dry_run"]:
            break

    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / train_data_length
        )
    )

In [None]:
def test(epoch):
    vae.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader_ptr):
            data_ptr = remote_torch.Tensor(data[0]).to(device)
            recon_batch, mu, logvar = vae(data_ptr)
            batch_loss = loss_function(recon_batch, data_ptr, mu, logvar).get(
                request_block=True
            )
            if batch_loss is not None:
                test_loss += batch_loss / test_data_length

            if config["dry_run"]:
                break

    print(f"====> Test set loss: {test_loss:.4f}")

## Training

In [None]:
for epoch in range(1, config["epochs"] + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to("cuda" if use_cuda else "cpu")
        sample = vae.decode(sample).cpu()
        sample_image = sample.get(request_block=True)
        save_image(sample_image.view(64, 1, 28, 28), "sample_" + str(epoch) + ".png")

    if config["dry_run"]:
        break

## Inference

In [None]:
import PIL

sample = torch.randn(64, 20).to("cuda" if use_cuda else "cpu")
sample = vae.decode(sample).cpu()
sample_image = sample.get(request_block=True)

save_image(sample_image.view(64, 1, 28, 28), "output.png")

# PIL.Image.open("output.png")

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 1 : Now STOP and run the Data Owner notebook until the next checkpoint.