In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms as T
from torchvision.transforms import ToPILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr

In [2]:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class ImageLoader:
    def __init__(self, size: (int, tuple), resize: bool = True, interpolation=2):
        transforms = []
        if resize:
            transforms.append(T.Resize(size=size, interpolation=interpolation))
        transforms.append(T.ToTensor())
        self.transforms = T.Compose(transforms)

    def read_image(self, filepath: str) -> torch.Tensor:
        image = Image.open(filepath)
        image = self.transforms(image)
        image = image.to(device, torch.float)
        return image

    @staticmethod
    def show_image(tensor: torch.Tensor, title: str = "Image", save_: bool = False, filename: str = None):
        tensor = tensor.cpu().clone()
        if len(tensor.shape) == 4:
            tensor = tensor.squeeze(0)
        elif len(tensor.shape) == 2:
            tensor = tensor.unsqueeze(0)
        elif len(tensor.shape) > 4 or len(tensor.shape) < 2:
            raise ValueError(f"Bad Input shape: {tensor.shape}")

        img = ToPILImage()(tensor)
        plt.imshow(img)
        plt.title(title)
        plt.pause(0.001)

        if save_:
            img.save(fp=filename)

class MyModel(nn.Module):
    def __init__(self, con_layers: list = ['conv4_2'], sty_layers: list = None,
                 mean: list = [0.485, 0.456, 0.406], stdv: list = [0.229, 0.224, 0.225]):
        super().__init__()

        mapping_dict = {"conv1_1": 0, "conv1_2": 2,
                        "conv2_1": 5, "conv2_2": 7,
                        "conv3_1": 10, "conv3_2": 12, "conv3_3": 14, "conv3_4": 16,
                        "conv4_1": 19, "conv4_2": 21, "conv4_3": 23, "conv4_4": 25,
                        "conv5_1": 28, "conv5_2": 30, "conv5_3": 32, "conv5_4": 34}

        mean = torch.tensor(mean, dtype=torch.float, device=device)
        stdv = torch.tensor(stdv, dtype=torch.float, device=device)
        self.transforms = T.Normalize(mean, stdv)

        self.con_layers = [(mapping_dict[layer] + 1) for layer in con_layers]
        self.sty_layers = [(mapping_dict[layer] + 1) for layer in sty_layers]

        self.vgg19 = models.vgg19(pretrained=True).features
        self.vgg19 = self.vgg19.to(device).eval()

        for name, layer in self.vgg19.named_children():
            if isinstance(layer, nn.MaxPool2d):
                self.vgg19[int(name)] = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, tensor: torch.Tensor) -> dict:
        sty_feat_maps = []
        con_feat_maps = []
        tensor = self.transforms(tensor)
        x = tensor.unsqueeze(0)

        for name, layer in self.vgg19.named_children():
            x = layer(x)
            if int(name) in self.con_layers:
                con_feat_maps.append(x)
            if int(name) in self.sty_layers:
                sty_feat_maps.append(x)

        return {"Con_features": con_feat_maps, "Sty_features": sty_feat_maps}

class NeuralStyleTransfer:
    def __init__(self, con_image: torch.Tensor, sty_image: torch.Tensor, size=512,
                 con_layers: list = None, sty_layers: list = None,
                 con_loss_wt: float = 1., sty_loss_wt: float = 1., var_loss_wt=1.):
        self.con_loss_wt = con_loss_wt
        self.sty_loss_wt = sty_loss_wt
        self.var_loss_wt = var_loss_wt
        self.size = size

        self.model = MyModel(con_layers=con_layers, sty_layers=sty_layers)
        self.sty_target = self.model(sty_image)["Sty_features"]
        self.con_target = self.model(con_image)["Con_features"]

        self.var_image = con_image.clone().requires_grad_(True).to(device)

    @staticmethod
    def _get_var_loss(tensor: torch.Tensor) -> torch.Tensor:
        return (torch.sum(torch.abs(tensor[:, :, :-1] - tensor[:, :, 1:])) +
                torch.sum(torch.abs(tensor[:, :-1, :] - tensor[:, 1:, :])))

    @staticmethod
    def _get_con_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return 0.5 * torch.sum(torch.pow(pred - target, 2))

    @staticmethod
    def _get_gram_matrix(tensor: torch.Tensor) -> torch.Tensor:
        b, c, h, w = tensor.size()
        tensor_ = tensor.view(b * c, h * w)
        return torch.mm(tensor_, tensor_.t())

    def _get_sty_loss(self, pred: torch.Tensor, target: torch.Tensor):
        Z = np.power(np.prod(pred.size()), 2, dtype=np.float64)
        pred = self._get_gram_matrix(pred)
        return 0.25 * torch.sum(torch.pow(pred - target, 2)).div(Z)

    def _get_tot_loss(self, output: torch.Tensor):
        con_output = output["Con_features"]
        sty_output = output["Sty_features"]

        con_loss = [self._get_con_loss(con_output[i], self.con_target[i]) for i in range(len(con_output))]
        sty_loss = [self._get_sty_loss(sty_output[i], self.sty_target[i]) for i in range(len(sty_output))]

        con_loss = torch.mean(torch.stack(con_loss)) * self.con_loss_wt
        sty_loss = torch.mean(torch.stack(sty_loss)) * self.sty_loss_wt
        var_loss = self._get_var_loss(self.var_image) * self.var_loss_wt

        return con_loss.to(device), sty_loss.to(device), var_loss.to(device)

    def fit(self, nb_epochs: int = 1, nb_iters: int = 1000, lr: float = 1e-2, eps: float = 1e-8,
            betas: tuple = (0.9, 0.999)) -> torch.Tensor:
        self.sty_target = [self._get_gram_matrix(x).detach().to(device) for x in self.sty_target]
        self.con_target = [x.detach() for x in self.con_target]

        optimizer = optim.Adam([self.var_image], lr=lr, betas=betas, eps=eps)

        for _ in range(nb_epochs):
            for _ in range(nb_iters):
                self.var_image.data.clamp_(0, 1)
                optimizer.zero_grad()
                output = self.model(self.var_image.to(device))

                con_loss, sty_loss, var_loss = self._get_tot_loss(output)
                tot_loss = con_loss + sty_loss + var_loss

                tot_loss.backward()
                optimizer.step()

        return self.var_image.data.clamp_(0, 1)

def style_transfer(content_image_path, *style_image_paths):
    img_loader = ImageLoader(size=(512, 512), resize=True)

    con_image = img_loader.read_image(filepath=content_image_path)

    con_layers = ["conv4_2"]
    sty_layers = ["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"]

    output_images = []

    for style_image_path in style_image_paths:
        sty_image = img_loader.read_image(filepath=style_image_path)

        NST = NeuralStyleTransfer(con_image=con_image, sty_image=sty_image, size=(512, 512),
                                  con_layers=con_layers, sty_layers=sty_layers,
                                  con_loss_wt=1e-5, sty_loss_wt=1e4, var_loss_wt=5e-5)

        output_image = NST.fit(nb_epochs=1, nb_iters=1000, lr=1e-2, eps=1e-8, betas=(0.9, 0.999))
        output_images.append(output_image)

        img_loader.show_image(output_image, save_=True, filename=f"stylized_image_{len(output_images)}.jpg")

    if output_images:
        segment_width = output_images[0].shape[2] // len(output_images)
        combined_width = segment_width * len(output_images)
        combined_image = Image.new('RGB', (combined_width, output_images[0].shape[1]))

        for i, output_image in enumerate(output_images):
            pil_image = ToPILImage()(output_image)
            segment = pil_image.crop((i * segment_width, 0, (i + 1) * segment_width, pil_image.height))
            combined_image.paste(segment, (i * segment_width, 0))

        combined_image.save("combined_image.jpg")

    for i in range(len(output_images)):
        pil_image = ToPILImage()(output_images[i])
        pil_image.save(f"stylized_image_{i + 1}.jpg")

    return [f"stylized_image_{i + 1}.jpg" for i in range(len(output_images))] + ["combined_image.jpg"]

def launch_app():
    n = int(input("Enter the number of style images: "))
    interface = gr.Interface(
        fn=style_transfer,
        inputs=["file"] + ["file"] * n,
        outputs=["image"] * (n + 1),
        title="Neural Style Transfer",
        description="Upload one content image and up to " + str(n) + " style images."
    )
    interface.launch()

launch_app()

Enter the number of style images: 2
It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://c9deea0aad9fd4d4ec.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
