# Neural Style Transfer - Syft Duet - Data Owner 🎸

## PART 1: Launch a Duet Server and Connect

As a Data Owner, you want to allow someone else to perform data science on data that you own and likely want to protect.

In order to do this, we must load our data into a locally running server within this notebook. We call this server a "Duet".

To begin, you must launch Duet and help your Duet "partner" (a Data Scientist) connect to this server.

You do this by running the code below and sending the code snippet containing your unique Server ID to your partner and following the instructions it gives!

In [None]:
import syft as sy
duet = sy.launch_duet(loopback=True)

In [None]:
import torch
import re
from torchvision import transforms

from original.neural_style import utils
# from original.neural_style.vgg import Vgg16
# from original.neural_style.transformer_net import TransformerNet # redefined below

## Load the already pre-trained model (with the style it was trained)

In [None]:
%run "original/download_saved_models.py"

In [None]:
class TransformerNet(sy.Module):
    def __init__(self, torch_ref):
        super(TransformerNet, self).__init__(torch_ref=torch_ref)
        # Initial convolution layers
        self.conv1 = ConvLayer(self.torch_ref, 3, 32, kernel_size=9, stride=1)
        self.in1 = self.torch_ref.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(self.torch_ref, 32, 64, kernel_size=3, stride=2)
        self.in2 = self.torch_ref.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(self.torch_ref, 64, 128, kernel_size=3, stride=2)
        self.in3 = self.torch_ref.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(self.torch_ref, 128)
        self.res2 = ResidualBlock(self.torch_ref, 128)
        self.res3 = ResidualBlock(self.torch_ref, 128)
        self.res4 = ResidualBlock(self.torch_ref, 128)
        self.res5 = ResidualBlock(self.torch_ref, 128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(self.torch_ref, 128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = self.torch_ref.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(self.torch_ref, 64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = self.torch_ref.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(self.torch_ref, 32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = self.torch_ref.nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class ConvLayer(sy.Module):
    def __init__(self, torch_ref, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__(torch_ref=torch_ref)
        reflection_padding = kernel_size // 2
        self.reflection_pad = self.torch_ref.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = self.torch_ref.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(sy.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, torch_ref, channels):
        super(ResidualBlock, self).__init__(torch_ref=torch_ref)
        self.conv1 = ConvLayer(self.torch_ref, channels, channels, kernel_size=3, stride=1)
        self.in1 = self.torch_ref.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(self.torch_ref, channels, channels, kernel_size=3, stride=1)
        self.in2 = self.torch_ref.nn.InstanceNorm2d(channels, affine=True)
        self.relu = self.torch_ref.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(sy.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, torch_ref, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__(torch_ref=torch_ref)
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = self.torch_ref.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = self.torch_ref.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = self.torch_ref.nn.functional.interpolate(
                x_in, mode="nearest", scale_factor=self.upsample
            )
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [None]:
import torch
model_path = "saved_models/mosaic.pth" 
    
# TODO
# load weights into the model
with torch.no_grad():
    style_model = TransformerNet(torch)
    state_dict = torch.load(model_path)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
#     for k in list(state_dict.keys()):
#         if re.search(r"in\d+\.running_(mean|var)$", k):
#             del state_dict[k]
#     style_model.load_state_dict(state_dict)
    
# sy_model = sy.Module.from_pytorch(torch, dict(style_model.named_modules()))
print(style_model.modules)

In [None]:
# TODO
# make the top level of the symodule sendable so that a pointer reference can be obtained on 
# the DS side and used to execute
ptr = style_model.send(duet)

In [None]:
content_img = utils.load_image("original/images/content_images/amber.jpg")

In [None]:
content_img

In [None]:
content_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]
    )
content_img = content_transform(content_img)
content_img = content_img.unsqueeze(0)

In [None]:
output = style_model(content_img)[0].cpu()

In [None]:
# Use this to directly show the image in the notebook
# The util function from the pytorch repository saves the image
# def get_img_from_tensor(out):
#     trans = transforms.ToPILImage()
#     img = 
#     return output.detach().clamp(0, 255).numpy().astype("uint8")

transforms.ToPILImage()(output.detach().numpy().astype("uint8").transpose(1, 2, 0))