### **Imports**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy

In [None]:
# Set GPU as default if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

In [None]:
# Connect to Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
root_dir = '/content/gdrive/MyDrive/NN for Images/Ex4'

In [None]:
# Reproducability
import random
manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)

W & B import

In [None]:
# API key: 62bc107823f2e5d51fd5f21b1081d81dd76a07db
!pip install wandb -qU
import wandb
wandb.login(key="62bc107823f2e5d51fd5f21b1081d81dd76a07db")

# **Some Constants, Operators and Functions**

### Constants

In [None]:
# use small size if no GPU to save time
imsize = 512 if torch.cuda.is_available() else 128

# Import the VGG net
cnn = models.vgg19(pretrained=True).features.eval()

# Mean of the data VGG was pre-trained on
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])

# STD of the data VGG was pre-trained on
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])

# List of layers to compute the Content-loss on
# taken from: ['conv_x' 'relu_x', 'pool_x', 'bn_x']
content_layers_default = ['conv_4']

# List of layers to compute the Style-loss on
# taken from: ['conv_x' 'relu_x', 'pool_x', 'bn_x']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

### Image Utilities

In [None]:
# Utility for the image loader
loader = transforms.Compose([
    transforms.Resize(imsize),
    transforms.ToTensor()])

# Unloads a tensor into a PIL image
unloader = transforms.ToPILImage()

# Load an image into a tensor
def image_loader(im_path):
    image = Image.open(im_path)
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

# Plot an image from a tensor
def imshow(tensor, title=None):
    image = tensor.cpu().clone()
    image = image.squeeze(0)
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)

def show_and_save_image(plt, img, title, path):
    plt.figure()
    imshow(img, title=title)
    plt.axis('off')
    plt.savefig(path + r"/{}.jpg".format(title))

# crop two images to fit their dimensions
def trim_images_to_fit_sizes(im1, im2):
    # Get the dimensions of the input images
    im1_dims = im1.shape
    im2_dims = im2.shape

    # Calculate the target dimensions
    target_dims = (min(im1_dims[2], im2_dims[2]), min(im1_dims[3], im2_dims[3]))

    # Calculate the starting positions for trimming
    start_pos_im1 = ((im1_dims[2] - target_dims[0]) // 2, (im1_dims[3] - target_dims[1]) // 2)
    start_pos_im2 = ((im2_dims[2] - target_dims[0]) // 2, (im2_dims[3] - target_dims[1]) // 2)

    # Crop the images to the target dimensions
    cropped_im1 = im1[:, :, start_pos_im1[0]:start_pos_im1[0]+target_dims[0], start_pos_im1[1]:start_pos_im1[1]+target_dims[1]]
    cropped_im2 = im2[:, :, start_pos_im2[0]:start_pos_im2[0]+target_dims[0], start_pos_im2[1]:start_pos_im2[1]+target_dims[1]]

    return cropped_im1, cropped_im2

# create an input image: either a copy of content_img or noise of same size
def get_input_image(content_img, noise=False):
    if noise:
      return torch.randn(content_img.data.size())
    else:
      return content_img.clone()

### Gram Matrix & Mean Var Calculators


In [None]:
def gram_matrix(input):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)
    features = input.view(a * b, c * d)  # resize F_XL into \hat F_XL
    G = torch.mm(features, features.t())  # compute the gram product
    # we 'normalize' the values of the gram matrix
    # by dividing by the number of element in each feature maps.
    return G.div(a * b * c * d)

In [None]:
def meanVar(input):
    a, b, c, d = input.size()  # a=batch size(=1)
    # b=number of feature maps
    # (c,d)=dimensions of a f. map (N=c*d)
    features = input.view(a * b, c * d)
    mean = torch.mean(features,1)
    var = torch.var(features,1)
    meanVar = torch.cat((mean,var))
    #todo need to normalize meanVar?
    return meanVar

###Some Sub-Modules
Those are sub-modules that will help to build the complete Style-Loss.

***Content Loss***: Initialized with a target, and computes the MSE loss between the target and a given input.

***Style Loss***: Initialized with a target, and computes the MSE between the Gram matrices of the target and a given input.

In both modules, target and input should be a pure activation layers.

***Normalization***: Initialized with a Mean and Std, and normalizes a given input by those values.

In [None]:
# Defining the Content-loss module
class ContentLoss(nn.Module):
    def __init__(self, target,):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

# Defining the Style-loss module based on Gram matrix
class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

# Defining the Style-loss module based on MeanVar matrix
class MeanVarLoss(nn.Module):
    def __init__(self, target_feature):
        super(MeanVarLoss, self).__init__()
        self.target = meanVar(target_feature).detach()

    def forward(self, input):
        MV = meanVar(input)
        self.loss = F.mse_loss(MV, self.target)
        return input

# Defining the normalization module
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

# Initializes an instance of Normalization with the VGG mean and std
VggNormalization = Normalization(cnn_normalization_mean, cnn_normalization_std)

### Style-Loss Module Definition
This function build a Style-Loss model.

The Style-Loss model copies the architecture of a given net (VGG in our case) to some point, and computes the content-loss and style-loss of particular given layers. It recieves:

***Content-layers*** List of `cnn`'s layers to computes their Content Loss with respect to `content_img`. Each such layer is a `ContentLoss` module, that it's target is the output of the current Style-Loss model (up to this layer), on `content_img`. Those layers are responsible to ***maintain the required content*** on the fly.

 ***Style-layers*** List of `cnn`'s layers to computes their Style Loss with respect to `style_img`. Each such layer is a `StyleLoss`/`MeanVarLoss` module, that it's target is the output of the current Style-Loss model (up to this layer), on `style_img`, where the Gram matrices are computed in the module itself. Those layers are responsible to ***maintain the required style*** on the fly.

 So, given a network *cnn*, list of *Content-layers* and *Style-layers*, the final architecture of the Style-Loss Module is as follows (Let M be the model):

 1. `Normalization` (default is the VggNormalization)
 2. For **L** in *cnn [first layer : last given layer]*:

  * `L`
  * if L is in Content-layers => add `ContentLoss(target=M(content_img))` layer
  * if L is in Style-layers => add `StyleLoss(target=M(style_img))` layer (or `MeanVarLoss`)





In [None]:
def get_style_model_and_losses(cnn, style_img, content_img,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default,
                               normalization=VggNormalization,
                               use_mean_var=False):

    content_losses, style_losses  = [], []

    # Initializes Sequential model with normalization
    style_loss_model = nn.Sequential(normalization)

    i = 0  # increment every time we see a conv
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        # add the layer to the model
        style_loss_model.add_module(name, layer)

        ## if layer is in Content-layers, add a content loss
        if name in content_layers:
            target = style_loss_model(content_img).detach()
            content_loss = ContentLoss(target)
            style_loss_model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        ## if layer is in Style-layers, add a style loss
        if name in style_layers:
            target_feature = style_loss_model(style_img).detach()
            style_loss = MeanVarLoss(target_feature) if use_mean_var else StyleLoss(target_feature)
            style_loss_model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # now we trim off the layers after the last content and style losses
    for i in range(len(style_loss_model) - 1, -1, -1):
        if isinstance(style_loss_model[i], ContentLoss) or isinstance(style_loss_model[i], StyleLoss) or isinstance(style_loss_model[i], MeanVarLoss):
            break

    style_loss_model = style_loss_model[:(i + 1)]

    return style_loss_model, style_losses, content_losses

# **Style Transfer Basic Algorithm**
Arguments:


*   ***cnn***: A convolutional neural network to copy it's architecture. In our case is VGG-19.
*   ***content_img***: The image that we want to change it's style
* ***style_img***: The image contains the style we want to apply
* ***input_img***: The image we optimize, this will be the result image. Can either be a copy of *content_image* or noise
* ***content_layers***: The layers to perform ContentLoss on
* ***style_layers***: The layers to perform StyleLoss on
* ***normalization***: Normalization module to fit the data to the cnn



In [None]:
#@title Original Algorithm
def run_style_transfer(cnn, content_img, style_img, input_img,
                       content_layers=content_layers_default,
                       style_layers=style_layers_default,
                       normalization=VggNormalization,
                       num_steps=300,
                       style_weight=1000000,
                       content_weight=1,
                       content_converg=False):

    print('Building the style transfer Style-loss model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
                                          style_img, content_img, content_layers,
                                          style_layers, normalization)

    # we want to optimize the input image and not the model's parameters
    input_img.requires_grad_(True)
    model.eval()
    model.requires_grad_(False)

    optimizer = optim.LBFGS([input_img])

    # this allows us to iterate until content loss converges
    condition = lambda x: x <= num_steps if not content_converg else True

    print('Optimizing..')
    run = [0]
    last_two_content_losses = [0, 1000]
    while condition(run[0]):

        def closure():
            # correct the values of updated input image
            with torch.no_grad():
                input_img.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.loss
            for cl in content_losses:
                content_score += cl.loss

            style_score *= style_weight
            content_score *= content_weight
            last_two_content_losses[0] = last_two_content_losses[1]
            last_two_content_losses[1] = content_score.item()

            loss = style_score + content_score
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score

        optimizer.step(closure)

        # check if content loss has converged and stop if need to.
        # limited to 2000 runs
        diff = abs(last_two_content_losses[1] - last_two_content_losses[0])
        if diff < 0.0001 or run[0] > 2000:
          print("Convergence at {} with diff {}".format(run[0], diff))
          break

    # a last correction...
    with torch.no_grad():
        input_img.clamp_(0, 1)

    return input_img

In [None]:
#@title A Simple Style Transfer Example
def style_transfer_example():
    # load a style reference and a content image
    style_img = image_loader(root_dir + r"/Images/Style/picasso.jpg")
    content_img = image_loader(root_dir + r"/Images/Content/dancing.jpg")
    content_img_name = "dancing"
    style_img_name = "picasso"

    # fit their dimensions
    if style_img.size() != content_img.size():
      style_img, content_img = trim_images_to_fit_sizes(style_img, content_img)

    # get input image: clone or noise
    input_img = get_input_image(content_img, noise=True)

    imshow(content_img, "Content")
    imshow(style_img, "Style")
    imshow(input_img, "Input")

    # run the style transfer algorithm
    output = run_style_transfer(cnn, content_img, style_img, input_img, content_weight=3, content_converg=True)

    # plot the results
    show_and_save_image(plt, output, "{} in {} style".format(content_img_name,
                                      style_img_name), path=root_dir+r"/Results")

In [None]:
style_transfer_example()

# **Part 1**

## **A. Feature inversion**

# modify the algorithm to invert features from different stages of the VGG-19 network (which is the pretrained feature extraction network used by the algorithm). More specifically, given a target image, record its feature map  at the output of one of the layers of VGG-19. Next initialize an input image to noise, and optimize the input until its feature map for that layer closely approximates . Experiment with inversion of layers of different types and at different depths of VGG-19.

**Implementation**: We build a Content-Loss model using the parameters:

1. `content_layers` = [L], where L is a layer in the VGG-19. We investigate here convolutional and ReLU layers in different depths.
2. `style_layers` = []. This way, the model won't make any style comparisons.
3. `input_img` = Noise

We used the default values of hyperparameters for the learning loop, and record the losses for all different layers. The results are in the PDF file.

In [None]:
#@title Part1_A Algorithm (w/o style loss)
def Part1_A_algo(cnn, content_img, style_img, input_img,
                       content_layers=content_layers_default,
                       style_layers=style_layers_default,
                       normalization=VggNormalization,
                       num_steps=300,
                       style_weight=1000000,
                       content_weight=1):

    print('Content layer {}'.format(content_layers[0]))
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
                                      style_img, content_img, content_layers,
                                      style_layers, normalization)

    # we want to optimize the input image and not the model's parameters
    input_img.requires_grad_(True)
    model.eval()
    model.requires_grad_(False)

    optimizer = optim.LBFGS([input_img])

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            with torch.no_grad():
                input_img.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            content_score = 0

            for cl in content_losses:
                content_score += cl.loss
            content_score *= content_weight

            loss = content_score
            loss.backward()

            # Plot the loss
            metrics = {"Content_loss": content_score.item()}
            wandb.log(metrics)

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Content Loss: {:4f}'.format(content_score.item()))
                print()

            return content_score

        optimizer.step(closure)

    # a last correction...
    with torch.no_grad():
        input_img.clamp_(0, 1)

    return input_img

In [None]:
#@title Question Implementation

def Part1_A():
    # set path to save figures in. losses plots are saved to W&B
    part1_a_path = root_dir + r"/Part1/A"

    # investigate all conv & relu layers of VGG-19
    content_layers = ["conv_%d" % d for d in range(1,14)] + ["relu_%d" % d for d in range(1,14)]

    content_img = image_loader(root_dir + r"/Images/Content/dancing.jpg")

    for layer in content_layers:
        input_img = get_input_image(content_img, noise=True)
        wandb.init(
            project="Ex4_Part1_A",
            config={
                'name': layer,
                'layer': layer,
                'type': layer[:4],   # either 'conv' or 'relu'
                'depth': layer[-1] if layer[-2] == "_" else layer[-2:]
                })
        config = wandb.config

        # run the modified algorithm, with no weight on the style loss
        # and much more weight on the content loss
        output = Part1_A_algo(cnn, content_img, None, input_img,
                              content_layers=[config.layer], style_layers = [],
                              style_weight=0, content_weight=1000000, num_steps=500)
        wandb.finish()

        # save optimized image
        show_and_save_image(plt, output, layer, part1_a_path)

In [None]:
Part1_A()

## **B. Texture synthesis**

# modify the algorithm to generate a texture inspired by the style input, and demonstrate how using Gram matrices from different network layers affects the characteristics of the generated texture.


**Implementation**: We build a Style-Loss model using the parameters:

1. `content_layers` = [] This way, the model won't make any content comparisons.
2. `style_layers` = [L]. where L is a layer in the VGG-19. We investigate here the Gram matrices convolutional and ReLU layers in different depths.
3. `input_img` = Noise

We used the default values of hyperparameters for the learning loop, and record the losses for all different layers. The results are in the PDF file.

In [None]:
#@title Part1_B Algorithm (w/o content loss)
def Part1_B_algo(cnn, content_img, style_img, input_img,
                       content_layers=content_layers_default,
                       style_layers=style_layers_default,
                       normalization=VggNormalization,
                       num_steps=300,
                       style_weight=1000000,
                       content_weight=1):

    print('style layer {}'.format(style_layers[0]))
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
                                      style_img, content_img, content_layers,
                                      style_layers, normalization)

    # we want to optimize the input image and not the model's parameters
    input_img.requires_grad_(True)
    model.eval()
    model.requires_grad_(False)

    optimizer = optim.LBFGS([input_img])

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            with torch.no_grad():
                input_img.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0


            for sl in style_losses:
                style_score += sl.loss

            style_score *= style_weight

            loss = style_score
            loss.backward()

            # Plot the loss
            metrics = {"Style_loss": style_score.item()}
            wandb.log(metrics)

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss: {:4f}'.format(style_score.item()))
                print()

            return style_score

        optimizer.step(closure)

    # a last correction...
    with torch.no_grad():
        input_img.clamp_(0, 1)

    return input_img

In [None]:
#@title Question Implementation
def Part1_B():
    # set path to save figures in. losses plots are saved to W&B
    part1_b_path = root_dir + r"/Part1/B"

    # investigate all conv & relu layers of VGG-19
    #todo what layers to use?
    style_layers = ["conv_%d" % d for d in range(1,14)] + ["relu_%d" % d for d in range(1,14)]

    style_img = image_loader(root_dir + r"/Images/Style/picasso.jpg")

    for layer in style_layers:
        input_img = get_input_image(style_img, noise=True)
        wandb.init(
            project="Ex4_Part1_B",
            config={
                'name': layer,
                'layer': layer,
                'type': layer[:4],   # either 'conv' or 'relu'
                'depth': layer[-1] if layer[-2] == "_" else layer[-2:]
                })
        config = wandb.config

        # run the modified algorithm, with no weight on the style loss
        # and much more weight on the content loss
        output = Part1_B_algo(cnn, None, style_img, input_img,
                              content_layers=[], style_layers = [config.layer],
                              style_weight=1000000, content_weight=0, num_steps=500)
        wandb.finish()

        # save optimized image
        show_and_save_image(plt, output, layer, part1_b_path)

In [None]:
Part1_B()

## **C. MeanVar style loss**

### Implement a “MeanVar” style loss that computes and compares the mean and the variance of each activation channel, in a given layer, instead of the Gram matrix of the layer. Compare the style transfer results obtained with this style loss to those obtained by the original implementation that uses the Gram matrices.


**Implementation**: We have implemeneted a style-transfer using several style-loss layers, and one content loss layer (those are the default values declared on top). The algorithm is the 'Original Algorithm' defined above, and for the Style-Loss model here, we defined it ot be based on MeanVar matrices style-comparisons rather then the original Gram matrices.

The definition of the `MeanVarLoss` is located in section "Sub Modules".

In [None]:
#@title Question Implementation
def Part1_C():
    # load a style reference and a content image
    style_img = image_loader(root_dir + r"/Images/Style/picasso.jpg")
    content_img = image_loader(root_dir + r"/Images/Content/dancing.jpg")
    content_img_name = "dancing"
    style_img_name = "picasso"
    # fit their dimensions
    if style_img.size() != content_img.size():
      style_img, content_img = trim_images_to_fit_sizes(style_img, content_img)

    # get input image: clone or noise
    input_img = get_input_image(content_img, noise=False)

    imshow(content_img, "Content")
    imshow(style_img, "Style")
    imshow(input_img, "Input")

    # run the style transfer algorithm
    style_weight = 10000
    output = run_style_transfer(cnn, content_img, style_img, input_img,style_weight=style_weight,content_weight=3, content_converg=True, use_mean_var=True)

    # plot the results
    show_and_save_image(plt, output, "{} in {} style - meanVar Loss style_weight{}".format(content_img_name,
                                      style_img_name,style_weight), path=root_dir+r"/Results")

In [None]:
Part1_C()

# **Part 2**

# Requirements and installations

In [None]:
!pip install transformers
!pip install diffusers

import itertools
from tqdm.auto import tqdm
from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler,UniPCMultistepScheduler

scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

In [None]:
device = "cuda"
vae.to(device).eval()
text_encoder.to(device).eval()
unet.to(device).eval()


In [None]:
scheduler.set_timesteps(50,device=device)
scheduler.timesteps

# Text embeddings

In [None]:
def text_embedding(prompt):
    generator = torch.manual_seed(217)  # Seed generator to create the inital latent noise
    batch_size = len(prompt)
    text_input = tokenizer(
        prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
    )

    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    return text_embeddings, batch_size

In [None]:
def get_latent_vec(batch_size):
    latents = torch.randn(
        (batch_size, unet.in_channels, height // 8, width // 8),
    )
    latents = latents.to(device)
    latents = latents * scheduler.init_noise_sigma

    return latents

In [None]:
def get_style_model_for_diffusion(style_img):
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
                                          style_img, content_img, content_layers=[],
                                          use_mean_var=False)
    model.eval()
    return model,style_losses

In [None]:
def show_img(latents):
  scaled_latents = 1 / 0.18215 * latents
  with torch.no_grad():
    image = vae.decode(scaled_latents).sample
  # convert the image to a PIL.Image to see the generated image
  image = (image / 2 + 0.5).clamp(0, 1)
  image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
  images = (image * 255).round().astype("uint8")
  pil_images = [Image.fromarray(image) for image in images]
  return pil_images[0]

In [None]:
def sl_grad(clean_latent_estimate,model,style_losses):

    scaled_latent_estimate = 1 / 0.18215 * clean_latent_estimate
    vae.eval()

    clean_img_pred = vae.decode(scaled_latent_estimate).sample / 2 + 0.5

    model(clean_img_pred)
    style_score = 0
    for sl in style_losses:
        style_score += sl.loss

    style_score.backward()
    return clean_img_pred

In [None]:
def optimize_latents(latents, scheduler, num_inference_steps, style_img,
                     text_embeddings, guidance_scale, style_weight,use_StyleLoss=True):

    scheduler.set_timesteps(num_inference_steps)
    if use_StyleLoss:
      model,style_losses = get_style_model_for_diffusion(style_img)
    for i,t in enumerate(tqdm(scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

        with torch.no_grad():
          noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        if use_StyleLoss and i>20:
          z = noise_pred.clone().detach()
          z.requires_grad = True
          clean_latent_estimate = scheduler.convert_model_output(z, t, latents)
          clean_img = sl_grad(clean_latent_estimate,model,style_losses)
          with torch.no_grad():
            grad = z.grad / torch.mean(torch.abs(z.grad))
          if i > 50:
            grad *= 1.5
          sig_t = scheduler.sigma_t[t]
          noise_pred -= style_weight * sig_t * grad

        latents = scheduler.step(noise_pred, t, latents).prev_sample
    torch.cuda.empty_cache()

    return latents

In [None]:
random.seed(212)
torch.manual_seed(212)
path = root_dir + r"/Part2/A"
# style_images = ["monet_portrait.jpg", "vangoh_portrait.jpg", "picasso_portrait.jpg", "francis_portrait.jpg"]
style_images = ["picasso.jpg"]

text = ["female actress on stage bow to audience"]
height = 512  # default height of Stable Diffusion
width = 512  # default width of Stable Diffusion
torch.cuda.empty_cache()
content_img = None
num_inference_steps = 100  # Number of denoising steps
style_weights = [0.8]               # Style weight
guidance_scales = [7.5]  # Scale for classifier-free guidance (refer to text)

for style_im in style_images:
  input_im = image_loader(root_dir + r"/Images/Style/{}".format(style_im))
  for style_s, guide_s in itertools.product(style_weights, guidance_scales):
    torch.cuda.empty_cache()
    text_embeddings, batch_size = text_embedding(text)
    latents = get_latent_vec(batch_size)
    latents = optimize_latents(latents, scheduler, num_inference_steps,
                              input_im, text_embeddings, guide_s,
                              style_s,True)
    im = show_img(latents)
    title = "{}_in_style_{}".format(text[0], style_im.split(".")[0])
    # title = "{}".format(text[0])
    show_and_save_image(plt, loader(im).unsqueeze(0), title, path)

In [None]:
def optimize_latents(latents, scheduler, num_inference_steps, style_img,
                     text_embeddings, guidance_scale, style_weight,use_StyleLoss=True):

    scheduler.set_timesteps(num_inference_steps)
    if use_StyleLoss:
      model,style_losses = get_style_model_for_diffusion(style_img)
    for i,t in enumerate(tqdm(scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

        with torch.no_grad():
          noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        if use_StyleLoss:
          z = noise_pred.clone().detach()
          z.requires_grad = True
          clean_latent_estimate = scheduler.convert_model_output(z, t, latents)
          clean_img = sl_grad(clean_latent_estimate,model,style_losses)
          with torch.no_grad():
            grad = z.grad / torch.mean(torch.abs(z.grad))
          sig_t = scheduler.sigma_t[t]
          noise_pred -= style_weight * sig_t * grad

        latents = scheduler.step(noise_pred, t, latents).prev_sample
    torch.cuda.empty_cache()

    return latents