In [None]:
import syft as sy
duet = sy.launch_duet(loopback=True)

In [None]:
# stdlib
import re
from typing import Any
from typing import List as TypeList

import torch
from torchvision import transforms

from original.neural_style import utils

from syft import SyModule
from syft import SySequential
# from original.neural_style.vgg import Vgg16
# from original.neural_style.transformer_net import TransformerNet # redefined below

In [None]:
# handler with no tags accepts everything. Better handlers coming soon.
duet.requests.add_handler(action="accept")

In [None]:
args = {"image_size":None,
        "dataset":None,
        "batch_size":4,
        "cuda":False}

In [None]:
class TransformerNet(SyModule):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        # Initial convolution layers
        self.conv1 = ConvLayer(in_channels = 3, 
                               out_channels = 32, 
                               kernel_size=9, stride=1,
                               input_size=(1, 3, 1080, 1080))
        self.in1 = torch.nn.InstanceNorm2d(num_features=32,
                                           affine=True)
        self.conv2 = ConvLayer(in_channels = 32,
                               out_channels = 64,
                               kernel_size=3, stride=2,
                               input_size=(1, 32, 1080, 1080))
        self.in2 = torch.nn.InstanceNorm2d(num_features=64,
                                           affine=True)
        self.conv3 = ConvLayer(in_channels = 64,
                               out_channels = 128, kernel_size=3,
                               stride=2,
                               input_size=(1, 64, 540, 540))
        self.in3 = torch.nn.InstanceNorm2d(num_features=128,
                                           affine=True)
        # Residual layers
        self.res1 = ResidualBlock(input_size=(1, 128, 270, 270))
        self.res2 = ResidualBlock(input_size=(1, 128, 270, 270))
        self.res3 = ResidualBlock(input_size=(1, 128, 270, 270))
        self.res4 = ResidualBlock(input_size=(1, 128, 270, 270))
        self.res5 = ResidualBlock(input_size=(1, 128, 270, 270))
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(in_channels=128, 
                                         out_channels=64,
                                         kernel_size=3,
                                         stride=1,
                                         #upsample=2,
                                         input_size=(1, 128, 270, 270))
        self.in4 = torch.nn.InstanceNorm2d(num_features=64, affine=True)
        self.deconv2 = UpsampleConvLayer(in_channels=64,
                                         out_channels=32,
                                         kernel_size=3,
                                         stride=1,
                                         #upsample=2,
                                         input_size=(1, 64, 540, 540))
        self.in5 = torch.nn.InstanceNorm2d(num_features=32, affine=True)
        self.deconv3 = ConvLayer(in_channels=32,
                                 out_channels=3,
                                 kernel_size=9, stride=1,
                                 input_size=(1, 32, 1080, 1080))
        # Non-linearities
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        y = self.relu(self.in1(self.conv1(x=x)[0]))
        y = self.relu(self.in2(self.conv2(x=y)[0]))
        y = self.relu(self.in3(self.conv3(x=y)[0]))
        
        y = self.res1(x=y)[0]
        y = self.res2(x=y)[0]
        y = self.res3(x=y)[0]
        y = self.res4(x=y)[0]
        y = self.res5(x=y)[0]
        
        y = self.relu(self.in4(self.deconv1(x=y)[0]))
        y = self.relu(self.in5(self.deconv2(x=y)[0]))
        y = self.deconv3(x=y)[0]
        return y
    
class ConvLayer(SyModule):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int,
                 **kwargs: Any) -> None:
        super().__init__(**kwargs)
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(padding=reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class ResidualBlock(SyModule):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        channels = 128
        self.conv1 = ConvLayer(in_channels=128,
                               out_channels=128,
                               kernel_size=3,
                               stride=1,
                               input_size=(1, 128, 270, 270))
        self.in1 = torch.nn.InstanceNorm2d(num_features=channels, affine=True)
        self.conv2 = ConvLayer(in_channels=128,
                               out_channels=128,
                               kernel_size=3, stride=1,
                               input_size=(1, 128, 270, 270))
        self.in2 = torch.nn.InstanceNorm2d(num_features=channels, affine=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.conv1(x=x)[0]
        out = self.in1(out)
        out = self.relu(out)
        out = self.in2(self.conv2(x=out)[0])
        out = out + residual
        return out


class UpsampleConvLayer(SyModule):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int,
                 upsample: float=None,
                 **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(padding=reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = remote_torch.nn.functional.interpolate(
                x_in, mode="nearest", scale_factor=self.upsample
            )
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [None]:
class DummyNet(SyModule):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)

    def forward(self, x):
        return x

In [None]:
remote_torch = duet.torch
remote_python = duet.python
remote_torchvision = duet.torchvision

In [None]:
content_image = utils.load_image("original/images/content_images/amber.jpg")
content_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]
)

content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0)

In [None]:
#dummy_content_dl = sy.lib.python.List([content_image])
dummy_net = DummyNet(input_size=(1, 3, 1080, 1080))

In [None]:
from syft.core.plan.plan_builder import make_plan

@make_plan
def stylize(x=content_image, style_model=dummy_net):
    output = style_model(x=x)
    return output

In [None]:
stylize.tag("stylize")
stylize.send(duet)

In [None]:
#%run "original/download_saved_models.py"

In [None]:
"""
model_path = "saved_models/mosaic.pth" 
device = torch.device("cuda" if args["cuda"] else "cpu")

# TODO
# load weights into the model
with torch.no_grad():
    style_model = TransformerNet(input_size=(1, 3, 1080, 1080))
    state_dict = torch.load(model_path)
    # remove saved deprecated running_* keys in InstanceNorm from the checkpoint
    for k in list(state_dict.keys()):
        if re.search(r"in\d+\.running_(mean|var)$", k):
            del state_dict[k]
    style_model.load_state_dict(state_dict)
"""

In [None]:
duet.store.pandas

In [None]:
stylize_ptr = duet.store["stylize"]

In [None]:
stylized_image = stylize_ptr(x=content_image, style_model=dummy_net)