In [None]:
# Import Libraries:
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import torch
import torch.optim as optim
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast

import torchvision
from torchvision import transforms
import torchvision.models as models

In [None]:
# Initialize Models:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
vgg.to(device)
# For non-equivariant
for param in vgg.parameters():
  param.requires_grad_(False)

In [None]:
# Utility Functions
def load_image(path, max_size=112, shape=(112, 112)):
    image = Image.open(path).convert('RGB')
    # Determine the size to resize the image to
    if shape is None:
        if max(image.size) > max_size:
            size = max_size
        else:
            size = max(image.size)
        size = (size, size)
    else:
        size = shape
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop((112, 112)),  # Ensure the image is centered and cropped
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

In [None]:
def im_convert(tensor):
  """ Display a tensor as an image. """
  image = tensor.cpu().clone().detach()     # tensor.clone().detach() would create a copy of tensor and removes tensor from computational graph(requires_grad = False)
  image = image.numpy().squeeze()
  image = image.transpose(1,2,0)
  image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
  image = image.clip(0, 1)
  return image

In [None]:
def get_features(image, model, layers = None):
  """ Run an image forward through a model and get features for a set of layers.
      Reference: Gatys et al (2016)
  """
  #'conv4_2' below is used for content representation
  if layers is None:
    layers = {'0' : 'conv1_1',
              '5' : 'conv2_1',
              '10' : 'conv3_1',
              '19': 'conv4_1',
              '21': 'conv4_2',
              '28': 'conv5_1'}
            
  features = {}
  x = image
  for name, layer in model._modules.items():
    x = layer(x)   #passing image through layer
    if name in layers:
      features[layers[name]] = x

  return features

In [None]:
# Gram Matrix Definitions
def gram_matrix(tensor):
  # get batch_size, depth, height, width of tensor
  _, d, h, w = tensor.size()
  # reshape so we are multiplying height and width
  tensor = tensor.view(d, h * w)
  # calc. gram matrix
  gram = torch.mm(tensor, tensor.t())

  return gram

In [None]:
def transfer_style(show_every=250, 
                   optimizer_lr=0.005, 
                   tot_steps=2000, 
                   content_alpha=1, 
                   style_beta=1e7, 
                   content_layer='conv4_2', 
                   style_weights={'conv1_1': 0.2, 'conv2_1': 0.2, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}, 
                   content_pth=None, 
                   style_pth=None, 
                   l2_loss=True,
                   gram_function=gram_matrix):
    
    if content_pth is None or style_pth is None:
        raise ValueError("Both content_pth and style_pth must be provided")
    
    gram_tensors = []  # List to store Gram matrices with their layer and step

    # Load content and style images
    content = load_image(content_pth).to(device)
    style = load_image(style_pth, shape=(content.shape[2], content.shape[3])).to(device)
    
    # Initialize the target image
    target = content.clone().requires_grad_(True).to(device)
    plt.imshow(im_convert(target))
    plt.show()

    # Get features from the content and style images
    content_layers = {'21': 'conv4_2'}
    content_features = get_features(content, vgg, layers=content_layers)
    
    # Display the content feature at conv4_2
    if 'conv4_2' in content_features:
        print("Content feature at conv4_2")
        plt.imshow(im_convert(content_features['conv4_2'][0]))  # Ensure batch dimension is removed
        plt.title("Content feature at conv4_2")
        plt.show()
    
    # Display the style features
    style_features = get_features(style, vgg)
    for layer, feature in style_features.items():
        print(f"Style feature at {layer}")
        plt.imshow(im_convert(feature[0]))
        plt.title(f"Style feature at {layer}")
        plt.show()

    # Calculate gram matrices for each layer of our style representation
    style_grams = {layer: gram_function(style_features[layer]) for layer in style_features}

    # Iteration hyperparameters
    optimizer = optim.Adam([target], lr=optimizer_lr)
    #optimizer = optim.SGD([target], lr=0.01, momentum=0.9, nesterov=True)
    
    num_layers = len(style_weights)
    print(f"Number of style layers: {num_layers}")

    for ii in range(1, tot_steps + 1):
        # Get the features from your target image
        target_features = get_features(target, vgg)

        # Compute the content loss
        if l2_loss == True:
            content_loss = torch.mean((target_features[content_layer] - content_features[content_layer])**2)
        else: 
            content_loss = torch.mean(torch.abs(target_features[content_layer] - content_features[content_layer]))

        # Initialize the style loss
        style_loss = 0

        # Compute the style loss for each layer
        for layer in style_weights:
            # Get the "target" style representation for the layer
            target_feature = target_features[layer]
            target_gram = gram_function(target_feature)

            # Append the step, layer, and computed Gram matrix to the list
            gram_tensors.append((ii, layer, target_gram.detach()))

            _, d, h, w = target_feature.shape
            
            style_gram = style_grams[layer]
            
            # Compute the style loss for one layer, weighted appropriately
            if l2_loss == True:
                layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
            else: 
                layer_style_loss = style_weights[layer] * torch.mean(torch.abs(target_gram - style_gram))
            style_loss += layer_style_loss / (d * h * w)

        # Calculate the total loss
        total_loss = content_alpha * content_loss + style_beta * style_loss

        # Update the target image
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        torch.cuda.empty_cache()

        if ii % show_every == 0:
            print('Total loss: ', total_loss.item())
            plt.imshow(im_convert(target))
            plt.show()

    expected_length = tot_steps * num_layers
    print(f"Expected length of gram_tensors: {expected_length}")

    return gram_tensors  # Return the list of Gram matrices with their corresponding step and layer

In [None]:
def visualize_gram_matrices(gram_tensors, 
                            filter_layer=None, 
                            step_interval=100, 
                            content="content", 
                            style="style", 
                            symmetries="symmetries"):
    # Create a list of filtered Gram matrices based on the filter_layer and step_interval
    filtered_grams = [(step, layer, gram_matrix) for step, layer, gram_matrix in gram_tensors if (not filter_layer or layer == filter_layer) and step % step_interval == 0]

    # Calculate the number of subplots needed
    num_subplots = len(filtered_grams)
    cols = 2  # Number of columns in the subplot grid
    rows = (num_subplots + cols - 1) // cols  # Calculate the number of rows needed
    
    # Create a larger figure for the subplots
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 6, rows * 5))
    axes = axes.flatten()  # Flatten the axes array for easy iteration
    
    for idx, (step, layer, gram_matrix) in enumerate(filtered_grams):
        ax = axes[idx]
        gram_matrix_np = gram_matrix.cpu().detach().numpy()  # Ensure the tensor is detached and moved to CPU
        sns.heatmap(gram_matrix_np, cmap='viridis', ax=ax, cbar=True)
        ax.set_title(f'Gram Matrix at Step {step}, Layer {layer}')
        ax.set_xlabel('Feature Maps')
        ax.set_ylabel('Feature Maps')
    
    # Remove any empty subplots
    for idx in range(len(filtered_grams), len(axes)):
        fig.delaxes(axes[idx])
    
    # Adjust layout and save the figure as a PDF
    fig.tight_layout()
    pdf_title = f"{content} {style} {symmetries}.pdf"
    plt.savefig(pdf_title)
    plt.show()