In [11]:
# Import Gradio to create a simple web interface for the model
import gradio as gr

# Import PIL (Python Imaging Library) to handle image loading and manipulation
from PIL import Image

# Import PyTorch core library
import torch

# Import the neural network module from PyTorch
import torch.nn as nn

# Import commonly used image transformation tools from torchvision
import torchvision.transforms as transforms

# Import pre-trained models from torchvision
import torchvision.models as models


In [12]:
# Set the device to GPU if available, otherwise fall back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a sequence of image transformations (preprocessing steps)
transform = transforms.Compose([
    # Resize the input image to 256x256 pixels
    transforms.Resize((256, 256)),
    
    # Convert the image to a PyTorch tensor (required for model input)
    transforms.ToTensor()
])


In [13]:
# Load the pre-trained VGG-16 model's feature extractor from torchvision
# Set pretrained=True to use weights trained on ImageNet
vgg = models.vgg16(pretrained=True).features

# Move the model to the selected device (GPU if available, otherwise CPU)
vgg = vgg.to(device)

# Set the model to evaluation mode (disables dropout and batchnorm training behavior)
vgg = vgg.eval()




In [14]:
# Specify which layers of the VGG network to use for content extraction
# Layer '21' corresponds to 'relu4_2' in VGG16, commonly used for content representation
content_layers = ['21']

# Specify multiple layers of VGG to use for style extraction
# These layer indices correspond to early and deep convolutional layers:
# '0' = relu1_1, '5' = relu2_1, '10' = relu3_1, '19' = relu4_1, '28' = relu5_1
# Using multiple layers helps capture both fine and coarse style patterns
style_layers = ['0', '5', '10', '19', '28']


In [15]:
# Define a custom PyTorch module to extract content and style features from VGG
class VGGFeatures(nn.Module):
    def __init__(self, model, style_layers, content_layers):
        super(VGGFeatures, self).__init__()
        
        # Store the pre-trained VGG model (feature extractor)
        self.model = model
        
        # List of layer names to extract style features from
        self.style_layers = style_layers
        
        # List of layer names to extract content features from
        self.content_layers = content_layers

    def forward(self, x):
        # Dictionaries to hold the extracted features
        content_features = {}
        style_features = {}

        # Loop through the layers of the model sequentially
        for name, layer in self.model._modules.items():
            x = layer(x)  # Pass the input through the current layer

            # If current layer is a content layer, store its output
            if name in self.content_layers:
                content_features[name] = x

            # If current layer is a style layer, store its output
            if name in self.style_layers:
                style_features[name] = x

        # Return both content and style features as dictionaries
        return content_features, style_features


In [16]:
# Function to compute the Gram matrix of a feature map
# Used to capture style information from an image
def gram_matrix(tensor):
    # Unpack tensor dimensions: batch size, channels, height, width
    b, c, h, w = tensor.size()

    # Reshape the tensor to (b * c) x (h * w)
    # This flattens each feature map (channel) into a row vector
    features = tensor.view(b * c, h * w)

    # Compute the Gram matrix by multiplying the feature matrix with its transpose
    # This gives a measure of how each channel correlates with others
    G = torch.mm(features, features.t())

    # Normalize the Gram matrix by the total number of elements
    return G.div(b * c * h * w)


In [18]:
# Function to preprocess and load an image for the model
def load_image(image):
    # Apply transformations (e.g., resize and convert to tensor),
    # add a batch dimension (unsqueeze), and move to the selected device (CPU/GPU)
    image = transform(image).unsqueeze(0).to(device)
    
    # Return the processed image tensor
    return image


In [19]:
# Function to perform neural style transfer with optional masking
def run_style_transfer(content_img, style_img, mask_img, num_steps=300, style_weight=1e6, content_weight=1):
    # Load and preprocess content and style images
    content = load_image(content_img)
    style = load_image(style_img)
    
    # Load and preprocess the mask image (convert to tensor and move to device)
    mask = transform(mask_img).unsqueeze(0).to(device)

    # Clone the content image as the starting point for the output image
    input_img = content.clone().requires_grad_(True)

    # Wrap the VGG model for feature extraction
    model = VGGFeatures(vgg, style_layers, content_layers).to(device)

    # Set up the optimizer (L-BFGS is good for style transfer)
    optimizer = torch.optim.LBFGS([input_img])

    # Extract target content and style features (reference features)
    style_targets = {}
    content_targets = {}
    content_features, style_features = model(content)         # From content image
    _, style_features_ref = model(style)                      # From style image

    # Store the content feature outputs (as fixed targets)
    for name in content_features:
        content_targets[name] = content_features[name].detach()

    # Store the Gram matrices for style targets
    for name in style_features:
        style_targets[name] = gram_matrix(style_features_ref[name].detach())

    # Use a mutable object to track steps inside closure
    run = [0]

    # Optimization loop
    while run[0] <= num_steps:
        def closure():
            # Clamp pixel values to be in [0, 1] range
            input_img.data.clamp_(0, 1)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Get predicted content and style features for the input image
            content_pred, style_pred = model(input_img)

            # Initialize content and style loss
            content_loss = 0
            style_loss = 0

            # Compute content loss (MSE between input and content features)
            for name in content_pred:
                content_loss += content_weight * torch.nn.functional.mse_loss(
                    content_pred[name], content_targets[name]
                )

            # Compute style loss (MSE between Gram matrices of input and style)
            for name in style_pred:
                G = gram_matrix(style_pred[name])
                A = style_targets[name]
                style_loss += style_weight * torch.nn.functional.mse_loss(G, A)

            # Total loss = content + style
            total_loss = content_loss + style_loss

            # Backpropagate the loss
            total_loss.backward()

            # Track iterations
            run[0] += 1
            return total_loss

        # Perform optimizer step using the closure function
        optimizer.step(closure)

    # Final clamping to ensure image values are within range
    input_img.data.clamp_(0, 1)

    # Post-process the image: convert tensor to PIL image
    result = input_img.cpu().clone().squeeze(0)
    result = transforms.ToPILImage()(result)

    # Resize the mask and content image to match result size
    mask_img = mask_img.convert("L").resize(result.size)
    content_img = content_img.resize(result.size)

    # Blend the result and content image using the mask
    result = Image.composite(result, content_img, mask_img)

    # Return the final stylized and masked image
    return result


In [None]:
# Define a function to perform masked style transfer from file paths
def stylize_image(content, style, mask):
    # Load and preprocess the content image
    content = Image.open(content).convert("RGB").resize((256, 256))

    # Load and preprocess the style image
    style = Image.open(style).convert("RGB").resize((256, 256))

    # Load and preprocess the mask image (convert to grayscale)
    mask = Image.open(mask).convert("L").resize((256, 256))

    # Perform style transfer using the previously defined function
    output = run_style_transfer(content, style, mask)

    # Return the final stylized image as a PIL object
    return output

# Create a Gradio interface for interactive style transfer
interface = gr.Interface(
    fn=stylize_image,  # Function to call when inputs are provided
    inputs=[
        gr.Image(type="filepath", label="Content Image"),  # User uploads content image
        gr.Image(type="filepath", label="Style Image"),    # User uploads style image
        gr.Image(type="filepath", label="Mask Image")      # User uploads binary mask
    ],
    outputs=gr.Image(type="pil", label="Stylized Output"),  # Display output as PIL image
    title="Masked Style Transfer with VGG16",               # Title shown on the web app
    description="Upload a content image, a style image, and a binary mask to selectively stylize parts of your image."  # Instruction
)

# Launch the Gradio interface (with optional debug and no public share)
interface.launch(debug=True, share=False)


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.
