<a href="https://colab.research.google.com/github/MammadovN/Machine_Learning/blob/main/projects/03_deep_learning/neural_style_transfer/style_transfer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import copy
from google.colab import files
import io

# Check whether a GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# ⬆️ 1. Upload helper ----------------------------------------------------------
def upload_images():
    """Let the user upload two files, then determine which is content/style.

    Returns
    -------
    tuple[str, str]
        A pair of paths: (content_path, style_path)
    """
    uploaded = files.upload()
    content_path, style_path = None, None

    # Auto-detect by filename
    for filename in uploaded.keys():
        fname = filename.lower()
        if "content" in fname:
            content_path = filename
        elif "style" in fname:
            style_path = filename

    # Ask the user if auto-detection failed
    if content_path is None or style_path is None:
        print(
            "Please indicate which of the uploaded files is the CONTENT image "
            "and which is the STYLE image:"
        )
        file_list = list(uploaded.keys())
        for i, file in enumerate(file_list):
            print(f"{i}: {file}")

        content_idx = int(input("Number of the content image: "))
        style_idx   = int(input("Number of the style image  : "))

        content_path = file_list[content_idx]
        style_path   = file_list[style_idx]

    print(f"Content image: {content_path}")
    print(f"Style image  : {style_path}")

    return content_path, style_path


# ⬆️ 2. Image-loading / preprocessing -----------------------------------------
def load_image(img_path, max_size: int = 400, shape=None):
    """Load an image file and transform it into a pre-processed tensor.

    Parameters
    ----------
    img_path : str
        File path of the image to load.
    max_size : int, optional
        Maximum dimension (width or height). Larger images are downsized.
    shape : tuple[int, int] | None, optional
        Force-resize to this shape instead of using `max_size`.

    Returns
    -------
    torch.Tensor
        A 4-D tensor (1, C, H, W) ready for VGG.
    """
    image = Image.open(img_path)

    # Decide new size
    size = max_size if max(image.size) > max_size else max(image.size)
    if shape is not None:
        size = shape

    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225)
        ),
    ])

    # Convert to tensor and add batch dimension
    image = transform(image)[:3, :, :].unsqueeze(0)
    return image


# ⬆️ 3. Tensor → displayable image --------------------------------------------
def im_convert(tensor):
    """Convert a pre-processed tensor back to a NumPy image in [0, 1]."""
    image = tensor.to("cpu").clone().detach().numpy().squeeze()
    image = image.transpose(1, 2, 0)  # CHW ➜ HWC
    image = (
        image * np.array((0.229, 0.224, 0.225)) +
        np.array((0.485, 0.456, 0.406))
    )
    return np.clip(image, 0, 1)

In [None]:
class ContentStyleLoss(nn.Module):
    """
    Compute the total loss (content + style) for neural style transfer.
    Uses a pre-trained VGG-19 feature extractor.
    """
    def __init__(self, style_img, content_img,
                 content_weight: float = 1.0,
                 style_weight: float = 1_000_000.0):
        super().__init__()
        self.style_img     = style_img
        self.content_img   = content_img
        self.content_weight = content_weight
        self.style_weight   = style_weight

        # Use VGG-19 convolutional features
        self.model = models.vgg19(pretrained=True).features.eval()
        self.content_layers = ['conv_4']
        self.style_layers   = [
            'conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'
        ]

        # Gram-matrix helper
        self.gram = GramMatrix()

    def forward(self, input_img):
        content_loss = 0.0
        style_loss   = 0.0

        # Pre-compute style and content targets
        style_features   = self.get_features(self.style_img)
        content_features = self.get_features(self.content_img)

        # Features of the current (generated) image
        input_features = self.get_features(input_img)

        # Content loss
        for layer in self.content_layers:
            content_loss += F.mse_loss(
                input_features[layer],
                content_features[layer]
            )

        # Style loss
        for layer in self.style_layers:
            input_gram  = self.gram(input_features[layer])
            style_gram  = self.gram(style_features[layer])
            style_loss += F.mse_loss(input_gram, style_gram)

        # Total loss = weighted sum
        total_loss = (
            self.content_weight * content_loss +
            self.style_weight   * style_loss
        )
        return total_loss

    def get_features(self, image):
        """Extract selected layer activations from VGG-19."""
        features = {}
        layer_name_mapping = {
            '0':  'conv_1',
            '5':  'conv_2',
            '10': 'conv_3',
            '19': 'conv_4',
            '28': 'conv_5',
        }

        x = image
        for name, layer in self.model._modules.items():
            x = layer(x)
            if name in layer_name_mapping:
                features[layer_name_mapping[name]] = x
        return features

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# Gram-matrix helper
# ──────────────────────────────────────────────────────────────────────────────
class GramMatrix(nn.Module):
    def forward(self, input):
        batch_size, f_maps, h, w = input.size()
        features = input.view(batch_size, f_maps, h * w)
        gram = torch.bmm(features, features.transpose(1, 2))

        # Normalise
        gram = gram.div(f_maps * h * w)
        return gram


# ──────────────────────────────────────────────────────────────────────────────
# Style-transfer routine
# ──────────────────────────────────────────────────────────────────────────────
def style_transfer(
    content_path,
    style_path,
    num_steps: int = 2_000,
    content_weight: float = 1.0,
    style_weight: float = 1_000_000.0,
):
    # Load images
    content = load_image(content_path).to(device)
    style   = load_image(style_path, shape=content.shape[-2:]).to(device)

    # Image to be optimised (start from a copy of the content image)
    generated = content.clone().requires_grad_(True).to(device)

    # Optimiser
    optimizer = optim.Adam([generated], lr=0.003)

    # Loss function
    loss_fn = ContentStyleLoss(style, content, content_weight, style_weight).to(device)

    # Progress bar (tqdm in notebooks, plain range otherwise)
    try:
        from tqdm.notebook import tqdm
        iterator = tqdm(range(num_steps))
    except ImportError:
        iterator = range(num_steps)

    # Training loop
    for step in iterator:
        optimizer.zero_grad()
        loss = loss_fn(generated)
        loss.backward()
        optimizer.step()

        if step % 200 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")

            # Show intermediate results
            plt.figure(figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.imshow(im_convert(content))
            plt.title("Content")
            plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.imshow(im_convert(style))
            plt.title("Style")
            plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.imshow(im_convert(generated))
            plt.title(f"Generated (Step {step})")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

    return generated

In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# Save the output image and download it from Colab
# ──────────────────────────────────────────────────────────────────────────────
def save_and_download_image(tensor, filename: str = "style_transfer_result.jpg"):
    final_img = im_convert(tensor)
    plt.figure(figsize=(10, 10))
    plt.imshow(final_img)
    plt.axis("off")
    plt.savefig(filename, bbox_inches="tight", pad_inches=0.1)
    plt.show()

    # Download from Colab
    files.download(filename)
    print(f"Style transfer completed! The result was downloaded as '{filename}'.")


# ──────────────────────────────────────────────────────────────────────────────
# Top-level helper
# ──────────────────────────────────────────────────────────────────────────────
def run_style_transfer():
    print("Please upload one CONTENT image and one STYLE image.")
    print("(Tip: If the filenames contain the words 'content' and 'style', "
          "they will be detected automatically.)")

    content_path, style_path = upload_images()

    # Get parameters from the user
    print("\nAdjust style-transfer parameters:")
    num_steps = int(input("Number of iterations (suggested: 1000–3000) [2000]: ") or 2000)
    content_weight = float(input("Content weight (suggested: 1) [1]: ") or 1)
    style_weight   = float(input("Style weight   (suggested: 1,000,000) [1000000]: ") or 1_000_000)

    print("\nStarting style transfer …")
    result = style_transfer(
        content_path,
        style_path,
        num_steps=num_steps,
        content_weight=content_weight,
        style_weight=style_weight,
    )

    # Save and download the final image
    save_and_download_image(result)


# ──────────────────────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────────────────────
print("Neural Style Transfer — Google Colab")
print("=" * 40)
print("This program recreates a CONTENT image in the STYLE of another image.\n")
run_style_transfer()
