<a href="https://colab.research.google.com/github/Siddharth17542/project-sid/blob/main/image_style_transfer_playground_with_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
import PIL
import yaml
import torch
import torch.nn as nn
from torchvision.utils import save_image
from PIL import Image
from torchvision import transforms
from torchvision.models import vgg19
from torchvision.utils import save_image
from google.colab import files

### Upload your content and style images by running the cell below:

In [None]:
uploaded = files.upload()

Saving figures.jpg to figures.jpg
Saving mosaic.jpg to mosaic.jpg


In [None]:
#@title Specify the file name of uploaded content image and style image:

#@markdown Enter the content image file name.
content_filename = "figures.jpg" #@param {type:"string"}

#@markdown Enter the style image file name.
style_filename = "mosaic.jpg" #@param {type:"string"}


In [None]:
#@title Some general settings for user experience:

#@markdown Check this to stop showing debugging messages, loss function values during training process, and stops generating intermediate images.
quiet = True #@param {type:"boolean"}

#@markdown Check this to automatically download output images after being generated.
download = False #@param {type:"boolean"}

In [None]:
#@title Here are some settings you could adjust for the output image (leave it blank to apply default values):

#@markdown Size of the output image. Either one integer or two integers (height, weight) separated by comma is accepted. Will use the dimensions of content image if not provided.
output_size = "128" #@param {type:"string"}

#@markdown Format of the output image. Can be either "jpg", "png", "jpeg", or "same". If "same", output image will have the same format as the content image. "jpg" will be the default format.
output_image_format = "jpg" #@param {type:"string"}

You may also provide a training configuration file in yaml format to set customized values for hyperparameters during the training process. May include the following hyperparameters:

- num_epochs
- learning_rate
- alpha
- beta
- capture_content_features_from
- capture_style_features_from

Note that not providing a yaml file for configuration does not affect the program's functionality. An output image will be synthesized using default hyperparameter values. Run the following cell to upload your training configuration file (must be in .yaml):

In [None]:
train_config_uploaded = files.upload()
train_config_filename = list(train_config_uploaded.keys())[0]

print("Loading training configuration file...")
try:
    with open(train_config_filename, 'r') as f:
        training_configuration = yaml.safe_load(f)
except FileNotFoundError:
    print(f"ERROR: could not find such file: '{train_config_filename}'.")
except yaml.YAMLError:
    print(f"ERROR: fail to load yaml file: '{train_config_filename}'.")
else:
    print("Training configuration file successfully loaded.")

Saving example_train_config.yaml to example_train_config (1).yaml
Loading training configuration file...
Training configuration file successfully loaded.


In [None]:
def load_image(image_path, device, output_size=None, normalize=False):
    """Loads an image by transforming it into a tensor."""
    img = Image.open(image_path)

    output_dim = None
    if output_size is None:
        output_dim = (img.size[1], img.size[0])
    elif isinstance(output_size, int):
        output_dim = (output_size, output_size)
    elif isinstance(output_size, tuple):
        if (len(output_size) == 2) and isinstance(output_size[0], int) and isinstance(output_size[1], int):
            output_dim = output_size
    else:
        raise ValueError("ERROR: output_size must be an integer or a 2-tuple of (height, width) if provided.")

    torch_loader = transforms.Compose(
        [
            transforms.Resize(output_dim),
            transforms.ToTensor()
        ]
    )

    img_tensor = torch_loader(img).unsqueeze(0)
    return img_tensor.to(device)


def get_image_name_ext(img_path):
    """Get name and extension of the image file from its path."""
    return os.path.splitext(os.path.basename(img_path))[0], os.path.splitext(os.path.basename(img_path))[1][1:]

In [None]:
class ImageStyleTransfer_VGG19(nn.Module):
    def __init__(self):
        super(ImageStyleTransfer_VGG19, self).__init__()

        self.chosen_features = {0: 'conv11', 5: 'conv21', 10: 'conv31', 19: 'conv41', 28: 'conv51'}
        self.model = vgg19(weights='DEFAULT').features[:29]

    def forward(self, x):
        feature_maps = dict()
        for idx, layer in enumerate(self.model):
            x = layer(x)
            if idx in self.chosen_features.keys():
                feature_maps[self.chosen_features[idx]] = x

        return feature_maps


def _get_content_loss(content_feature, generated_feature):
    """Compute MSE between content feature map and generated feature map as content loss."""
    return torch.mean((generated_feature - content_feature) ** 2)


def _get_style_loss(style_feature, generated_feature):
    """Compute MSE between gram matrix of style feature map and of generated feature map as style loss."""
    _, channel, height, width = generated_feature.shape
    style_gram = style_feature.view(channel, height*width).mm(
        style_feature.view(channel, height*width).t()
    )
    generated_gram = generated_feature.view(channel, height*width).mm(
        generated_feature.view(channel, height*width).t()
    )

    return torch.mean((generated_gram - style_gram) ** 2)


def train(content, style, generated, device, train_config, output_dir, output_img_fmt, content_img_name, style_img_name, verbose=False):
    """Update the output image using pre-trained VGG19 model."""
    model = ImageStyleTransfer_VGG19().to(device).eval()    # freeze parameters in the model

    # set default value for each configuration if not specified in train_config
    num_epochs = train_config.get('num_epochs') if train_config.get('num_epochs') is not None else 6000
    lr = train_config.get('learning_rate') if train_config.get('learning_rate') is not None else 0.001
    alpha = train_config.get('alpha') if train_config.get('alpha') is not None else 1
    beta = train_config.get('beta') if train_config.get('beta') is not None else 0.01
    capture_content_features_from = train_config.get('capture_content_features_from') \
        if train_config.get('capture_content_features_from') is not None else {'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}
    capture_style_features_from = train_config.get('capture_style_features_from') \
        if train_config.get('capture_style_features_from') is not None else {'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}

    # check if values passed to capture_content_features_from and capture_style_features_from are valid
    if not isinstance(capture_content_features_from, set):
        if isinstance(capture_content_features_from, dict):
            capture_content_features_from = set(capture_content_features_from.keys())
        elif isinstance(capture_content_features_from, str):
            capture_content_features_from = set([item.strip() for item in capture_content_features_from.split(',')])
        else:
            print(f"ERROR: invalid value for 'capture_content_features_from' in training configuration file: {capture_content_features_from}.")
            return 0

    if not capture_content_features_from.issubset({'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}):
        print(f"ERROR: invalid value for 'capture_content_features_from' in training configuration file: {capture_content_features_from}.")
        return 0

    if not isinstance(capture_style_features_from, set):
        if isinstance(capture_style_features_from, dict):
            capture_style_features_from = set(capture_style_features_from.keys())
        elif isinstance(capture_style_features_from, str):
            capture_style_features_from = set([item.strip() for item in capture_style_features_from.split(',')])
        else:
            print(f"ERROR: invalid value for 'capture_style_features_from' in training configuration file: {capture_style_features_from}.")
            return 0

    if not capture_style_features_from.issubset({'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}):
        print(f"ERROR: invalid value for 'capture_style_features_from' in training configuration file: {capture_style_features_from}.")
        return 0

    optimizer = torch.optim.Adam([generated], lr=lr)

    if verbose:
        # create a directory to save intermediate results
        intermediate_dir = os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-intermediate')
        if not os.path.exists(intermediate_dir):
            os.makedirs(intermediate_dir)

    for epoch in range(num_epochs):
        # get features maps of content, style and generated images from chosen layers
        content_features = model(content)
        style_features = model(style)
        generated_features = model(generated)

        content_loss = style_loss = 0

        for layer_name in generated_features.keys():
            content_feature = content_features[layer_name]
            style_feature = style_features[layer_name]
            generated_feature = generated_features[layer_name]

            content_loss_per_feature = _get_content_loss(content_feature, generated_feature)
            style_loss_per_feature = _get_style_loss(style_feature, generated_feature)

            if layer_name in capture_content_features_from:
                content_loss += content_loss_per_feature

            if layer_name in capture_style_features_from:
                style_loss += style_loss_per_feature

        # compute loss
        total_loss = alpha * content_loss + beta * style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # print loss value and save progress every 200 epochs
        if verbose:
            if (epoch + 1) % 200 == 0:
                save_image(generated, os.path.join(intermediate_dir, f'nst-{content_img_name}-{style_img_name}-{epoch + 1}.{output_img_fmt}'))

                print(f"\tEpoch {epoch + 1}/{num_epochs}, loss = {total_loss.item()}")

    if verbose:
        print("\t================================")
        print(f"\tIntermediate images are saved in directory: '{intermediate_dir}'")
        print("\t================================")

    return 1


In [None]:
def main():
    image_dir = output_dir = '.'
    content_path = os.path.join(image_dir, content_filename)
    style_path = os.path.join(image_dir, style_filename)

    verbose = not quiet

    if verbose:
        print("Loading content and style images...")

    try:
        content_img = Image.open(content_path)
    except FileNotFoundError:
        print(f"ERROR: could not find such file: '{content_path}'.")
        return
    except PIL.UnidentifiedImageError:
        print(f"ERROR: could not identify image file: '{content_path}'.")
        return

    try:
        style_img = Image.open(style_path)
    except FileNotFoundError:
        print(f"ERROR: could not find such file: '{style_path}'.")
        return
    except PIL.UnidentifiedImageError:
        print(f"ERROR: could not identify image file: '{style_path}'.")
        return

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # load content and style images
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_dim = output_size.strip()
    if len(output_dim) == 0:
        output_dim = None
    else:
        try:
            output_dim = [int(item.strip()) for item in output_dim.split(',')]
        except ValueError:
            print(f"ERROR: invalid input for output_size: '{output_size}'. Should be integers separeted by comma.")
            return

        if len(output_dim) > 1:
            output_dim = tuple(output_dim)
        else:
            output_dim = output_dim[0]

    content_tensor = load_image(content_path, device, output_size=output_dim)
    output_size = (content_tensor.shape[2], content_tensor.shape[3])
    style_tensor = load_image(style_path, device, output_size=output_dim)

    if verbose:
        print("Content and style images successfully loaded.")
        print()
        print("Initializing output image...")

    # initialize output image
    generated_tensor = content_tensor.clone().requires_grad_(True)

    if verbose:
        print("Output image successfully initialized.")
        print()

    if 'training_configuration' not in globals():
        train_config = dict()
    else:
        train_config = training_configuration.copy()

    if verbose:
        print("Training...")

    content_img_name, content_img_fmt = get_image_name_ext(content_path)
    style_img_name, _ = get_image_name_ext(style_path)

    output_img_fmt = output_image_format.strip()
    if len(output_img_fmt) == 0: output_img_fmt = 'jpg'
    elif output_img_fmt == 'same': output_img_fmt = content_img_fmt
    elif output_img_fmt not in {'jpg', 'png', 'jpeg', 'same'}:
        print(f"ERROR: invalid input for output_img_fmt: {output_img_fmt}. Should be one of \"jpg\", \"png\", \"jpeg\", \"same\".")
        return

    # train model
    success = train(content_tensor, style_tensor, generated_tensor, device, train_config, output_dir, output_img_fmt, content_img_name, style_img_name, verbose=verbose)

    # save output image to specified directory
    if success:
        save_image(generated_tensor, os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-final.{output_img_fmt}'))

    if verbose:
        print(f"Output image successfully generated as {os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-final.{output_img_fmt}')}.")

    # download final output image
    if download:
        files.download(os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-final.{output_img_fmt}'))

In [None]:
main()

UnboundLocalError: ignored