# MNIST - Lightning ⚡️ Syft Duet - Data Scientist 🥁

## PART 1: Connect to a Remote Duet Server

As the Data Scientist, you want to perform data science on data that is sitting in the Data Owner's Duet server in their Notebook.

In order to do this, we must run the code that the Data Owner sends us, which importantly includes their Duet Session ID. The code will look like this, importantly with their real Server ID.

```
import syft as sy
duet = sy.duet('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
```

This will create a direct connection from my notebook to the remote Duet server. Once the connection is established all traffic is sent directly between the two nodes.

Paste the code or Server ID that the Data Owner gives you and run it in the cell below. It will return your Client ID which you must send to the Data Owner to enter into Duet so it can pair your notebooks.

In [None]:
import torch
import torchvision
import syft as sy

from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.experimental.plugins.secure.pysyft import SyLightningModule
from pytorch_lightning.utilities.imports import is_syft_initialized
from pytorch_lightning.metrics import Accuracy
from syft.util import get_root_data_path

duet = sy.join_duet(loopback=True)

sy.client_cache["duet"] = duet
assert is_syft_initialized()

## PART 2: Setting up a Model and our Data
The majority of the code below has been adapted closely from the original PyTorch MNIST example which is available in the `original` directory with these notebooks.

The `duet` variable is now your reference to a whole world of remote operations including supported libraries like torch.

Lets take a look at the duet.torch attribute.
```
duet.torch
```

Lets create a model just like the one in the MNIST example. We do this in almost the exact same way as in PyTorch. The main difference is we inherit from sy.Module instead of nn.Module and we need to pass in a variable called torch_ref which we will use internally for any calls that would normally be to torch.

In [None]:
class SyNet(sy.Module):
    def __init__(self, torch_ref) -> None:
        super(SyNet, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = self.torch_ref.nn.Conv2d(32, 64, 3, 1) 
        self.dropout1 = self.torch_ref.nn.Dropout2d(0.25)
        self.dropout2 = self.torch_ref.nn.Dropout2d(0.5)
        self.fc1 = self.torch_ref.nn.Linear(9216, 128)
        self.fc2 = self.torch_ref.nn.Linear(128, 10)
        
        self.train_acc = Accuracy()
        self.test_acc = Accuracy()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.torch_ref.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = self.torch_ref.flatten(x, 1)
        x = self.fc1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output

In [None]:
class LiftSyLightningModule(SyLightningModule):
    def __init__(self, module, duet):
        super().__init__(module, duet)
    
    def train(self, mode: bool = True):
        if self.is_remote:
            return self.remote_model.train(mode)
        else:
            return self.module.train(mode)
        
    def eval(self):
        return self.train(False)
    
    def training_step(self, batch, batch_idx):
        data_ptr, target_ptr = batch[0], batch[1]  # batch is list so no destructuring
        output = self.forward(data_ptr)
        loss = self.torch.nn.functional.nll_loss(output, target_ptr)
        
        target = target_ptr.get(delete_obj=False)
        real_output = output.get(delete_obj=False)
        
        self.log("train_acc", self.module.train_acc(real_output.argmax(-1), target), on_epoch=True, prog_bar=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        data, target = batch[0], batch[1]  # batch is list so no destructuring
        output = self.forward(data)
        loss = self.torch.nn.functional.nll_loss(output, target)
        self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = self.torch.optim.SGD(self.model.parameters(), lr=0.1)
        return optimizer
    
    @property
    def torchvision(self):
        tv = duet.torchvision if self.is_remote() else torchvision
        return tv
    
    def get_transforms(self):
        current_list = duet.python.List if self.is_remote() else list
        transforms = current_list()
        transforms.append(self.torchvision.transforms.ToTensor())
        transforms.append(self.torchvision.transforms.Normalize(0.1307, 0.3081))
        transforms_compose = self.torchvision.transforms.Compose(transforms)
        return transforms_compose
    
    def train_dataloader(self):
        transforms_ptr = self.get_transforms()
        train_dataset_ptr = self.torchvision.datasets.MNIST(
            str(get_root_data_path()),
            train=True,
            download=True,
            transform=transforms_ptr,
        )
        train_loader_ptr = self.torch.utils.data.DataLoader(
            train_dataset_ptr, batch_size=500
        )
        return train_loader_ptr
    
    def test_dataloader(self):
        transforms = self.get_transforms()
        test_dataset = self.torchvision.datasets.MNIST(
            str(get_root_data_path()),
            train=False,
            download=True,
            transform=transforms,
        )
        test_loader = self.torch.utils.data.DataLoader(test_dataset, batch_size=1)
        return test_loader

In [None]:
module = SyNet(torch)
model = LiftSyLightningModule(module=module, duet=duet)

In [None]:
limit_train_batches = 1.0 # 1.0 is 100% of data

trainer = Trainer(
    default_root_dir="./",
    max_epochs=1,
    limit_train_batches=limit_train_batches
)

In [None]:
trainer.fit(model)

In [None]:
model = LiftSyLightningModule.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path, module=module, duet=duet
)

In [None]:
if not model.module.is_local:
    local_model = model.module.get(
        request_block=True,
        reason="test evaluation",
        timeout_secs=5
    )
else:
    local_model = model

torch.save(local_model.state_dict(), "weights.pt")

In [None]:
from torch import nn
class NormalModel(nn.Module):
    def __init__(self) -> None:
        super(NormalModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1) 
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

In [None]:
torch_model = NormalModel()

In [None]:
saved_state_dict = torch.load("weights.pt")
torch_model.load_state_dict(saved_state_dict)

In [None]:
# TorchVision hotfix https://github.com/pytorch/vision/issues/3549
from syft.util import get_root_data_path
from torchvision import datasets
import torch.nn.functional as F

datasets.MNIST.resources = [
    (
        "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz",
        "f68b3c2dcbeaaa9fbdd348bbdeb94873",
    ),
    (
        "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz",
        "d53e105ee54ea40749a09fcbcd1e9432",
    ),
    (
        "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz",
        "9fb629c4189551a2d022fa330f9573f3",
    ),
    (
        "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz",
        "ec29112dd5afa0611ce80d1b7f02629c",
    ),
]

batch_size_test = 100

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        get_root_data_path(),
        train=False, download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ])
    ),
    batch_size=batch_size_test, shuffle=True
)

In [None]:
def test(network, test_loader):
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), accuracy))
    return accuracy.item()

In [None]:
result = test(torch_model, test_loader)

In [None]:
expected_accuracy = 93.0
assert result > expected_accuracy