# MNIST - Lightning ⚡️ Syft Duet - Data Scientist 🥁

## PART 1: Connect to a Remote Duet Server

As the Data Scientist, you want to perform data science on data that is sitting in the Data Owner's Duet server in their Notebook.

In order to do this, we must run the code that the Data Owner sends us, which importantly includes their Duet Session ID. The code will look like this, importantly with their real Server ID.

```
import syft as sy
duet = sy.duet('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
```

This will create a direct connection from my notebook to the remote Duet server. Once the connection is established all traffic is sent directly between the two nodes.

Paste the code or Server ID that the Data Owner gives you and run it in the cell below. It will return your Client ID which you must send to the Data Owner to enter into Duet so it can pair your notebooks.

In [None]:
import syft as sy
duet = sy.join_duet(loopback=True)
sy.logger.add(sink="./syft_ds.log")

In [None]:
duet.store.pandas

## PART 2: Setting up a Model and our Data
The majority of the code below has been adapted closely from the original PyTorch MNIST example which is available in the `original` directory with these notebooks.

The `duet` variable is now your reference to a whole world of remote operations including supported libraries like torch.

Lets take a look at the duet.torch attribute.
```
duet.torch
```

In [None]:
# stdlib
from types import ModuleType
from typing import Any
from typing import List
from typing import Optional
from typing import Union

# third party
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.experimental.plugins.secure.pysyft import SyLightningModule
import torch
from torchvision import transforms

# syft absolute
import syft as sy
from syft.ast.module import Module
from torch import nn
from syft.util import get_root_data_path

SyModuleProxyType = Union[ModuleType, Module]
SyModelProxyType = Union[nn.Module, sy.Module]

# cant use lib_ast during test search time
TorchTensorPointerType = Any  # sy.lib_ast.torch.Tensor.pointer_type
TorchDataLoaderPointerType = Any  # sy.lib_ast.torch.utils.data.DataLoader
SyTensorProxyType = Union[torch.Tensor, TorchTensorPointerType]  # type: ignore
SyDataLoaderProxyType = Union[torch.utils.data.DataLoader, TorchDataLoaderPointerType]  # type: ignore

Lets create a model just like the one in the MNIST example. We do this in almost the exact same way as in PyTorch. The main difference is we inherit from sy.Module instead of nn.Module and we need to pass in a variable called torch_ref which we will use internally for any calls that would normally be to torch.

In [None]:
class SyNet(sy.Module):
    def __init__(self, torch_ref: SyModuleProxyType) -> None:
        super(SyNet, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = self.torch_ref.nn.Conv2d(32, 64, 3, 1) 
        self.dropout1 = self.torch_ref.nn.Dropout2d(0.25)
        self.dropout2 = self.torch_ref.nn.Dropout2d(0.5)
        self.fc1 = self.torch_ref.nn.Linear(9216, 128)
        self.fc2 = self.torch_ref.nn.Linear(128, 10)

    def forward(self, x: SyTensorProxyType) -> SyTensorProxyType:
        x = self.conv1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.torch_ref.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = self.torch_ref.flatten(x, 1)
        x = self.fc1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output

In [None]:
tmpdir = "./"

In [None]:
from pytorch_lightning.utilities.imports import is_syft_initialized

In [None]:
is_syft_initialized()

In [None]:
# bookkeeping
sy.client_cache["duet"] = duet

class LiftSyLightningModule(SyLightningModule):
    def __init__(self, module: sy.Module, duet: Any) -> None:
        super().__init__(module, duet)

    def training_step(
        self, batch: SyTensorProxyType, batch_idx: Optional[int]
    ) -> SyTensorProxyType:
        data_ptr = batch
        output = self.forward(data_ptr)
        return self.torch.nn.functional.mse_loss(
            output, self.torch.ones_like(output)
        )

    def test_step(self, batch: SyTensorProxyType, batch_idx: Optional[int]) -> None:
        output = self.forward(batch)
        loss = self.loss(output, self.torch.ones_like(output))
        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True)

    def configure_optimizers(self) -> List:
        optimizer = self.torch.optim.SGD(self.model.parameters(), lr=0.1)  # type: ignore
        return [optimizer]
    
    @property
    def torchvision(self) -> SyModuleProxyType:
        return duet.torchvision if self.is_remote() else torchvision
    
    def get_transforms(self) -> type(transforms.transforms.Compose):  # type: ignore
        current_list = duet.python.List if self.is_remote() else list
        transforms = current_list()
        transforms.append(self.torchvision.transforms.ToTensor())  # type: ignore
        transforms.append(self.torchvision.transforms.Normalize(0.1307, 0.3081))  # type: ignore
        return self.torchvision.transforms.Compose(transforms)  # type: ignore

    def train_dataloader(self) -> SyDataLoaderProxyType:
        transforms = self.get_transforms()
        train_data_ptr = self.torchvision.datasets.MNIST(  # type: ignore
            str(get_root_data_path()),
            train=True,
            download=True,
            transform=transforms,
        )
        train_loader_ptr = self.torch.utils.data.DataLoader(  # type: ignore
            train_data_ptr, batch_size=1
        )
        return train_loader_ptr

    def test_dataloader(self) -> SyDataLoaderProxyType:
        transforms = self.get_transforms()
        test_data = self.torchvision.datasets.MNIST(  # type: ignore
            str(get_root_data_path()),
            train=False,
            download=True,
            transform=transforms,
        )
        test_loader = self.torch.utils.data.DataLoader(test_data, batch_size=1)  # type: ignore
        return test_loader

In [None]:
module = SyNet(torch)
model = LiftSyLightningModule(module=module, duet=duet)

In [None]:
trainer = Trainer(
    default_root_dir=tmpdir,
    max_epochs=1,
    limit_train_batches=2,
    limit_test_batches=2,
)

In [None]:
trainer.fit(model)
trainer.test()
trainer.test(model)

In [None]:
model = LiftSyLightningModule.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path, module=module, duet=duet
)

In [None]:
trainer.fit(model)