In [1]:
import gradio as gr
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [3]:
vgg = models.vgg16(pretrained=True).features.to(device).eval()



In [4]:
content_layers = ['21']
style_layers = ['0', '5', '10', '19', '28']

In [5]:
class VGGFeatures(nn.Module):
    def __init__(self, model, style_layers, content_layers):
        super(VGGFeatures, self).__init__()
        self.model = model
        self.style_layers = style_layers
        self.content_layers = content_layers
    def forward(self, x):
        content_features = {}
        style_features = {}
        for name, layer in self.model._modules.items():
            x = layer(x)
            if name in self.content_layers:
                content_features[name] = x
            if name in self.style_layers:
                style_features[name] = x
        return content_features, style_features

In [6]:
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b * c, h * w)
    G = torch.mm(features, features.t())
    return G.div(b * c * h * w)

In [7]:
def load_image(image):
    image = transform(image).unsqueeze(0).to(device)
    return image

In [8]:
def run_style_transfer(content_img, style_img, mask_img, num_steps=300, style_weight=1e6, content_weight=1):
    content = load_image(content_img)
    style = load_image(style_img)
    mask = transform(mask_img).unsqueeze(0).to(device)

    input_img = content.clone().requires_grad_(True)
    model = VGGFeatures(vgg, style_layers, content_layers).to(device)

    optimizer = torch.optim.LBFGS([input_img])

    style_targets = {}
    content_targets = {}
    content_features, style_features = model(content)
    _, style_features_ref = model(style)
    
    for name in content_features:
        content_targets[name] = content_features[name].detach()
    for name in style_features:
        style_targets[name] = gram_matrix(style_features_ref[name].detach())

    run = [0]
    while run[0] <= num_steps:
        def closure():
            input_img.data.clamp_(0, 1)
            optimizer.zero_grad()
            content_pred, style_pred = model(input_img)

            content_loss = 0
            style_loss = 0

            for name in content_pred:
                content_loss += content_weight * torch.nn.functional.mse_loss(content_pred[name], content_targets[name])
            for name in style_pred:
                G = gram_matrix(style_pred[name])
                A = style_targets[name]
                style_loss += style_weight * torch.nn.functional.mse_loss(G, A)

            total_loss = content_loss + style_loss
            total_loss.backward()
            run[0] += 1
            return total_loss

        optimizer.step(closure)

    input_img.data.clamp_(0, 1)
    result = input_img.cpu().clone().squeeze(0)
    result = transforms.ToPILImage()(result)

    mask_img = mask_img.convert("L").resize(result.size)
    content_img = content_img.resize(result.size)
    result = Image.composite(result, content_img, mask_img)
    return result

In [None]:
def stylize_image(content, style, mask):
    content = Image.open(content).convert("RGB").resize((256, 256))
    style = Image.open(style).convert("RGB").resize((256, 256))
    mask = Image.open(mask).convert("L").resize((256, 256))
    output = run_style_transfer(content, style, mask)
    return output

interface = gr.Interface(
    fn=stylize_image,
    inputs=[
        gr.Image(type="filepath", label="Content Image"),
        gr.Image(type="filepath", label="Style Image"),
        gr.Image(type="filepath", label="Mask Image")
    ],
    outputs=gr.Image(type="pil", label="Stylized Output"),
    title="Masked Style Transfer with VGG16",
    description="Upload a content image, a style image, and a binary mask to selectively stylize parts of your image."
)

interface.launch(debug=True, share=False)

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "C:\Users\visio\AppData\Local\Programs\Python\Python310\lib\site-packages\uvicorn\protocols\http\h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "C:\Users\visio\AppData\Local\Programs\Python\Python310\lib\site-packages\uvicorn\middleware\proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
  File "C:\Users\visio\AppData\Local\Programs\Python\Python310\lib\site-packages\fastapi\applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "C:\Users\visio\AppData\Local\Programs\Python\Python310\lib\site-packages\starlette\applications.py", line 112, in __call__
    await self.middleware_stack(scope, receive, send)
  File "C:\Users\visio\AppData\Local\Programs\Python\Python310\lib\site-packages\starlette\middleware\errors.py", line 187, in __call__
    raise exc
  File "C:\Users\visio\