In [5]:
#######################################################################
# Implementation of 
# A Sliced Wasserstein Loss for Neural Texture Synthesis
# Heitz et al., CVPR 2021
#######################################################################

import numpy as np
import torch
import imageio
from torchvision import transforms

#######################################################################
# scaling factor of the optimized texture 
# wrt the example texture
#######################################################################
SCALING_FACTOR = 1


#######################################################################
# Load example texture 
#######################################################################
STYLE_IMAGE_PATH = 'images/Vassily_Kandinsky_1913_-_Composition_7.jpg'
CONTENT_IMAGE_PATH = 'images/YellowLabradorLooking_new.jpg' 

def saveImage(filename, image):
    imageTMP = np.clip(image * 255.0, 0, 255).astype('uint8')
    imageio.imwrite("SWD/"+filename, imageTMP)

def loadImage(filename, size=128):
    image = imageio.imread(filename).astype("float32") / 255.0  # Normalize pixel values
    image = image[:, :, :3]  # Ensure image is RGB without alpha channel
    image = np.transpose(image, (2, 0, 1))  # Change dimension order to CxHxW
    image = image[np.newaxis, ...]  # Add a batch dimension
    # Convert to tensor
    image = torch.from_numpy(image)
    # Resize image
    resize = transforms.Resize((size, size))
    image = resize(image)
    return image

# Load style and content images
image_style = loadImage(STYLE_IMAGE_PATH)
image_content = loadImage(CONTENT_IMAGE_PATH)

# Ensure both are on the same device
device = torch.device('cpu')  # Use 'cuda' if you have GPU
image_style = image_style.to(device)
image_content = image_content.to(device)

#######################################################################
# Load pretrained VGG19
#######################################################################

class VGG19(torch.nn.Module):

    def __init__(self):
        super(VGG19, self).__init__()

        self.block1_conv1 = torch.nn.Conv2d(3, 64, (3,3), padding=(1,1), padding_mode='reflect')
        self.block1_conv2 = torch.nn.Conv2d(64, 64, (3,3), padding=(1,1), padding_mode='reflect')

        self.block2_conv1 = torch.nn.Conv2d(64, 128, (3,3), padding=(1,1), padding_mode='reflect')
        self.block2_conv2 = torch.nn.Conv2d(128, 128, (3,3), padding=(1,1), padding_mode='reflect')

        self.block3_conv1 = torch.nn.Conv2d(128, 256, (3,3), padding=(1,1), padding_mode='reflect')
        self.block3_conv2 = torch.nn.Conv2d(256, 256, (3,3), padding=(1,1), padding_mode='reflect')
        self.block3_conv3 = torch.nn.Conv2d(256, 256, (3,3), padding=(1,1), padding_mode='reflect')
        self.block3_conv4 = torch.nn.Conv2d(256, 256, (3,3), padding=(1,1), padding_mode='reflect')

        self.block4_conv1 = torch.nn.Conv2d(256, 512, (3,3), padding=(1,1), padding_mode='reflect')
        self.block4_conv2 = torch.nn.Conv2d(512, 512, (3,3), padding=(1,1), padding_mode='reflect')
        self.block4_conv3 = torch.nn.Conv2d(512, 512, (3,3), padding=(1,1), padding_mode='reflect')
        self.block4_conv4 = torch.nn.Conv2d(512, 512, (3,3), padding=(1,1), padding_mode='reflect')

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsampling = torch.nn.AvgPool2d((2,2))

    def forward(self, image):
        
        # RGB to BGR
        image = image[:, [2,1,0], :, :]

        # [0, 1] --> [0, 255]
        image = 255 * image

        # remove average color
        image[:,0,:,:] -= 103.939
        image[:,1,:,:] -= 116.779
        image[:,2,:,:] -= 123.68

        # block1
        block1_conv1 = self.relu(self.block1_conv1(image))
        block1_conv2 = self.relu(self.block1_conv2(block1_conv1))
        block1_pool = self.downsampling(block1_conv2)

        # block2
        block2_conv1 = self.relu(self.block2_conv1(block1_pool))
        block2_conv2 = self.relu(self.block2_conv2(block2_conv1))
        block2_pool = self.downsampling(block2_conv2)

        # block3
        block3_conv1 = self.relu(self.block3_conv1(block2_pool))
        block3_conv2 = self.relu(self.block3_conv2(block3_conv1))
        block3_conv3 = self.relu(self.block3_conv3(block3_conv2))
        block3_conv4 = self.relu(self.block3_conv4(block3_conv3))
        block3_pool = self.downsampling(block3_conv4)

        # block4
        block4_conv1 = self.relu(self.block4_conv1(block3_pool))
        block4_conv2 = self.relu(self.block4_conv2(block4_conv1))
        block4_conv3 = self.relu(self.block4_conv3(block4_conv2))
        block4_conv4 = self.relu(self.block4_conv4(block4_conv3))

        return [block1_conv1, block1_conv2, block2_conv1, block2_conv2, block3_conv1, block3_conv2, block3_conv3, block3_conv4, block4_conv1, block4_conv2, block4_conv3, block4_conv4]


#######################################################################
# Initialize optimized texture
# LBFGS optimization with the slicing loss
#######################################################################

def run_texture_synthesis(lambda_content=1.0, max_iterations=20):
    device = torch.device('cpu')
    image_style = loadImage(STYLE_IMAGE_PATH).to(device)
    image_content = loadImage(CONTENT_IMAGE_PATH).to(device)

    vgg = VGG19().to(device)
    vgg.load_state_dict(torch.load("vgg19.pth", map_location=device))

    image_optimized = torch.nn.Parameter(torch.randn_like(image_style) * 0.01 + image_style.mean(dim=(2, 3), keepdim=True))

    optimizer = torch.optim.LBFGS([image_optimized], lr=1, max_iter=max_iterations, tolerance_grad=0.0)
    def slicing_loss(image_generated, image_style):
        
        # generate VGG19 activations
        list_activations_generated = vgg(image_generated)
        list_activations_example   = vgg(image_style)
        
        # iterate over layers
        loss = 0
        for l in range(len(list_activations_example)):
            # get dimensions
            b = list_activations_example[l].shape[0]
            dim = list_activations_example[l].shape[1]
            n = list_activations_example[l].shape[2]*list_activations_example[l].shape[3]
            # linearize layer activations and duplicate example activations according to scaling factor
            activations_example = list_activations_example[l].view(b, dim, n).repeat(1, 1, SCALING_FACTOR*SCALING_FACTOR)
            activations_generated = list_activations_generated[l].view(b, dim, n*SCALING_FACTOR*SCALING_FACTOR)
            # sample random directions
            Ndirection = dim
            directions = torch.randn(Ndirection, dim).to('cpu')  # After
            directions = directions / torch.sqrt(torch.sum(directions**2, dim=1, keepdim=True))
            # project activations over random directions
            projected_activations_example = torch.einsum('bdn,md->bmn', activations_example, directions)
            projected_activations_generated = torch.einsum('bdn,md->bmn', activations_generated, directions)
            # sort the projections
            sorted_activations_example = torch.sort(projected_activations_example, dim=2)[0]
            sorted_activations_generated = torch.sort(projected_activations_generated, dim=2)[0]
            # L2 over sorted lists
            loss += torch.mean( (sorted_activations_example-sorted_activations_generated)**2 ) 
        return loss

    def content_loss(image_generated, image_content):
        # generate VGG19 activations for the generated and content images
        activations_generated = vgg(image_generated)
        activations_content = vgg(image_content)

        # Choose layers to compare content
        content_layers = [4]  
        
        loss = 0
        for l in content_layers:
            loss += torch.mean((activations_generated[l] - activations_content[l])**2)
        return loss
   
    current_loss = {'total_loss': 0, 'style_loss': 0, 'content_loss': 0}  # A dictionary to hold current loss values

    def closure():
        optimizer.zero_grad()
        style_loss = slicing_loss(image_optimized, image_style)
        cont_loss = content_loss(image_optimized, image_content)
        total_loss = style_loss + lambda_content * cont_loss
        total_loss.backward()
        
        # Store losses in the dictionary to access outside the closure
        current_loss['total_loss'] = total_loss.item()
        current_loss['style_loss'] = style_loss.item()
        current_loss['content_loss'] = cont_loss.item()
        
        return total_loss

    for iteration in range(max_iterations):
        optimizer.step(closure)  # This will execute the closure and update the `current_loss` dictionary

        # Now you can access and print the loss values stored in `current_loss`
        print(f'Iteration {iteration}: lambda: {lambda_content}, Style Loss: {current_loss["style_loss"]}, Content Loss: {current_loss["content_loss"]}, Total Loss: {current_loss["total_loss"]}')

        with torch.no_grad():
            img_np = image_optimized.detach().cpu().squeeze(0).permute(1, 2, 0).numpy()
            saveImage(f'optimized_texture_lambda_{lambda_content}_{iteration}.png', img_np)

    print("Optimization Completed")


  image = imageio.imread(filename).astype("float32") / 255.0  # Normalize pixel values


In [7]:
# Example usage:
run_texture_synthesis(lambda_content=100.0, max_iterations=20)
run_texture_synthesis(lambda_content=10.0, max_iterations=20)
run_texture_synthesis(lambda_content=1.0, max_iterations=20)

  image = imageio.imread(filename).astype("float32") / 255.0  # Normalize pixel values


Iteration 0: lambda: 100.0, Style Loss: 5.27976131439209, Content Loss: 0.05321184918284416, Total Loss: 10.600946426391602
Iteration 1: lambda: 100.0, Style Loss: 5.181183338165283, Content Loss: 0.01991446129977703, Total Loss: 7.172629356384277
Iteration 2: lambda: 100.0, Style Loss: 5.299943923950195, Content Loss: 0.012998061254620552, Total Loss: 6.59975004196167
Iteration 3: lambda: 100.0, Style Loss: 5.289909362792969, Content Loss: 0.009743031114339828, Total Loss: 6.264212608337402
Iteration 4: lambda: 100.0, Style Loss: 5.3857035636901855, Content Loss: 0.007918263785541058, Total Loss: 6.177529811859131
Iteration 5: lambda: 100.0, Style Loss: 5.33715295791626, Content Loss: 0.007158598862588406, Total Loss: 6.053012847900391
Iteration 6: lambda: 100.0, Style Loss: 5.279743671417236, Content Loss: 0.006690758280456066, Total Loss: 5.948819637298584
Iteration 7: lambda: 100.0, Style Loss: 5.374063491821289, Content Loss: 0.006612951401621103, Total Loss: 6.035358428955078
Ite

In [8]:
run_texture_synthesis(lambda_content=0.5, max_iterations=20)

  image = imageio.imread(filename).astype("float32") / 255.0  # Normalize pixel values


Iteration 0: lambda: 0.5, Style Loss: 0.26880425214767456, Content Loss: 1.471167802810669, Total Loss: 1.0043880939483643
Iteration 1: lambda: 0.5, Style Loss: 0.21440470218658447, Content Loss: 1.079517126083374, Total Loss: 0.7541632652282715
Iteration 2: lambda: 0.5, Style Loss: 0.20608563721179962, Content Loss: 0.9341393113136292, Total Loss: 0.6731553077697754
Iteration 3: lambda: 0.5, Style Loss: 0.20681165158748627, Content Loss: 0.8561058640480042, Total Loss: 0.6348645687103271
Iteration 4: lambda: 0.5, Style Loss: 0.203163281083107, Content Loss: 0.8292345404624939, Total Loss: 0.6177805662155151
Iteration 5: lambda: 0.5, Style Loss: 0.196353942155838, Content Loss: 0.8037357330322266, Total Loss: 0.5982217788696289
Iteration 6: lambda: 0.5, Style Loss: 0.1987880915403366, Content Loss: 0.7929556369781494, Total Loss: 0.5952659249305725
Iteration 7: lambda: 0.5, Style Loss: 0.19665615260601044, Content Loss: 0.7798060178756714, Total Loss: 0.5865591764450073
Iteration 8: la

In [10]:
run_texture_synthesis(lambda_content=0.1, max_iterations=20)

  image = imageio.imread(filename).astype("float32") / 255.0  # Normalize pixel values


Iteration 0: lambda: 0.1, Style Loss: 0.214693084359169, Content Loss: 2.0820412635803223, Total Loss: 0.42289721965789795
Iteration 1: lambda: 0.1, Style Loss: 0.14281392097473145, Content Loss: 1.8402161598205566, Total Loss: 0.3268355429172516
Iteration 2: lambda: 0.1, Style Loss: 0.12375868111848831, Content Loss: 1.689477562904358, Total Loss: 0.2927064299583435
Iteration 3: lambda: 0.1, Style Loss: 0.11649229377508163, Content Loss: 1.5831913948059082, Total Loss: 0.27481144666671753
Iteration 4: lambda: 0.1, Style Loss: 0.11200540512800217, Content Loss: 1.5115606784820557, Total Loss: 0.26316148042678833
Iteration 5: lambda: 0.1, Style Loss: 0.11082682013511658, Content Loss: 1.4593366384506226, Total Loss: 0.25676047801971436
Iteration 6: lambda: 0.1, Style Loss: 0.10654297471046448, Content Loss: 1.4237242937088013, Total Loss: 0.2489154040813446
Iteration 7: lambda: 0.1, Style Loss: 0.10700571537017822, Content Loss: 1.38753080368042, Total Loss: 0.2457588016986847
Iteration

In [11]:
run_texture_synthesis(lambda_content=0.01, max_iterations=20)

  image = imageio.imread(filename).astype("float32") / 255.0  # Normalize pixel values


Iteration 0: lambda: 0.01, Style Loss: 0.18532966077327728, Content Loss: 2.2850191593170166, Total Loss: 0.20817984640598297
Iteration 1: lambda: 0.01, Style Loss: 0.13927492499351501, Content Loss: 2.2192232608795166, Total Loss: 0.16146716475486755
Iteration 2: lambda: 0.01, Style Loss: 0.11293163150548935, Content Loss: 2.2075960636138916, Total Loss: 0.1350075900554657
Iteration 3: lambda: 0.01, Style Loss: 0.10766863822937012, Content Loss: 2.2095847129821777, Total Loss: 0.12976448237895966
Iteration 4: lambda: 0.01, Style Loss: 0.09975562244653702, Content Loss: 2.1932382583618164, Total Loss: 0.12168800830841064
Iteration 5: lambda: 0.01, Style Loss: 0.0953354462981224, Content Loss: 2.185986042022705, Total Loss: 0.11719530820846558
Iteration 6: lambda: 0.01, Style Loss: 0.09478658437728882, Content Loss: 2.1789073944091797, Total Loss: 0.11657565832138062
Iteration 7: lambda: 0.01, Style Loss: 0.09295006096363068, Content Loss: 2.1751182079315186, Total Loss: 0.1147012412548