In [11]:
from PIL import Image
from pathlib import Path

import torch
import torch.nn as nn
from torchvision.utils import save_image

import adain.net as net

from styleclr.test import test_transform, style_transfer
from styleclr.utils import move_to_top_directory


In [12]:
move_to_top_directory()
%pwd

'/home/felix/styleclr'

In [13]:
output_dir = Path('output')
decoder_path = Path('adain/models/decoder.pth')
vgg_path = Path('adain/models/vgg_normalised.pth')
content_size = 512
style_size = 512
crop = False
alpha = 1
save_ext = '.jpg'

content_path = Path('adain/input/content/avril.jpg')
style_path = Path('adain/input/style/asheville.jpg')

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

decoder = net.decoder
vgg = net.vgg

decoder.eval()
vgg.eval()

decoder.load_state_dict(torch.load(decoder_path))
vgg.load_state_dict(torch.load(vgg_path))
vgg = nn.Sequential(*list(vgg.children())[:31])

vgg.to(device)
decoder.to(device)

content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)

In [15]:
content = content_tf(Image.open(str(content_path)))
style = style_tf(Image.open(str(style_path)))

style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)

with torch.no_grad():
    output = style_transfer(vgg, decoder, content, style, alpha)
output = output.cpu()

output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
    content_path.stem, style_path.stem, save_ext)
save_image(output, str(output_name))