In [None]:
import syft as sy
duet = sy.launch_duet(loopback=True)

♫♫♫ > DUET LIVE STATUS  -  Objects: 24  Requests: 0   Messages: 392  Request Handlers: 1                                

In [None]:
# stdlib
import re
from typing import Any
from typing import List as TypeList

import torch
from torchvision import transforms

from original.neural_style import utils

from syft import SyModule
from syft import SySequential
from syft.core.plan.plan_builder import PLAN_BUILDER_VM
from syft.core.plan.plan_builder import ROOT_CLIENT
# from original.neural_style.vgg import Vgg16
# from original.neural_style.transformer_net import TransformerNet # redefined below

In [None]:
# handler with no tags accepts everything. Better handlers coming soon.
duet.requests.add_handler(action="accept")

In [None]:
args = {"image_size":None,
        "dataset":None,
        "batch_size":4,
        "cuda":False}

In [None]:
class TransformerNet(sy.Module):
    def __init__(self, torch_ref):
        super(TransformerNet, self).__init__(torch_ref=torch_ref)
        # Initial convolution layers
        self.conv1 = ConvLayer(self.torch_ref, 3, 32, kernel_size=9, stride=1)
        self.in1 = self.torch_ref.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(self.torch_ref, 32, 64, kernel_size=3, stride=2)
        self.in2 = self.torch_ref.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(self.torch_ref, 64, 128, kernel_size=3, stride=2)
        self.in3 = self.torch_ref.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(self.torch_ref, 128)
        self.res2 = ResidualBlock(self.torch_ref, 128)
        self.res3 = ResidualBlock(self.torch_ref, 128)
        self.res4 = ResidualBlock(self.torch_ref, 128)
        self.res5 = ResidualBlock(self.torch_ref, 128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(self.torch_ref, 128, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = self.torch_ref.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(self.torch_ref, 64, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = self.torch_ref.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(self.torch_ref, 32, 3, kernel_size=9, stride=1)
        # Non-linearities
        self.relu = self.torch_ref.nn.ReLU()

    def forward(self, x):
        y = self.relu(self.in1(self.conv1(x)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y


class ConvLayer(sy.Module):
    def __init__(self, torch_ref, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__(torch_ref=torch_ref)
        reflection_padding = kernel_size // 2
        self.reflection_pad = self.torch_ref.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = self.torch_ref.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(sy.Module):
    """ResidualBlock
    introduced in: https://arxiv.org/abs/1512.03385
    recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
    """

    def __init__(self, torch_ref, channels):
        super(ResidualBlock, self).__init__(torch_ref=torch_ref)
        self.conv1 = ConvLayer(self.torch_ref, channels, channels, kernel_size=3, stride=1)
        self.in1 = self.torch_ref.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(self.torch_ref, channels, channels, kernel_size=3, stride=1)
        self.in2 = self.torch_ref.nn.InstanceNorm2d(channels, affine=True)
        self.relu = self.torch_ref.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.in1(self.conv1(x)))
        out = self.in2(self.conv2(out))
        out = out + residual
        return out


class UpsampleConvLayer(sy.Module):
    """UpsampleConvLayer
    Upsamples the input and then does a convolution. This method gives better results
    compared to ConvTranspose2d.
    ref: http://distill.pub/2016/deconv-checkerboard/
    """

    def __init__(self, torch_ref, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__(torch_ref=torch_ref)
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = self.torch_ref.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = self.torch_ref.nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = self.torch_ref.nn.functional.interpolate(
                x_in, mode="nearest", scale_factor=self.upsample
            )
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [None]:
remote_torch = duet.torch
remote_python = duet.python
remote_torchvision = duet.torchvision

In [None]:
%run "original/download_saved_models.py"

In [None]:
model = TransformerNet(torch_ref=torch)
model_params_zeros = sy.lib.python.List(
        [torch.nn.Parameter(torch.zeros_like(param)) for param in model.parameters()]
    )
model_params = sy.lib.python.List(model.parameters())

In [None]:
def set_remote_model_params(module_ptrs, params_list_ptr):  # type: ignore
    param_idx = 0
    for module_name, module_ptr in module_ptrs.items():
        for param_name, _ in PLAN_BUILDER_VM.store[
            module_ptr.id_at_location
        ].data.named_parameters():
            module_ptr.register_parameter(param_name, params_list_ptr[param_idx])
            param_idx += 1

In [None]:
content_image = utils.load_image("original/images/content_images/amber.jpg")
content_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]
)

content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0)

In [None]:
local_model = model.send(duet)

In [None]:
local_model(content_image)

In [None]:
set_remote_model_params(local_model.modules, model_params)

In [None]:
output = local_model(content_image)

In [None]:
output

In [None]:
from syft.core.plan.plan_builder import make_plan

@make_plan
def stylize(x=content_image, model_params=model_params_zeros):
    local_model = model.send(ROOT_CLIENT, send_parameters=False)
    set_remote_model_params(local_model.modules, model_params)
    output = local_model(x)
    return output

In [None]:
stylize.tag("stylize")
stylize.send(duet)

In [None]:
#%run "original/download_saved_models.py"

In [None]:
"""
model_path = "saved_models/mosaic.pth" 
device = torch.device("cuda" if args["cuda"] else "cpu")

# TODO
# load weights into the model
with torch.no_grad():
    style_model = TransformerNet(input_size=(1, 3, 1080, 1080))
    state_dict = torch.load(model_path)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r"in\d+\.running_(mean|var)$", k):
            del state_dict[k]
    style_model.load_state_dict(state_dict)
"""

In [None]:
duet.store.pandas

In [None]:
stylize_ptr = duet.store["stylize"]

In [None]:
stylized_image = stylize_ptr(x=content_image, style_model=dummy_net)