In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm
from PIL import Image, ImageOps
import cv2
from mpl_toolkits.axes_grid1 import ImageGrid
import random

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.autograd import Variable


import torchvision
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
import torchvision.models as models

import numpy as np
import seaborn as sns

import sys

import os

import gc

In [None]:
def get_features_eq(image, model, layers):
    features = {}
    x = image
    for name, layer in model.features._modules.items():
        x = layer(x)
        if name in layers:
            features[layers[name]] = x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d8_vgg = EquivariantVGG19(num_classes = 100)
d8_vgg = d8_vgg.to(device)

In [None]:
# Equivariant Style Transfer
def transfer_style_eq(model=None,
                   show_every=250, 
                   optimizer_lr=0.005, 
                   tot_steps=2000, 
                   content_alpha=1, 
                   style_beta=1e7, 
                   content_layers={21: 'conv4_2'}, 
                   style_weights={'conv1_1': 0.2, 'conv2_1': 0.2, 'conv3_1': 0.2, 'conv4_1': 0.2, 'conv5_1': 0.2}, 
                   content_pth=None, 
                   style_pth=None, 
                   l2_loss=True,
                   gram_function=gram_matrix):
    
    if content_pth is None or style_pth is None:
        raise ValueError("Both content_pth and style_pth must be provided")
    
    gram_tensors = []  # List to store Gram matrices with their layer and step

    # Load content and style images
    content = load_image(content_pth, 224).to(device)
    style = load_image(style_pth, 224).to(device)
    
    # Initialize the target image
    target = content.clone().requires_grad_(True).to(device)
    plt.imshow(im_convert(target))
    plt.show()

    # Get features from the content and style images
    content_features = get_features_eq(content, model, layers=content_layers)
    
    # Calculate gram matrices for each layer of our style representation
    style_layers = {layer: layer for layer in style_weights}
    style_features = get_features_eq(style, model, layers=style_layers)
    style_grams = {layer: gram_function(style_features[layer]) for layer in style_features}

    # Iteration hyperparameters
    optimizer = optim.Adam([target], lr=optimizer_lr)
    
    num_layers = len(style_weights)
    print(f"Number of style layers: {num_layers}")

    for ii in range(1, tot_steps + 1):
        # Get the features from your target image
        target_features = get_features_eq(target, model, layers={**content_layers, **style_layers})
        
        # Debug statement to print the target feature keys
        if ii == 1:
            print("Target feature keys:", target_features.keys())

        # Compute the content loss
        content_layer = list(content_layers.values())[0]  # Get the content layer key
        if l2_loss:
            content_loss = torch.mean((target_features[content_layer] - content_features[content_layer])**2)
        else: 
            content_loss = torch.mean(torch.abs(target_features[content_layer] - content_features[content_layer]))

        # Initialize the style loss
        style_loss = 0

        # Compute the style loss for each layer
        for layer in style_weights:
            # Get the "target" style representation for the layer
            target_feature = target_features[layer]
            target_gram = gram_function(target_feature)

            # Append the step, layer, and computed Gram matrix to the list
            gram_tensors.append((ii, layer, target_gram.detach()))

            _, d, h, w = target_feature.shape
            
            style_gram = style_grams[layer]
            
            # Compute the style loss for one layer, weighted appropriately
            if l2_loss:
                layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
            else: 
                layer_style_loss = style_weights[layer] * torch.mean(torch.abs(target_gram - style_gram))
            style_loss += layer_style_loss / (d * h * w)

        # Calculate the total loss
        total_loss = content_alpha * content_loss + style_beta * style_loss

        # Update the target image
        optimizer.zero_grad()
        total_loss.backward(retain_graph = True)
        optimizer.step()
        
        torch.cuda.empty_cache()

        if ii % show_every == 0:
            print('Total loss: ', total_loss.item())
            plt.imshow(im_convert(target))
            plt.show()

    expected_length = tot_steps * num_layers
    print(f"Expected length of gram_tensors: {expected_length}")

    return gram_tensors