## install visionlens

In [None]:
!rm -rf /kaggle/working/*
!git clone https://github.com/SKT27182/VisionLens.git
!mv  /kaggle/working/VisionLens/* .

Cloning into 'VisionLens'...
remote: Enumerating objects: 216, done.[K
remote: Counting objects: 100% (216/216), done.[K
remote: Compressing objects: 100% (147/147), done.[K
remote: Total 216 (delta 110), reused 167 (delta 64), pack-reused 0 (from 0)[K
Receiving objects: 100% (216/216), 12.58 MiB | 17.60 MiB/s, done.
Resolving deltas: 100% (110/110), done.


In [None]:
!pip3 install einops==0.8.0

Collecting einops==0.8.0
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


## imports

In [None]:
import torch
import numpy as np

from typing import Union

import einops
from matplotlib import pyplot as plt
from torchvision.transforms.functional import to_pil_image


from visionlens.models import InceptionV1
from visionlens.objectives import objective_wrapper
from visionlens.optimize import Visualizer
from visionlens import objectives
from visionlens.images import get_images
from visionlens.utils import device

In [None]:
model = InceptionV1(pretrained=True).eval()

## Style Transfer

In [None]:
class StyleTransfer:

    def __init__(self, model, content_image: Union[str, torch.Tensor], style_image: Union[str, torch.Tensor]):
        self.model = model
        self.content_image = content_image if isinstance(content_image, torch.Tensor) else self.load_image(content_image)
        self.style_image = style_image if isinstance(style_image, torch.Tensor) else self.load_image(style_image)

        self.TRANSFER_INDEX = 0
        self.CONTENT_INDEX = 1
        self.STYLE_INDEX = 2

    def load_image(self, file_path: str) -> torch.Tensor:
        """
        Convert an image to a tensor.

        Args:
            file_path: Path to the image file.

        Returns:
            Tensor representation of the image of shape (H, W, C).
        """
        img = plt.imread(file_path)
        img = torch.tensor(img, dtype=torch.float32)
        return img

    def get_style_transfer_params(self, content_image, style_image, decorrelate=True, fft=True):
        """
        Generates parameters and a function for style transfer.

        Args:
            content_image (numpy.ndarray): The content image to be used for style transfer.
            style_image (numpy.ndarray): The style image to be used for style transfer.
            decorrelate (bool, optional): Whether to decorrelate the images. Defaults to True.
            fft (bool, optional): Whether to use FFT for image processing. Defaults to True.

        Returns:
            tuple: A tuple containing:
                - params: Parameters for the style transfer.
                - inner (function): A function that returns a tensor stack of style transfer input, content input, and style input.
        """

        content_h, content_w = content_image.shape[:2]  # assume we use content_image.shape
        params, image = get_images(content_h, content_w, decorrelate=decorrelate, fft=fft)

        def inner():
            style_transfer_input = image()[0]

            content_input = (
                torch.tensor(einops.rearrange(content_image, "h w c -> c h w").float().to(device))
            )
            style_input = (
                torch.tensor(
                    einops.rearrange(
                        style_image[:content_h, :content_w, :], "h w c -> c h w"
                    )
                )
                .float()
                .to(device)
            )
            return torch.stack([style_transfer_input, content_input, style_input])

        return params, inner

    @staticmethod
    def gram_matrix(features, normalize=True):
        C, H, W = features.shape
        # Flatten the features to compute the gram matrix
        features = einops.rearrange(features, "c h w -> c (h w)")
        # Compute the gram matrix, which is of shape (C, C)
        gram = torch.einsum("cl, dl -> cd", features, features)

        if normalize:
            gram = gram / (H * W)
        return gram

    @staticmethod
    def mean_L1_loss(x, y):
        """
            Compute the mean L1 loss between two tensors.

            Args:
                x (torch.Tensor): The first tensor.
                y (torch.Tensor): The second tensor.

            Returns:
                torch.Tensor: The mean L1 loss.
            """
        return torch.mean(torch.abs(x - y))

    @objective_wrapper
    def get_activations_difference(
        self,
        layer_names,
        difference_to,
        loss_type=None,
        obj_name="activation_difference",
        transform_f=None,
    ):

        obj_name = f"{obj_name}_activations_difference"

        loss_type = StyleTransfer.mean_L1_loss if loss_type is None else loss_type

        def get_activation_loss(act_dict):

            image_activations = [
                act_dict(layer_name)[difference_to] for layer_name in layer_names
            ]

            if transform_f is not None:
                image_activations = [transform_f(act) for act in image_activations]

            optimization_activations = [
                act_dict(layer_name)[self.TRANSFER_INDEX] for layer_name in layer_names
            ]

            if transform_f is not None:
                optimization_activations = [transform_f(act) for act in optimization_activations]

            losses = [loss_type(optimization_act, image_act) for optimization_act, image_act in zip(optimization_activations, image_activations)]

            return torch.stack(losses).sum()

        return get_activation_loss, obj_name

    def style_transfer(
        self,
        style_layers,
        content_layers,
        content_weight=200,
        style_weight=1,
        decorrelate=True,
        fft=True,
        threshold=(5, 50, 512), 
        ** kwargs,
    ):
        STYLE_LAYERS = style_layers
        CONTENT_LAYERS = content_layers

        param_f = lambda: self.get_style_transfer_params(self.content_image, self.style_image, decorrelate=decorrelate, fft=fft)

        content_obj = self.get_activations_difference(CONTENT_LAYERS, difference_to=self.CONTENT_INDEX, obj_name="content_loss")

        style_obj = self.get_activations_difference(
            STYLE_LAYERS, transform_f=self.gram_matrix, difference_to=self.STYLE_INDEX, obj_name="style_loss"
        )

        objective = content_weight * content_obj + style_weight * style_obj

        viz = Visualizer(self.model, objective)
        images = viz.visualize(param_f, lr=0.1, threshold=threshold, **kwargs)

        return images

In [None]:
STYLE_LAYERS = [
    "conv2d2",
    "mixed3a",
    "mixed4a",
    "mixed4b",
    "mixed4c",
]

CONTENT_LAYERS = [
    "mixed3b",
]


content_image_pth = "images/transfer_big_ben.png"
style_image_pth = "images/transfer_vangogh.png"

style_transfer = StyleTransfer(
    model, content_image=content_image_pth, style_image=style_image_pth
)
images = style_transfer.style_transfer(
    STYLE_LAYERS, CONTENT_LAYERS, content_weight=200, style_weight=1
)

In [None]:
from visionlens.img_utils import display_images_in_table
display_images_in_table(images[-1])

In [None]:
STYLE_LAYERS = [
    "conv2d2",
    "mixed3a",
    "mixed4a",
    "mixed4b",
    "mixed4c",
]

CONTENT_LAYERS = [
    "mixed3b",
]


content_image_pth = "images/transfer_big_ben.png"
style_image_pth = "images/transfer_picasso.png"

style_transfer = StyleTransfer(
    model, content_image=content_image_pth, style_image=style_image_pth
)
images = style_transfer.style_transfer(
    STYLE_LAYERS, CONTENT_LAYERS, content_weight=200, style_weight=1
)