In [None]:
import torch  # Deep learning framework
import torch.nn as nn  # Neural network components
import torchvision.transforms as T  # Image transformations
import torch.nn.functional as F  # Functional operations
import numpy as np  # Numerical operations
from torchvision.transforms.functional import resize  # Import the resize function from torchvision.transforms.functional for resizing images
from torchvision.utils import save_image  # Import the save_image function from torchvision.utils for saving images
from torchvision.models import vgg19  # Pre-trained VGG19 network

from PIL import Image  # Image handling
from torchvision.transforms.transforms import Resize  # Resize transform from torchvision.transforms.transforms for resizing images

#ENCODER

In [None]:
class VGGEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        # Load the VGG19 model with default weights
        vgg = vgg19(weights='DEFAULT').features

        # Define different slices of the VGG model for feature extraction
        self.slice1 = vgg[:2]
        self.slice2 = vgg[2:7]
        self.slice3 = vgg[7:12]
        self.slice4 = vgg[12:21]

        # Set requires_grad=False for all parameters to freeze the pre-trained weights
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, images, output_last_feature=False):
        """
        Forward pass of the VGGEncoder.

        Args:
            images (Tensor): Input images to be encoded.
            output_last_feature (bool): If True, only the last feature is returned. Otherwise, all intermediate features are returned.

        Returns:
            Tensor or Tuple[Tensor]: Encoded features from the VGG encoder. If output_last_feature is True, returns the last feature tensor. Otherwise, returns a tuple of feature tensors from each slice.
        """
        # Pass the input images through each slice of the VGG encoder
        h1 = self.slice1(images)
        h2 = self.slice2(h1)
        h3 = self.slice3(h2)
        h4 = self.slice4(h3)

        if output_last_feature:
            # Return the last feature tensor
            return h4
        else:
            # Return a tuple of feature tensors from each slice
            return h1, h2, h3, h4


# ADAIN

In [None]:
def calc_mean_std(features):
    """
    Calculate the mean and standard deviation of the input features.

    Args:
        features (Tensor): Input features of shape [batch_size, c, h, w].

    Returns:
        features_mean (Tensor): Mean of the features of shape [batch_size, c, 1, 1].
        features_std (Tensor): Standard deviation of the features of shape [batch_size, c, 1, 1].
    """

    # Get the batch size and number of channels from the input features
    batch_size, c = features.size()[:2]

    # Calculate the mean and reshape it to match the required shape
    features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1)

    # Calculate the standard deviation and reshape it to match the required shape
    features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) + 1e-6

    return features_mean, features_std


def adain(content_features, style_features):
    """
    Apply Adaptive Instance Normalization (AdaIN) to the content features using style features.

    Args:
        content_features (Tensor): Content features of shape [batch_size, c, h, w].
        style_features (Tensor): Style features of shape [batch_size, c, h, w].

    Returns:
        normalized_features (Tensor): Normalized features of shape [batch_size, c, h, w].
    """

    # Calculate the mean and standard deviation of the content and style features
    content_mean, content_std = calc_mean_std(content_features)
    style_mean, style_std = calc_mean_std(style_features)

    # Normalize the content features using the style features
    normalized_features = style_std * (content_features - content_mean) / content_std + style_mean    # Adaptive Instance Normalization

    return normalized_features

# Decoder

In [None]:
class RC(torch.nn.Module):
    """
    A wrapper of ReflectionPad2d and Conv2d

    This class represents a combination of reflection padding and a convolutional layer.
    It applies reflection padding to the input and then performs convolution on the padded input.
    Optionally, it applies ReLU activation to the output of the convolution.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Size of the convolution kernel. Default is 3.
        pad_size (int): Size of the reflection padding. Default is 1.
        activated (bool): Whether to apply activation (ReLU) after convolution. Default is True.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, pad_size=1, activated=True):
        super().__init__()
        self.pad = nn.ReflectionPad2d((pad_size, pad_size, pad_size, pad_size))
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.activated = activated

    def forward(self, x):
      """
        Forward pass of the RC module.

        Args:
            x: Input tensor of shape (batch_size, in_channels, height, width).

        Returns:
            Output tensor  after applying reflection padding, convolution,
            and activation (if enabled) of shape (batch_size, out_channels, height, width)
        """
        h = self.pad(x)     # Apply reflection padding to the input tensor
        h = self.conv(h)    # Perform convolution on the padded input
        if self.activated:  # Apply ReLU activation if activated is True
            return F.relu(h)
        else:
            return h         # Otherwise, return the output without activation


class Decoder(nn.Module):
  """
  Decoder network for image reconstruction.
  This network takes features extracted by an encoder network with
  adaptive instance normalization applied using style features and generates a reconstructed image.
  This module consists of a series of RC (ReflectionPad2d and Conv2d) layers for upsampling and Image reconstruction.
  """
    def __init__(self):
        super().__init__()
        self.rc1 = RC(512, 256, 3, 1)
        self.rc2 = RC(256, 256, 3, 1)
        self.rc3 = RC(256, 256, 3, 1)
        self.rc4 = RC(256, 256, 3, 1)
        self.rc5 = RC(256, 128, 3, 1)
        self.rc6 = RC(128, 128, 3, 1)
        self.rc7 = RC(128, 64, 3, 1)
        self.rc8 = RC(64, 64, 3, 1)
        self.rc9 = RC(64, 3, 3, 1, False)

    def forward(self, features):
       """
        Forward pass of the Decoder module.

        Args:
            features (torch.Tensor): Input features from the encoder module.

        Returns:
            torch.Tensor: Output tensor representing the reconstructed image.
        """
        # Forward pass of the Decoder module for image upsampling and reconstruction
        h = self.rc1(features)
        h = F.interpolate(h, scale_factor=2)      # Perform upsampling using F.interpolate with scale factor 2
        h = self.rc2(h)
        h = self.rc3(h)
        h = self.rc4(h)
        h = self.rc5(h)
        h = F.interpolate(h, scale_factor=2)      # Perform another upsampling using F.interpolate with scale factor 2
        h = self.rc6(h)
        h = self.rc7(h)
        h = F.interpolate(h, scale_factor=2)      # Perform another upsampling using F.interpolate with scale factor 2
        h = self.rc8(h)
        h = self.rc9(h)
        return h

# Image Generation

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg_encoder = VGGEncoder()       # Initialize the VGGEncoder to extract content and style features
        self.decoder = Decoder()              # Initialize the Decoder for image reconstruction

    def generate(self, content_images, style_images, alpha=1.0):
        """
        Generates stylized images using Adaptive Instance Normalization (AdaIN) that aligns the mean and variance
        of the content features with those of the style features.

        Args:
            content_images (torch.Tensor): A tensor of shape (batch_size, channels, height, width) representing
                the content image(s).
            style_images (torch.Tensor): A tensor of shape (batch_size, channels, height, width) representing
                the style image(s).
            alpha (float, optional): A value between 0 and 1 indicating the strength of the style transfer.
                0 corresponds to using only the content features, and 1 corresponds to match mean and variance of the style features.
                Default is 1.0.

        Returns:
           out (torch.Tensor): A tensor of shape (batch_size, channels, height, width) representing the generated
                stylized image(s).

        Notes:
            The function performs style transfer using the Adaptive Instance Normalization (AdaIN) technique, which
            aligns the mean and variance of the content features with those of the style features. The 'alpha' parameter
            controls the strength of the style transfer, allowing users to blend the content and style features.
            When 'alpha' is set to 0, the output will be similar to the content image(s), while an 'alpha' of 1 will
            result in images closely resembling the style image(s).

            The VGGEncoder and Decoder are parts of the model used for feature extraction and image reconstruction.
            The content and style features are extracted using the VGGEncoder, and then AdaIN is applied to combine
            the features based on 'alpha'. The combined features are used to generate the stylized output using the Decoder.
        """
        # Extract content and style features using VGGEncoder
        content_features = self.vgg_encoder(content_images, output_last_feature=True)
        style_features = self.vgg_encoder(style_images, output_last_feature=True)

        # Apply AdaIN to align the mean and variance of the content features to match with the style features based on 'alpha'
        t = adain(content_features, style_features)
        t = alpha * t + (1 - alpha) * content_features

        # Generate the stylized output using the Decoder
        out = self.decoder(t)
        return out

# Transforms

In [None]:
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))    # The mean and std of ImageNet

def transforms(H, W):
    """
    Create a set of image transformations using PyTorch transforms.

    This function creates a composition of image transformations that includes resizing, converting the image to a tensor,
    and normalizing the image using pre-defined mean and standard deviation (stats).

    Args:
        H (int): The desired height of the transformed image.
        W (int): The desired width of the transformed image.

    Returns:
        torchvision.transforms.Compose: A composition of image transformations.

    Example:
        # Define the desired height and width of the transformed image
        H, W = 224, 224

        # Create the image transformations
        transform = transforms(H, W)

        # Apply the transformations to an image
        transformed_image = transform(image)
    """
    # Create a composition of image transformations using torchvision.transforms.Compose
    tfms = T.Compose([
        T.Resize((H, W)),            # Resize the image to the desired height and width
        T.ToTensor(),                # Convert the image to a PyTorch tensor
        T.Normalize(*stats, inplace=True)  # Normalize the image using the pre-defined mean and standard deviation
    ])

    return tfms


# Denormalization

In [None]:
def denorm(tensor):
    """
    De-normalize the input tensor using ImageNet statistics (mean and standard deviation).

    This function takes an input tensor that was previously normalized using ImageNet statistics (mean and standard deviation)
    and reverses the normalization process. It restores the original pixel values of the image.

    Args:
        tensor (torch.Tensor): Input tensor to be de-normalized, of shape [batch_size, c, h, w].

    Returns:
        torch.Tensor: De-normalized tensor with the same shape as the input tensor.

    Notes:
        The input tensor should have been previously normalized using the mean and standard deviation
        values typically used for the ImageNet dataset.

    Example:
        # Assume 'tensor' is a normalized tensor with ImageNet statistics
        denormalized_tensor = denorm(tensor)
    """
    # Define the standard deviation and mean tensors using ImageNet statistics
    std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1).to(device)
    mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1).to(device)

    # De-normalize the input tensor by applying the reverse transformation
    res = tensor * std + mean

    # Optionally clamp the values to ensure they are within the valid range [0, 1]
    # res = torch.clamp(res, 0, 1)

    return res

# Choosing device

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

device = get_default_device()

# Image prediction

In [None]:
def predict_image(content_path, style_path, model, output_path, alpha=1.0):
    """
    Generates a stylized image by aligning the mean and variance of the
    content features with those of the style features using the provided model
    and saves the result to the output path.

    Args:
        content_path (str): The file path to the content image.
        style_path (str): The file path to the style image.
        model (nn.Module): The pre-trained decoder model used for style transfer.
        output_path (str): The file path to save the output image.
        alpha (float): Controls the degree of stylization. It should be a value between 0 and 1 (default).

    Returns:
        None: The function saves the stylized image to the specified 'output_path' on disk.

    Note:
        - This function uses PyTorch and assumes that the model is compatible with PyTorch.
        - The provided model should have a method named 'generate' that takes two input tensors
          (content and style images) and an 'alpha' parameter to control stylization.
        - The 'alpha' parameter adjusts the degree of stylization, with 0 being no style and 1 being full style.
        - The stylized image will be saved to the 'output_path' with the filename provided in the output_path variable.
          e.g. /content/output/starrynight_goldengate
          The method returns the stylized output image tensor.
    """

    # Clear GPU memory to avoid potential memory issues
    torch.cuda.empty_cache()

    # Load content and style images
    c = Image.open(content_path)
    og_size = c.size                       # Saving the original size of the content image (width, height)
    og_size = (og_size[1], og_size[0])     # Transpose the original size to (height, width)
    s = Image.open(style_path)

    # Converting image to tensors and normalizing them
    if og_size[0] <= 2000:
        tfms = transforms(int(og_size[0]), int(og_size[1]))
    else:
        tfms = transforms(int(og_size[0] * 0.9), int(og_size[1] * 0.9))
    c_tensor = tfms(c).unsqueeze(0).to(device)
    tfms = transforms(s_size[1], s_size[0])
    s_tensor = tfms(s).unsqueeze(0).to(device)

    # Generate stylized image using the model
    model.eval()
    with torch.no_grad():
        out = model.generate(c_tensor, s_tensor, alpha)
    img = out.squeeze()

    # Denormalizing the image and saving it
    img = denorm(img).cpu().detach()
    save_image(resize(img, size=(og_size)), f'{output_path}.jpg', nrow=1)

    # Clear GPU memory again to release resources
    torch.cuda.empty_cache()

# File paths

In [None]:
# Creating a folder for outputs
!mkdir /content/output

mkdir: cannot create directory ‘/content/output’: File exists


In [None]:
# Loading the pre-trained model.
model = Model().to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/ADAIN/final.pth'))

<All keys matched successfully>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

In [None]:
# creating 9 images with different degrees of stylization ranging from 0 to 1 with an increase in 0.125 at every step
for i in np.arange(0,1.1,0.125):
  predict_image(content_path='/content/istockphoto-1359275011-170667a.jpg',style_path ='/content/starry-night-g89b7431a8_1920.jpg',
             model=model,output_path=f'/content/output/{str(i)}',alpha =i)

In [None]:
file = []
file.append('/content/pexels-pixabay-208745.jpg')
file.append('/content/starry-night-g89b7431a8_1920.jpg')
names = os.listdir('/content/output')
names.sort()
# Appending names of files in the output folder.
for i in names:
  file.append(i)

# Visualizing the effect of alpha parameter

The provided function named show_images(names) that displays a grid of images. This function takes a list of image file paths as input and generates a grid of images with specified titles for each image.

The titles for the images are set based on their position in the grid. The first image is given the title 'Content', the second image is titled 'Style', and the rest of the images are given titles that indicate the corresponding value of 'alpha'. 'Alpha' controls the degree of stylization, and its values range from 0 to 1. The 'alpha' value is incremented by 0.125 for each successive stylized image.

After displaying all the images in the grid, the function saves the final grid of images as 'alpha.png' in the '/content/' directory.

In [None]:
def show_images(names):
    """
    Display a grid of images specified by the list of file names.

    Args:
        names (list of str): A list containing file paths or names of images to display.

    Returns:
        None

    Note:
        - The function uses the matplotlib library for displaying images in a grid format.
        - The 'names' list should contain the file paths names of the images to display.
        - The images will be arranged in a grid with 12 rows and 3 columns.
        - The first two images in the 'names' list are considered as the content and style images, respectively.
          The rest of the images (if any) are the stylized images corresponding to different values of 'alpha'.
          'alpha' determines the degree of stylization, ranging from 0 to 1.
        - The content and style images will have titles 'Content' and 'Style', respectively.
        - The stylized images will have titles indicating the value of 'alpha'.
        - The final grid of images will be saved as 'alpha.png'.
    """

    # Parameters for our graph; we'll output images in a 12x3 configuration
    nrows = 12
    ncols = 3
    alpha = 0.0
    fontsize = 60
    fig = plt.gcf()
    fig.set_size_inches(ncols * 20, nrows * 20)

    for i in range(len(names)):
        # Load the image based on the index
        if i == 0 or i == 1:
            img = mpimg.imread(names[i])
        else:
            img = mpimg.imread('/content/output/' + names[i])

        # Set up subplot; subplot indices start at 1
        sp = plt.subplot(nrows, ncols, i + 1)
        sp.axis('Off')  # Don't show axes (or gridlines)
        plt.imshow(img)

        # Set titles for content, style, and stylized images
        if i == 0:
            plt.title('Content', fontsize=fontsize)
        elif i == 1:
            plt.title('Style', fontsize=fontsize)
        else:
            plt.title(f'α = {str(alpha)}', fontsize=fontsize)
            alpha = alpha + 0.125

    # Save the grid of images as 'alpha.png' with tight bounding box
    plt.savefig('/content/alpha.png', bbox_inches='tight')

In [None]:
show_images(file)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
import numpy as np
import cv2

# read three images of different sizes
img1 = cv2.imread('/content/pexels-pixabay-208745.jpg')
img2 = cv2.imread('/content/starry-night-g89b7431a8_1920.jpg')
img3 = cv2.imread('/content/output/1.0.jpg')

# determine the maximum height and width of the images
h_max = max(img1.shape[0], img2.shape[0], img3.shape[0])
w_max = img1.shape[1] + img2.shape[1] + img3.shape[1]

# create an empty array of the required size
img_concat = np.zeros((h_max, w_max, 3), dtype=np.uint8)

# copy the individual images into the appropriate locations in the array
img_concat[:img1.shape[0], :img1.shape[1], :] = img1
img_concat[:img2.shape[0], img1.shape[1]:img1.shape[1]+img2.shape[1], :] = img2
img_concat[:img3.shape[0], img1.shape[1]+img2.shape[1]:, :] = img3

# save the concatenated image
cv2.imwrite('/content/img_concat.jpg', img_concat)


True