In [1]:
import gradio as gr
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image

import net
from function import adaptive_instance_normalization, coral

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def show_image(content,style):
# No interpolation in this

    def test_transform(size, crop):
        transform_list = []
        if size != 0:
            transform_list.append(transforms.Resize(size))
        if crop:
            transform_list.append(transforms.CenterCrop(size))
        transform_list.append(transforms.ToTensor())
        transform = transforms.Compose(transform_list)
        return transform


    def style_transfer(vgg, decoder, content, style, alpha=1.0,
                       interpolation_weights=None):
        assert (0.0 <= alpha <= 1.0)
        content_f = vgg(content)
        style_f = vgg(style)
        if interpolation_weights:
            _, C, H, W = content_f.size()
            feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
            base_feat = adaptive_instance_normalization(content_f, style_f)
            for i, w in enumerate(interpolation_weights):
                feat = feat + w * base_feat[i:i + 1]
            content_f = content_f[0:1]
        else:
            feat = adaptive_instance_normalization(content_f, style_f)
        feat = feat * alpha + content_f * (1 - alpha)
        return decoder(feat)

    # #basic Informations
    # parser.add_argument('--content', type=str,
    #                     help='File path to the content image')
    # parser.add_argument('--style', type=str,
    #                     help='File path to the style image, or multiple style \
    #                     images separated by commas if you want to do style \
    #                     interpolation or spatial control')
    
    
    vgg_path='models/vgg_normalised.pth'
    decoder_path='models/decoder.pth'
    
    # Additional options
    content_size=512
    style_size=512
    crop=False
    save_ext='.jpg'
    output='output'
    
    # Advanced options
    preserve_color=False # To preserve colour of content image
    alpha=1.0 # The weight that controls the degree of \stylization. Should be between 0 and 1
    
    
    # parser.add_argument(
    #     '--style_interpolation_weights', type=str, default='',
    #     help='The weight for blending the style of multiple style images')
    
    do_interpolation = False
    
    output_dir = Path(output)
    output_dir.mkdir(exist_ok=True, parents=True)
    
    assert (content)
    if content:
        content_paths = [Path(content)]
    
    assert (style)
    if style:
        style_paths = style.split(',')
        if len(style_paths) == 1:
            style_paths = [Path(style)]
    
    decoder = net.decoder
    vgg = net.vgg
    
    decoder.eval()
    vgg.eval()
    
    decoder.load_state_dict(torch.load(decoder_path))
    vgg.load_state_dict(torch.load(vgg_path))
    vgg = nn.Sequential(*list(vgg.children())[:31])
    
    vgg.to(device)
    decoder.to(device)
    
    content_tf = test_transform(content_size, crop)
    style_tf = test_transform(style_size, crop)
    
    for content_path in content_paths:
        if do_interpolation:  # one content image, N style image, not used, let it be for now
            style = torch.stack([style_tf(Image.open(str(p))) for p in style_paths])
            content = content_tf(Image.open(str(content_path))).unsqueeze(0).expand_as(style)
            style = style.to(device)
            content = content.to(device)
            with torch.no_grad():
                output = style_transfer(vgg, decoder, content, style,alpha, interpolation_weights)
            output = output.cpu()
            output_name = output_dir / '{:s}_interpolation{:s}'.format(
                content_path.stem, save_ext)

            transform = transforms.ToPILImage()
            output=transform(output[0])
    
        else:  # process one content and one style
            for style_path in style_paths:
                content = content_tf(Image.open(str(content_path)))
                style = style_tf(Image.open(str(style_path)))
                if preserve_color:
                    style = coral(style, content)
                style = style.to(device).unsqueeze(0)
                content = content.to(device).unsqueeze(0)
                with torch.no_grad():
                    output = style_transfer(vgg, decoder, content, style,alpha)
                output = output.cpu()
    
                output_name = output_dir / '{:s}_stylized_{:s}{:s}'.format(
                    content_path.stem, style_path.stem, save_ext)
                transform = transforms.ToPILImage()
                output=transform(output[0])

    return output


app = gr.Interface(
    fn=show_image,
    inputs=[gr.Image(label="Content Image", type="filepath"),gr.Image(label="Style Image", type='filepath')],
    outputs=gr.Image(label="Output Image", type="filepath"),
)

app.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


