# SuperResolution - Syft Duet - Data Scientist 🥁

Contributed by [@Koukyosyumei](https://github.com/Koukyosyumei)

This example trains a SuperResolution network on the BSD300 dataset with Syft.
This notebook is mainly based on the original pytorch [example](https://github.com/OpenMined/PySyft/tree/dev/examples/duet/super_resolution/original).

## PART 1: Connect to a Remote Duet Server

As the Data Scientist, you want to perform data science on data that is sitting in the Data Owner's Duet server in their Notebook.

In order to do this, we must run the code that the Data Owner sends us, which importantly includes their Duet Session ID. The code will look like this, importantly with their real Server ID.

```
import syft as sy
duet = sy.duet('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
```

This will create a direct connection from my notebook to the remote Duet server. Once the connection is established all traffic is sent directly between the two nodes.

Paste the code or Server ID that the Data Owner gives you and run it in the cell below. It will return your Client ID which you must send to the Data Owner to enter into Duet so it can pair your notebooks.

In [None]:
import syft as sy
duet = sy.join_duet(loopback=True)
sy.logger.add(sink="./syft_ds.log")

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 0 : Now STOP and run the Data Owner notebook until Checkpoint 1.

In [None]:
import os
from os import listdir, makedirs, remove
from os.path import exists, join, basename
import tarfile
import subprocess
from six.moves import urllib

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.utils.data as data
import torchvision.utils as vutils
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize

import numpy as np

from PIL import Image


try:
    # make notebook progress bars nicer
    from tqdm.notebook import tqdm
except ImportError:
    print(f"Unable to import tqdm")

# Set params

Following params are based on the original implementation

In [None]:
config = {"upscale_factor": 2,
          "threads":0,
          "batchSize":1, # other size may not work
          "testBatchSize":1,
          "lr":0.001,
          "epochs":2,
          "no_cuda":True,
          "log_batch_size":10,
          "seed":42,
          "dry_run":True,
          "test":False}

# Load data

You can receive the data which data owner send with following codes. Also, you need custom Dataset and collate_fn which can process tensorpointer. As of noe, tensorpointer doesn't support slice, so batch size must be one. 

In [None]:
remote_torch = duet.torch

In [None]:
duet.store.pandas

In [None]:
X_train = duet.store["X_train"]
y_train = duet.store["y_train"]
train_num = duet.store["train_num"]

In [None]:
class DatasetFromPointer(data.Dataset):
    def __init__(self, 
                 X_tensorpointer,
                 y_tensorpointer,
                 datanum_pointer,
                 ):
        super(DatasetFromPointer, self).__init__()
        self.X_tensorpointer = X_tensorpointer
        self.y_tensorpointer = y_tensorpointer
        self.datanum_pointer = datanum_pointer

    def __getitem__(self, index):
        input = self.X_tensorpointer[index]
        target = self.y_tensorpointer[index]
        return input, target

    def __len__(self):
        return self.datanum_pointer.get(
        request_block=True,
        reason="To write the training loop",
        timeout_secs=30,
        delete_obj=False,
    )
        
def batch_idx_fn(batch):
    return batch[0]

In [None]:
train_set = DatasetFromPointer(X_train, y_train, train_num)
training_data_loader = DataLoader(dataset=train_set, 
                                  num_workers=config["threads"], batch_size=config["batchSize"], shuffle=True,
                                  collate_fn=batch_idx_fn)

# Define and create the model

In [None]:
class Net(sy.Module):
    def __init__(self, torch_ref, upscale_factor):
        super(Net, self).__init__(torch_ref = torch_ref)

        self.relu = self.torch_ref.nn.ReLU()
        self.conv1 = self.torch_ref.nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = self.torch_ref.nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = self.torch_ref.nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = self.torch_ref.nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = self.torch_ref.nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

In [None]:
local_net = Net(torch, config["upscale_factor"])
remote_net = local_net.send(duet)

# Check cuda

You should ask the data owner whether he/she has GPUs or not.

In [None]:
has_cuda = False
has_cuda_ptr = remote_torch.cuda.is_available()

# lets ask to see if our Data Owner has CUDA
has_cuda = bool(has_cuda_ptr.get(
    request_block=True,
    reason="To run test and inference locally",
    timeout_secs=3,  # change to something slower
))
print("Is cuda available ? : ", has_cuda)

use_cuda = not config["no_cuda"] and has_cuda
# now we can set the seed
remote_torch.manual_seed(config["seed"])

device = remote_torch.device("cuda" if use_cuda else "cpu")
#print(f"Data Owner device is {device.type.get()}")

In [None]:
# if we have CUDA lets send our model to the GPU
if has_cuda:
    remote_net.cuda(device)
else:
    remote_net.cpu()

# Training

In [None]:
criterion = remote_torch.nn.MSELoss()
optimizer = remote_torch.optim.Adam(remote_net.parameters(), lr=config["lr"])

In [None]:
for epoch in range(config["epochs"]):
    
    remote_net.train()
    epoch_loss = 0
    for batch_idx, data_pointers in enumerate(training_data_loader):

        optimizer.zero_grad()
        data_ptr, target_ptr = data_pointers[0], data_pointers[1]
        data_ptr_reshape = remote_torch.unsqueeze(remote_torch.unsqueeze(data_ptr, 0), 0)
        target_ptr_reshape = remote_torch.unsqueeze(remote_torch.unsqueeze(target_ptr, 0), 0)
       
        output_ptr = remote_net(data_ptr_reshape)
        
        loss = criterion(output_ptr, target_ptr_reshape)
        loss.backward()
        optimizer.step()

        if batch_idx % config["log_batch_size"] == 0:
            loss_item = loss.item().get(
                reason="To evaluate training progress",
                request_block=True,
                timeout_secs=3,
                delete_obj=False,
                verbose=False
                )
            print(f"epoch {epoch}, batch_idx {batch_idx}, loss {loss_item}")

        if config["dry_run"]:
            break

# Save model

In [None]:
local_net = remote_net.get(
    request_block=True,
    reason="test evaluation",
    timeout_secs=5
)

local_net.save(f"super_resolve.pt")

# Inference

In [None]:
image_url = "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300/html/images/plain/normal/color/12084.jpg"
test_img_name = basename(image_url)
os.system(f'curl -O {image_url}')

In [None]:
output_img_name = "output.jpg"

img = Image.open(test_img_name).convert("YCbCr")
y, cb, cr = img.split()
img_to_tensor = ToTensor()
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0])

if has_cuda:
    local_net = local_net.cuda()
    input = input.cuda()

out = local_net(input)
out = out.cpu()
out_img_y = out[0].detach().numpy()
out_img_y *= 255.0
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L")

out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB")

out_img.save(output_img_name)
print("output image saved to ", output_img_name)

In [None]:
original_image = Image.open("12084.jpg")
super_image = Image.open("output.jpg")
#display(original_image)
#display(super_image)

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 1 : Now STOP and run the Data Owner notebook until Checkpoint 2.