In [None]:
from PIL import Image
import numpy as np
import torch.nn as nn
import torch
from cv2.ximgproc import guidedFilter

from photorealistic_smoothing import photorealistic_smoothing
from secondary_smoothing import secondary_smoothing
from networks import VGGEncoder, VGGDecoder
from wct import wct

# Load images

In [None]:
content = Image.open('images/Tuebingen_Neckarfront.jpg')
w, h = content.size
content = content.resize(((w // 8) * 8, (h // 8) * 8), Image.LANCZOS)
content

In [None]:
style = Image.open('styles/andre derain the dance.jpg')
style

# Load the models

In [None]:
class PhotoWCT(nn.Module):

    def __init__(self):
        super(PhotoWCT, self).__init__()
        self.encoder = VGGEncoder()
        self.decoders = nn.ModuleDict({f'{i}': VGGDecoder(i) for i in [1, 2, 3, 4]})

    def forward(self, content, style):
        """
        Arguments:
            content: a float tensor.
            style: a float tensor.
        """
        with torch.no_grad():

            style_features, _ = self.encoder(style)
            x = content

            for i in [1, 2, 3, 4]:
                features, pooling_indices = self.encoder(x, level=i)
                f = wct(features[i], style_features[i])
                x = self.decoders[f'{i}'](f, pooling_indices)

        return x
    

def to_tensor(x):
    """
    Arguments:
        x: an instance of PIL image.
    Returns:
        a float tensor with shape [1, 3, h, w],
        it represents a RGB image with
        pixel values in [0, 1] range.
    """
    x = np.array(x)
    x = torch.FloatTensor(x)
    return x.permute(2, 0, 1).unsqueeze(0).div(255.0)

In [None]:
transform = PhotoWCT().cuda()

transform.encoder.load_state_dict(torch.load('models/encoder.pth'))
for i, m in transform.decoders.items():
    m.load_state_dict(torch.load(f'models/decoder{i}.pth'))

# Do whitening and coloring transform

In [None]:
content_tensor = to_tensor(content).cuda()
style_tensor = to_tensor(style).cuda()

output_tensor = transform(content_tensor, style_tensor)

In [None]:
output_array = output_tensor.cpu().clamp(0.0, 1.0)[0].permute(1, 2, 0).numpy()
output_array = (255 * output_array).astype('uint8')

In [None]:
Image.fromarray(output_array)

# Do first smoothing

In [None]:
%%time
content_array = np.array(content)
r1 = photorealistic_smoothing(content_array, output_array)

In [None]:
Image.fromarray(r1)

In [None]:
r1_another = guidedFilter(guide=content_array, src=output_array, radius=35, eps=1e-3)
Image.fromarray(r1_another)

# Do second smoothing

In [None]:
r2 = secondary_smoothing(r1, content_array)
Image.fromarray(r2)

In [None]:
r2_another = secondary_smoothing(r1_another, content_array)
Image.fromarray(r2_another)