# SuperResolution - Syft Duet - Data Owner 🎸

Contributed by [@Koukyosyumei](https://github.com/Koukyosyumei)

This example trains a SuperResolution network on the BSD300 dataset with Syft.
This notebook is mainly based on the original pytorch [example](https://github.com/OpenMined/PySyft/tree/dev/examples/duet/super_resolution/original).

## PART 1: Launch a Duet Server and Connect

As a Data Owner, you want to allow someone else to perform data science on data that you own and likely want to protect.

In order to do this, we must load our data into a locally running server within this notebook. We call this server a "Duet".

To begin, you must launch Duet and help your Duet "partner" (a Data Scientist) connect to this server.

You do this by running the code below and sending the code snippet containing your unique Server ID to your partner and following the instructions it gives!

In [None]:
import syft as sy
duet = sy.launch_duet(loopback=True)
sy.logger.add(sink="./syft_do.log")

In [None]:
from os import listdir
from os import makedirs, remove
from os.path import exists, join, basename
from six.moves import urllib
import tarfile

from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize

# Add handler

In [None]:
# handler with no tags accepts everything. Better handlers coming soon.
duet.requests.add_handler(action="accept")

# Set params and functions

To train the model, you have to send the data to data scientists with duet. Thus, you have to convert the data to torch.array.

In [None]:
config = {"upscale_factor": 2,
          "threads":4,
          "batchSize":1,
          "testBatchSize":10}

In [None]:
def is_image_file(filename):
    return any(filename.lower().endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y

In [None]:
class Prepare_DataSet:
    def __init__(self, image_dir, input_transform=None, target_transform=None):
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self):
        inputs = []
        targets = []

        for path in self.image_filenames:
            input = load_img(path)
            target = input.copy()
            if self.input_transform:
                input = self.input_transform(input)
            if self.target_transform:
                target = self.target_transform(target)

            inputs.append(input)
            targets.append(target)

        return inputs, targets

    def __len__(self):
        return len(self.image_filenames)  

In [None]:
from syft.util import get_root_data_path
def download_bsd300(dest=get_root_data_path()):
    output_image_dir = join(dest, "BSDS300/images")

    if not exists(output_image_dir):
        makedirs(dest, exist_ok=True)
        url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
        print("downloading url ", url)

        data = urllib.request.urlopen(url)

        file_path = join(dest, basename(url))
        with open(file_path, 'wb') as f:
            f.write(data.read())

        print("Extracting data")
        with tarfile.open(file_path) as tar:
            for item in tar:
                tar.extract(item, dest)

        remove(file_path)

    return output_image_dir


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def input_transform(crop_size, upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor),
        ToTensor(),
    ])


def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),
    ])

  
def get_training_set(upscale_factor):
    root_dir = download_bsd300()
    train_dir = join(root_dir, "train")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return Prepare_DataSet(train_dir,
                           input_transform=input_transform(crop_size, upscale_factor),
                           target_transform=target_transform(crop_size))
    
def get_test_set(upscale_factor):
    root_dir = download_bsd300()
    test_dir = join(root_dir, "test")
    crop_size = calculate_valid_crop_size(256, upscale_factor)

    return Prepare_DataSet(test_dir,
                           input_transform=input_transform(crop_size, upscale_factor),
                           target_transform=target_transform(crop_size))

# Load Data

In [None]:
train_set = get_training_set(config["upscale_factor"])
test_set = get_test_set(config["upscale_factor"])

In [None]:
X_train, y_train = train_set.__getitem__()
X_train = torch.cat(X_train)
y_train = torch.cat(y_train)

# Send Data and its size

In [None]:
X_train.tag("X_train")
X_train.send(duet, pointable=True)

y_train.tag("y_train")
y_train.send(duet, pointable=True)

In [None]:
train_num = sy.lib.python.Int(X_train.shape[0])
train_num.tag("train_num")
train_num.send(duet, pointable=True)

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 1 : Now STOP and run the Data Scientist notebook until the same checkpoint.

In [None]:
duet.store.pandas

In [None]:
duet.requests.pandas

### <img src="https://github.com/OpenMined/design-assets/raw/master/logos/OM/mark-primary-light.png" alt="he-black-box" width="100"/> Checkpoint 2 : Well done!