# Style transfer with AdaIN

In [1]:
import os
import glob
import shutil
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn
from torchvision import transforms, datasets
from tqdm import tqdm

import util.util_validation as ut_val
from networks.resnet_big import SupCEResNet, SupConResNet, LinearClassifier, model_dict

seaborn.set_theme(style="darkgrid")

In [5]:
cuda_device = 0

dataset = "animals10_diff_-1"
styles_path = "./datasets/adaIN/textures_animals10"
output_path = "./datasets/adaIN/shape_texture_conflict_animals10"


tmp_shape = "./datasets/adaIN/tmp/shape/"
tmp_style = "./datasets/adaIN/tmp/style/"

adaIn_venv_interpreter = "./../adain_venv/bin/python"
addIn_path = "./../pytorch-AdaIN/"

adaIn_execution_file = os.path.join(addIn_path, "test.py")
adaIn_vgg = os.path.join(addIn_path, "models/vgg_normalised.pth")
adaIn_decoder = os.path.join(addIn_path, "models/decoder.pth")

classes = ut_val.get_classes(dataset)
_, root_test = ut_val.get_root_dataset(dataset)

image_loader = datasets.folder.default_loader

In [3]:
def create_files_dataFrame(root, classes):
    """
    Creates a pandas DataFrame for a image dataset of the form ./root/class/img.png

    Parameters
    ----------
    root: str
        The path to the dataset
    classes: iterable
        Containing the class names (folder names in the dataset)

    Returns
    ---------
    : pandas.DataFrame
    DataFrame with columns image for the file names of the images
    and label for the integer class label.
    """
    images = []
    labels = []
    for i, c in enumerate(classes):
        # get all image path for a class
        img_paths = glob.glob(os.path.join(root, f"{c}/*"))

        for img_path in img_paths:
            img_file = img_path.replace(os.path.join(root, f"{c}/"), '')
            images.append(img_file)
            labels.append(i)

    return pd.DataFrame.from_dict({'image': images, 'label': labels})

def  adaIN(path_shape, path_style, path_output, size=300,
           adaIn_execution_file=adaIn_execution_file, adaIn_vgg=adaIn_vgg,
           adaIn_decoder=adaIn_decoder, cuda_device=cuda_device):

    os.makedirs(path_output, exist_ok=True)

    adaIn_call = f"CUDA_VISIBLE_DEVICES={cuda_device} "\
               + f"{adaIn_venv_interpreter} {adaIn_execution_file} "\
               + ("--content" if os.path.isfile(path_shape) else "--content_dir") + f" {path_shape} "\
               + ("--style" if os.path.isfile(path_style) else "--style_dir") + f" {path_style} "\
               + f"--content_size {size} --style_size {size} --crop "\
               + f"--output {path_output} --vgg {adaIn_vgg} --decoder {adaIn_decoder}"
    
    os.system(adaIn_call)

In [4]:
style_kinds = np.array([style_kind.split('/')[-1] for style_kind in glob.glob(os.path.join(styles_path, "*"))])
style_paths = glob.glob(os.path.join(styles_path, "*", "*"))

df_test = create_files_dataFrame(root_test, classes)# TODO Test this out

df_test["style_label"] = len(df_test)*[-1]
df_test["style_image"] = len(df_test)*[""]

for l, c in enumerate(tqdm(classes)):
    index_class = np.array(df_test[df_test.label == l].index)
    index_class = index_class[np.random.permutation(len(index_class))]
    index_split = np.array_split(index_class, len(style_kinds) + (-1 if c in style_kinds else 0))

    style_labels = np.arange(len(style_kinds))
    style_labels = style_labels[style_kinds != c]

    for i, s in enumerate(style_kinds[style_kinds != c]):
        df_test.loc[index_split[i],"style_label"] = style_labels[i]

        style_examples = [s_e.split('/')[-1] for s_e in glob.glob(os.path.join(styles_path, s, "*")) ]

        index_examples = np.array_split(index_split[i], len(style_examples))
        for j, s_e in enumerate(style_examples):
            df_test.loc[index_examples[j],"style_image"] = s_e

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 124.05it/s]


In [10]:
for l, c in enumerate(classes):
    for s_path in tqdm(style_paths):
        s_e = s_path.split('/')[-1]
        df_shape_style = df_test.query(f"label == {l} & style_image == '{s_e}'")
        images = df_shape_style["image"].values
        if len(images) > 0:
            style_label = df_shape_style["style_label"].values[0]
            out_path = os.path.join(output_path, c, style_kinds[style_label])

            image_loader(s_path).save(os.path.join(tmp_style, s_path.split('/')[-1]))
            
            for img in images:
                img_path = os.path.join(root_test, c, img)
                image_loader(img_path).save(os.path.join(tmp_shape, img))
            
            adaIN(path_shape=tmp_shape, path_style=tmp_style, path_output=out_path, cuda_device=cuda_device)

            tmp_files = glob.glob(os.path.join(tmp_shape, "*"))
            tmp_files.extend(glob.glob(os.path.join(tmp_style, "*")))

            for f in tmp_files:
                os.remove(f)

100%|██████████| 20/20 [01:26<00:00,  4.33s/it]
100%|██████████| 20/20 [01:13<00:00,  3.69s/it]
100%|██████████| 20/20 [01:07<00:00,  3.36s/it]
100%|██████████| 20/20 [01:10<00:00,  3.55s/it]
100%|██████████| 20/20 [01:16<00:00,  3.82s/it]
100%|██████████| 20/20 [01:08<00:00,  3.42s/it]
100%|██████████| 20/20 [01:08<00:00,  3.40s/it]
100%|██████████| 20/20 [01:08<00:00,  3.42s/it]
100%|██████████| 20/20 [01:25<00:00,  4.25s/it]
100%|██████████| 20/20 [01:08<00:00,  3.45s/it]
