In [None]:
import os
import re
import math
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image
import matplotlib.pyplot as plt

import sys
sys.path.insert(0, './')
from DISTS import DISTS

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dists_model = DISTS().to(device)

In [4]:
# desired size of the output image
imsize = 256 if torch.cuda.is_available() else 64  # use small size if no GPU

loader = transforms.Compose([
    transforms.Resize((imsize, imsize)), # scale imported image
    transforms.ToTensor()                # transform it into a torch tensor
])  

unloader = transforms.ToPILImage()


def image_loader(image_name):
    image = Image.open(image_name)
    # fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)


def tensor_to_np(tensor):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    image = np.array(image)
    return image


def image_show(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated


def image_save(tensor, folder, name):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    image.save(os.path.join(folder, name) + '.jpg', 'JPEG')


def create_checkerboard(shape, square_size=8, device=None):
    # Extract the height and width from the shape
    if len(shape) == 4:  # Assuming shape is in the format (B, C, H, W)
        _, _, rows, cols = shape
    else:  # Assuming shape is in the format (C, H, W)
        _, rows, cols = shape

    # Create one row of the checkerboard
    row_pattern = np.array([(i // square_size) % 2 for i in range(cols)])

    # Create the full checkerboard
    checkerboard = np.array([(row_pattern if i // square_size % 2 == 0 
                              else 1 - row_pattern) for i in range(rows)])

    # Convert to a PyTorch tensor and adjust device and data type
    checkerboard_tensor = torch.from_numpy(checkerboard).float()
    if device is not None:
        checkerboard_tensor = checkerboard_tensor.to(device)

    # Repeat across channel dimension and add a batch dimension
    checkerboard_tensor = checkerboard_tensor.repeat(shape[1], 1, 1).unsqueeze(0)

    return checkerboard_tensor


def create_lines(shape, line_width=4, orientation='horizontal', device=None):
    # Extract the height and width from the shape
    if len(shape) == 4:  # Assuming shape is in the format (B, C, H, W)
        _, _, rows, cols = shape
    else:  # Assuming shape is in the format (C, H, W)
        _, rows, cols = shape

    # Create a single line pattern
    single_line = np.zeros(line_width * 2)
    single_line[:line_width] = 1

    if orientation == 'horizontal':
        # Tile the pattern vertically and then trim to match the image size
        pattern = np.tile(single_line, cols // (line_width * 2) + 1)
        pattern = pattern[:cols]
        pattern = np.tile(pattern, (rows, 1))
    else:  # vertical
        # Tile the pattern horizontally and then trim to match the image size
        pattern = np.tile(single_line, rows // (line_width * 2) + 1)
        pattern = pattern[:rows]
        pattern = np.tile(pattern, (cols, 1)).T

    # Convert to a PyTorch tensor and adjust device and data type
    lines_tensor = torch.from_numpy(pattern).float()
    if device is not None:
        lines_tensor = lines_tensor.to(device)

    # Repeat across channel dimension and add a batch dimension
    lines_tensor = lines_tensor.repeat(shape[1], 1, 1).unsqueeze(0)

    return lines_tensor

In [5]:
# ResNet feature map model
class ResNetFeatureMapExtractor(nn.Module):
    def __init__(self, *args):
        super().__init__(*args)
        self.selected_out = {}
        self.pretrained = models.resnet34(weights='IMAGENET1K_V1').to(device).eval()
        self.fhooks = []
        
        for name, module in list(self.pretrained.named_children()):
            # forward hook for first ReLU layer
            if isinstance(module, nn.ReLU):
                self.fhooks.append(module.register_forward_hook(self.forward_hook('relu')))
            # attach forward hooks to every residual block
            if isinstance(module, nn.Sequential):
                for n, block in module.named_children():
                    self.fhooks.append(block.register_forward_hook(self.forward_hook(f'{name}_{n}')))

    def forward_hook(self, layer_name):
        def hook(module, input, output):
            self.selected_out[layer_name] = output
        return hook

    def forward(self, x):
        # just return feature map activations
        self.selected_out = {}
        self.fhooks = []
        self.pretrained(x)
        return self.selected_out
    

# VGG feature map model
class VGGFeatureMapExtractor(nn.Module):
    def __init__(self, *args):
        super().__init__(*args)
        self.selected_out = {}
        self.pretrained = models.vgg19(weights='IMAGENET1K_V1').features.to(device).eval()
        self.fhooks = []

        for name, module in list(self.pretrained.named_children()):
            if isinstance(module, nn.ReLU):
                self.fhooks.append(module.register_forward_hook(self.forward_hook(f'relu_{name}')))
            if isinstance(module, nn.MaxPool2d):
                self.pretrained[int(name)] = nn.AvgPool2d(2)
                self.fhooks.append(module.register_forward_hook(self.forward_hook(f'pool_{name}')))

    def forward_hook(self, layer_name):
        def hook(module, input, output):
            self.selected_out[layer_name] = output
        return hook

    def forward(self, x):
        self.selected_out = {}
        self.fhooks = []
        self.pretrained(x)
        return self.selected_out

In [6]:
def gram(map):
    b, f, w, h = map.size()
    # flatten feature maps into matrix
    features = map.view(b*f, w*h)
    # calculate Gram matrix using transpose
    return torch.mm(features, features.t())


def synthesise(model, seed, exemplar, max_iter=200, stop_crit=1, _lr=1):
    # optimise the seed and not the model parameters
    seed.requires_grad_(True)
    model.requires_grad_(False)

    # init LBFGS
    optimizer = optim.LBFGS([seed], lr=_lr)

    # exemplar maps
    exemplar_maps = model(exemplar)

    loss_plot = []
    grad_threshold = 0.001  # Threshold for the gradient below which the loop will break
    window_size = 5  # Number of points to consider for calculating the gradient

    # L-BFGS step
    def closure():
        # correct the values of updated input image
        with torch.no_grad():
            seed.clamp_(0, 1)

        # reset gradient
        optimizer.zero_grad()

        # forward pass
        seed_maps = model(seed)

        # calculate loss
        loss = 0
        for (seed_map, exemplar_map) in zip(seed_maps.values(), exemplar_maps.values()):
            b, f, w, h = seed_map.size()
            seed_gram = gram(seed_map)
            exemplar_gram = gram(exemplar_map)
            loss += torch.sum(torch.square(torch.sub(exemplar_gram,
                              seed_gram))) / (f**2*(w*h))
        loss_plot.append(loss.item())
        # print(loss.item())

        # backprop
        loss.backward()

        return loss

    for e in range(max_iter):
        optimizer.step(closure)
    
        # Calculate the gradient of the loss
        if len(loss_plot) > window_size:
            recent_loss = loss_plot[-window_size:]
            gradient = (recent_loss[0] - recent_loss[-1]) / window_size
            
            # Termination condition based on loss gradient
            if abs(gradient) < stop_crit:
                print('low loss gradient')
                break

        # Additional termination condition
        if (loss_plot[-1] < 1e-5) or (math.isnan(loss_plot[-1])):
            break

    # final correction
    with torch.no_grad():
        seed.clamp_(0, 1)

    return seed, loss_plot

In [11]:
# run synthesis for a folder (`dataset`) with the following shape
# dataset
#   / images
#       / texturecategory
#           / texturename_001.jpg 
# saves results in `save_location`
# seed_type is instance of the VGG or ResNet model
# iterations is maximum iteration count
# stop_crit is loss gradient threshold
# kernel and sigmaval are used for optional blur for the seed
# geo_w is used for the geometric seed component width
def run_syn(cnn, dataset, save_location, seed_type, iterations, stop_crit=0.01, kernel=(3,3), sigmaval=1, geo_w=2):
    # create folder for synthesised images
    syn_img_path = os.path.join(save_location, 'images')
    if not os.path.exists(syn_img_path):
        os.makedirs(syn_img_path)

    # create folder for DISTS scores
    dists_score_path = os.path.join(save_location, 'scores')
    if not os.path.exists(dists_score_path):
        os.makedirs(dists_score_path)

    # seed blur setting
    blurrer = transforms.GaussianBlur(kernel_size=kernel, sigma=sigmaval)

    # final DISTS scores
    syn_scores = []

    # get all exemplar category subfolders
    dtd_subfolders = [f.path for f in os.scandir(dataset) if f.is_dir()]

    # for every category
    for texture_dir in dtd_subfolders:
        
        # get category title from path
        name = os.path.basename(os.path.normpath(texture_dir))

        # all images within a category
        category_list = os.listdir(texture_dir)

        for texture in category_list:
            # retrieve actual image
            full_path = f'{dataset}/{name}/{texture}'
            # specific texture name
            texture_name = os.path.splitext(texture)[0]

            # load exemplar as tensor
            exemplar = image_loader(full_path)
            # default to Gaussian noise seed
            seed = torch.randn(exemplar.shape).to(device)
            if (seed_type == "blur"):
                seed = blurrer(exemplar)
            if (seed_type == "lfnoise"):
                seed = blurrer(seed)
            if (seed_type == "square"):
                seed = create_checkerboard(exemplar.shape, geo_w)
            if (seed_type == "hline"):
                seed = create_lines(exemplar.shape, geo_w, 'horizontal')
            if (seed_type == "vline"):
                seed = create_lines(exemplar.shape, geo_w, 'vertical')

            # save seed image
            image_save(seed, os.path.join(save_location, 'images'), texture_name + "_seed")
            # save exampler 
            image_save(exemplar, os.path.join(save_location, 'images'), texture_name + "_example")
            # synthesise texture
            output, loss_plot = synthesise(cnn, seed, exemplar, iterations, stop_crit)
            # calculate DISTS score
            similarity_score = dists_model(exemplar, output)
            # Save synthesised image
            image_save(output, os.path.join(save_location, 'images'), texture_name + "_score=" + str(similarity_score.item()))

            print(texture_name)

            # Construct loss plot with evenly spaced x-axis ticks
            fig, ax = plt.subplots()
            ax.plot(loss_plot)
            ax.set(xlabel='iterations', ylabel='loss', title=name)
            ax.grid()
            ax.set_yscale('log')

            # Determine the number of ticks and their spacing
            iterations = len(loss_plot)
            max_ticks = 10
            if iterations <= max_ticks:
                step = 1
            else:
                step = np.ceil(iterations / max_ticks)
                # Adjust step to ensure an even number of ticks
                if (iterations // step) % 2 != 0:
                    step += 1

            # Set the x-ticks
            x_ticks = np.arange(0, iterations, step)
            ax.set_xticks(x_ticks)

            fig.savefig(os.path.join(save_location, "images", texture_name + "_lossplot.png"))

            # save DISTS score
            syn_scores.append({'name': texture_name, 'score': similarity_score.item()})

    # dicts to save the generated DISTS scores in different formats
    ind_score = {}
    cat_score = {}
    sum_score = {}

    # sort and save scores per category
    for score in syn_scores:
        ind_score[score['name']] = score['score']
        cat = re.findall(".*(?=_)", score['name'])[0]
        if cat in cat_score:
            cat_score[cat].append(score['score'])
        else:
            cat_score[cat] = [score['score']]
        if cat in sum_score:
            sum_score[cat] += score['score']
        else:
            sum_score[cat] = score['score']

    # calculate averages and sort
    avg_score = {k: v / 4 for k, v in sum_score.items()}
    sorted_score = {k: v for k, v in sorted(avg_score.items(), key=lambda item: item[1])}

    # save score dicts as files
    with open(os.path.join(dists_score_path, 'individual_scores'), 'wb') as f:
        pickle.dump(ind_score, f)
    with open(os.path.join(dists_score_path, 'categorical_scores'), 'wb') as f:
        pickle.dump(cat_score, f)
    with open(os.path.join(dists_score_path, 'averaged_scores'), 'wb') as f:
        pickle.dump(sorted_score, f)

In [8]:
ResNet = ResNetFeatureMapExtractor().to(device)
VGG = VGGFeatureMapExtractor().to(device)

In [None]:
run_syn(VGG, './dtd_small/', './', 'noise', 100, 0.1)