# Style transfer with AdaIN

In [2]:
import os
import glob
import csv
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn
from torchvision import transforms, datasets
from tqdm import tqdm

import util.util_validation as ut_val
from networks.resnet_big import SupCEResNet, SupConResNet, LinearClassifier, model_dict
from util.util_logging import open_csv_file

seaborn.set_theme(style="darkgrid")

In [3]:
cuda_device = 0

# dataset = "city_classification_original"
# styles_path = "./datasets/adaIN/paintings/selected_for_styletransfer/test/"
# output_path = "./datasets/adaIN/stylized_city_classification/test/"

dataset = "animals10_diff_-1"
styles_path = "./datasets/adaIN/textures_animals10_many/"
output_path = "./datasets/adaIN/shape_texture_conflict_animals10_many/"

# dataset = "animals10_diff_-1"
# styles_path = "./datasets/adaIN/paintings/selected_for_styletransfer/train/"
# output_path = "./datasets/adaIN/stylized_animals10/train/"


tmp_shape = "./datasets/adaIN/tmp/shape/"
tmp_style = "./datasets/adaIN/tmp/style/"

adaIn_venv_interpreter = "./../adain_venv/bin/python"
addIn_path = "./../pytorch-AdaIN/"

adaIn_execution_file = os.path.join(addIn_path, "test.py")
adaIn_execution_file_many = os.path.join(addIn_path, "test_many_individual.py")
adaIn_vgg = os.path.join(addIn_path, "models/vgg_normalised.pth")
adaIn_decoder = os.path.join(addIn_path, "models/decoder.pth")

classes = ut_val.get_classes(dataset)
root_train, root_test = ut_val.get_root_dataset(dataset)

image_loader = datasets.folder.default_loader

### Compute miss classified Shape Images

In [None]:
cuda_device = 0

models_dict = {"CE_baseline": ["./save/SupCE/animals10/SupCE_animals10_resnet18_lr_0.125_decay_0.0001_bsz_26_trial_0_baseline_cosine/models/last.pth", None],
               "CE_diffAug": ["./save/SupCE/animals10_diff_-1+4000/SupCE_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_trial_0_diffAug_cosine/models/last.pth", None],
               "CE_diffAugAllAug": ["./save/SupCE/animals10_diff_-1+4000/SupCE_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_trial_0_diffAugAllAug_cosine/models/last.pth", None],
               "SupCon_baseline": ["./save/SupCon/animals10_diff_-1/SupCon_animals10_diff_-1_resnet18_lr_0.125_decay_0.0001_bsz_26_temp_0.1_trial_0_try3_cosine/models/last.pth", ""],
               "SupCon_diffCSameSAug": ["./save/SupCon/animals10_diff_-1+4000/SupCon_animals10_diff_-1+4000_resnet18_lr_0.125_decay_0.0001_bsz_26_temp_0.1_trial_0_colorAugSameShapeAug_cosine/models/last.pth", ""]}

dataset_stConflict = "./datasets/adaIN/shape_texture_conflict_animals10/"

exclude_original_dict = ut_val.compute_exclude_dict(models_dict, dataset_stConflict, cuda_device)

csv_file = "./datasets/adaIN/experiments/exclude_animals10.csv"
with open(csv_file, 'w') as f:
    w = csv.DictWriter(f, exclude_original_dict)
    w.writeheader()
    w.writerow(exclude_original_dict)

### Create texture shape conflict Dataset

In [11]:
def create_files_dataFrame(root, classes):
    """
    Creates a pandas DataFrame for a image dataset of the form ./root/class/img.png

    Parameters
    ----------
    root: str
        The path to the dataset
    classes: iterable
        Containing the class names (folder names in the dataset)

    Returns
    ---------
    : pandas.DataFrame
    DataFrame with columns image for the file names of the images
    and label for the integer class label.
    """
    images = []
    labels = []
    for i, c in enumerate(classes):
        # get all image path for a class
        img_paths = glob.glob(os.path.join(root, f"{c}/*"))

        for img_path in img_paths:
            img_file = img_path.replace(os.path.join(root, f"{c}/"), '')
            images.append(img_file)
            labels.append(i)

    return pd.DataFrame.from_dict({'image': images, 'label': labels})

def adaIN(path_shape, path_style, path_output, size=300,
           adaIn_execution_file=adaIn_execution_file, adaIn_vgg=adaIn_vgg,
           adaIn_decoder=adaIn_decoder, cuda_device=cuda_device):

    os.makedirs(path_output, exist_ok=True)

    adaIn_call = f"CUDA_VISIBLE_DEVICES={cuda_device} "\
               + f"{adaIn_venv_interpreter} {adaIn_execution_file} "\
               + ("--content" if os.path.isfile(path_shape) else "--content_dir") + f" {path_shape} "\
               + ("--style" if os.path.isfile(path_style) else "--style_dir") + f" {path_style} "\
               + f"--content_size {size} --style_size {size} --crop "\
               + f"--output {path_output} --vgg {adaIn_vgg} --decoder {adaIn_decoder}"
    
    os.system(adaIn_call)

def adaIN_many_individual(path_shape, path_style, path_content_style_csv, path_output, size=300,
           adaIn_execution_file_many=adaIn_execution_file_many, adaIn_vgg=adaIn_vgg,
           adaIn_decoder=adaIn_decoder, cuda_device=cuda_device):

    os.makedirs(path_output, exist_ok=True)

    adaIn_call = f"CUDA_VISIBLE_DEVICES={cuda_device} "\
               + f"{adaIn_venv_interpreter} {adaIn_execution_file_many} "\
               + f"--content_dir {path_shape} "\
               + f"--style_dir {path_style} "\
               + f"--content_style_csv {path_content_style_csv} "\
               + f"--content_size {size} --style_size {size} --crop "\
               + f"--output {path_output} --vgg {adaIn_vgg} --decoder {adaIn_decoder}"
    
    os.system(adaIn_call)

------------

- stylize trainings data with paintings (in a way that for each class no painting is used more frequent than other ones)

In [12]:
painting_paths = np.array(glob.glob(os.path.join(styles_path, "*")))
painting_labels = np.arange(len(painting_paths))

df_train = create_files_dataFrame(root_train, classes)

df_train["painting_lable"] = len(df_train)*[-1]
df_train["painting"] = len(df_train)*[""]

for l, c in enumerate(tqdm(classes)):
    index_class = np.array(df_train[df_train.label == l].index)
    index_class = index_class[np.random.permutation(len(index_class))]

    if len(index_class) > len(painting_paths):
        index_split = np.array_split(index_class, len(painting_paths))
        painting_labels_select = painting_labels
    else:
        index_split = np.array_split(index_class, len(index_class))
        painting_labels_select = painting_labels[np.random.permutation(len(painting_labels))][:len(index_class)]

    for i, pl in enumerate(painting_labels_select):
        df_train.loc[index_split[i],"painting_lable"] = pl
        df_train.loc[index_split[i],"painting"] = os.path.split(painting_paths[pl])[-1]

100%|██████████| 10/10 [00:05<00:00,  1.76it/s]


In [26]:
path_content_style_csv = "./datasets/adaIN/tmp/content_style_pair.csv"

for l, c in enumerate(tqdm(classes)):
    path_content = os.path.join(root_train, c)
    out_path = os.path.join(output_path, c)

    df_train.query(f"label == {l}").sort_values("painting_lable").reset_index(drop=True)[["image", "painting"]].to_csv(
        path_content_style_csv, index=False, header=False)
    
    adaIN_many_individual(path_shape=path_content, path_style=styles_path, path_content_style_csv=path_content_style_csv,
                          path_output=out_path, cuda_device=cuda_device)

100%|██████████| 10/10 [12:04<00:00, 72.46s/it]


- resize the output to $300\times300$

In [4]:
resized_dataset = datasets.ImageFolder(root=output_path, transform=transforms.Resize(300))

for i in tqdm(range(len(resized_dataset))):
    img = resized_dataset[i][0]
    img_path = resized_dataset.imgs[i][0]

    img.save(img_path)

  0%|          | 0/4218 [00:00<?, ?it/s]

100%|██████████| 4218/4218 [00:11<00:00, 377.50it/s]


------------------

- stylize the dataset with paintings select at paintings random

In [6]:
style_kinds = np.array([style_kind.split('/')[-1] for style_kind in glob.glob(os.path.join(styles_path, "*"))])
style_paths = glob.glob(os.path.join(styles_path, "*", "*"))

df_test = create_files_dataFrame(root_test, classes)

N = int(np.ceil(len(df_test) / len(style_paths)))
style_index_shuffled = np.repeat(np.arange(len(style_paths)), repeats=N)[np.random.permutation(len(style_paths)*N)]

df_test["style_index"] = style_index_shuffled[:len(df_test)]
df_test["style_image"] = [style_paths[i].split('/')[-1] for i in style_index_shuffled[:len(df_test)]]

In [10]:
for l, c in enumerate(classes):
    for s_path in tqdm(style_paths):
        s_e = s_path.split('/')[-1]

        df_shape_style = df_test.query(f"label == {l} & style_image == '{s_e}'")
        images = df_shape_style["image"].values
        if len(images) > 0:
            out_path = os.path.join(output_path, c)

            image_loader(s_path).save(os.path.join(tmp_style, s_path.split('/')[-1]))
            
            for img in images:
                img_path = os.path.join(root_test, c, img)
                image_loader(img_path).save(os.path.join(tmp_shape, img))
            
            adaIN(path_shape=tmp_shape, path_style=tmp_style, path_output=out_path, cuda_device=cuda_device)

            tmp_files = glob.glob(os.path.join(tmp_shape, "*"))
            tmp_files.extend(glob.glob(os.path.join(tmp_style, "*")))

            for f in tmp_files:
                os.remove(f)

100%|██████████| 54/54 [04:56<00:00,  5.49s/it]
100%|██████████| 54/54 [04:26<00:00,  4.94s/it]
100%|██████████| 54/54 [04:25<00:00,  4.92s/it]
100%|██████████| 54/54 [03:37<00:00,  4.03s/it]
100%|██████████| 54/54 [03:10<00:00,  3.52s/it]
100%|██████████| 54/54 [03:16<00:00,  3.64s/it]
100%|██████████| 54/54 [01:46<00:00,  1.97s/it]
100%|██████████| 54/54 [02:23<00:00,  2.66s/it]
100%|██████████| 54/54 [03:18<00:00,  3.67s/it]
100%|██████████| 54/54 [03:05<00:00,  3.44s/it]
100%|██████████| 54/54 [01:45<00:00,  1.96s/it]


-----------

- exclude miss classified images
- stylize each shape with a texture of all other classes

In [4]:
csv_file = "./datasets/adaIN/experiments/exclude_animals10.csv"

loaded_exclude_dict = open_csv_file(csv_file)

In [61]:
style_kinds = np.sort(np.array([style_kind.split('/')[-1] for style_kind in glob.glob(os.path.join(styles_path, "*"))]))
style_paths = glob.glob(os.path.join(styles_path, "*", "*"))
num_styles_per_shape = len(style_kinds) + (-1 if len(set(classes).intersection(style_kinds)) > 0 else 0)

df_test = create_files_dataFrame(root_test, classes)

image_names = np.array([img.split('.')[0] for img in df_test["image"].values])

keep_indices = []
for c in loaded_exclude_dict:
    l = np.where(np.array(classes) == c)[0][0]
    excl_class_indices = df_test.query(f"label=={l}")[["image"]].map(lambda img: img.split('.')[0] not in loaded_exclude_dict[c]).query("image").index
    keep_indices.extend(excl_class_indices)
df_keep_test = df_test.iloc[keep_indices].copy().reset_index(drop=True)

min_class_count = df_keep_test.groupby("label").count().min()["image"]
df_stylize = df_keep_test.groupby("label").sample(min_class_count).copy().reset_index(drop=True)

for i in range(num_styles_per_shape):
    df_stylize[f"style_label_{i}"] = len(df_stylize)*[-1]

for l,c in enumerate(classes):
    for i in range(num_styles_per_shape):
        class_index = df_stylize[df_stylize.label == l].index
        df_stylize.loc[class_index, f"style_label_{i}"] = df_stylize[df_stylize.label == l][[f"style_label_{i}"]].map(lambda _: i + (1 if i >= l else 0))

df_stylize_list = [
    df_stylize[["image", "label", f"style_label_{i}"]].rename(columns={f"style_label_{i}": "style_label"})
    for i in range(num_styles_per_shape)
]
df_stylize = pd.concat(df_stylize_list).sort_values(["label", "image", "style_label"]).reset_index(drop=True)

df_stylize[f"style_image"] = len(df_stylize)*[""]
for tl,s in enumerate(style_kinds):
    for l,c in enumerate(classes):
        if s != c:
            style_index = df_stylize.query(f"label=={l} & style_label=={tl}").index
            if len(style_index) == 0:
                print(s, c)
            style_index = style_index[np.random.permutation(len(style_index))]

            style_examples = [s_e.split('/')[-1] for s_e in glob.glob(os.path.join(styles_path, s, "*"))]
            index_examples = np.array_split(style_index, len(style_examples))
            for j, s_e in enumerate(style_examples):
                df_stylize.loc[index_examples[j],"style_image"] = s_e

In [64]:
for l, c in enumerate(classes):
    for s_path in tqdm(style_paths):
        s_e = s_path.split('/')[-1]
        df_shape_style = df_stylize.query(f"label == {l} & style_image == '{s_e}'")
        images = df_shape_style["image"].values
        if len(images) > 0:
            style_label = df_shape_style["style_label"].values[0]
            out_path = os.path.join(output_path, c, style_kinds[style_label])

            image_loader(s_path).save(os.path.join(tmp_style, s_path.split('/')[-1]))
            
            for img in images:
                img_path = os.path.join(root_test, c, img)
                image_loader(img_path).save(os.path.join(tmp_shape, img))
            
            adaIN(path_shape=tmp_shape, path_style=tmp_style, path_output=out_path, cuda_device=cuda_device)

            tmp_files = glob.glob(os.path.join(tmp_shape, "*"))
            tmp_files.extend(glob.glob(os.path.join(tmp_style, "*")))

            for f in tmp_files:
                os.remove(f)

100%|██████████| 225/225 [15:53<00:00,  4.24s/it]
100%|██████████| 225/225 [15:38<00:00,  4.17s/it]
100%|██████████| 225/225 [14:53<00:00,  3.97s/it]
100%|██████████| 225/225 [15:33<00:00,  4.15s/it]
100%|██████████| 225/225 [15:39<00:00,  4.18s/it]
100%|██████████| 225/225 [14:51<00:00,  3.96s/it]
100%|██████████| 225/225 [15:42<00:00,  4.19s/it]
100%|██████████| 225/225 [15:25<00:00,  4.11s/it]
100%|██████████| 225/225 [16:04<00:00,  4.29s/it]
100%|██████████| 225/225 [14:57<00:00,  3.99s/it]


- resize the output to $300\times300$

In [15]:
conflict_dataset = ut_val.shapeTextureConflictDataset(root=output_path, transform=transforms.Resize(300))

for i in tqdm(range(len(conflict_dataset))):
    img = conflict_dataset[i][0]
    img_path = conflict_dataset.paths[i]

    img.save(img_path)

100%|██████████| 23310/23310 [00:53<00:00, 436.26it/s]


--------------

- stylize each shape with one random selected texture of an other class

In [4]:
style_kinds = np.array([style_kind.split('/')[-1] for style_kind in glob.glob(os.path.join(styles_path, "*"))])
style_paths = glob.glob(os.path.join(styles_path, "*", "*"))

df_test = create_files_dataFrame(root_test, classes)

df_test["style_label"] = len(df_test)*[-1]
df_test["style_image"] = len(df_test)*[""]

for l, c in enumerate(tqdm(classes)):
    index_class = np.array(df_test[df_test.label == l].index)
    index_class = index_class[np.random.permutation(len(index_class))]
    index_split = np.array_split(index_class, len(style_kinds) + (-1 if c in style_kinds else 0))

    style_labels = np.arange(len(style_kinds))
    style_labels = style_labels[style_kinds != c]

    for i, s in enumerate(style_kinds[style_kinds != c]):
        df_test.loc[index_split[i],"style_label"] = style_labels[i]

        style_examples = [s_e.split('/')[-1] for s_e in glob.glob(os.path.join(styles_path, s, "*")) ]

        index_examples = np.array_split(index_split[i], len(style_examples))
        for j, s_e in enumerate(style_examples):
            df_test.loc[index_examples[j],"style_image"] = s_e

100%|██████████| 10/10 [00:00<00:00, 79.93it/s]


In [10]:
for l, c in enumerate(classes):
    for s_path in tqdm(style_paths):
        s_e = s_path.split('/')[-1]
        df_shape_style = df_test.query(f"label == {l} & style_image == '{s_e}'")
        images = df_shape_style["image"].values
        if len(images) > 0:
            style_label = df_shape_style["style_label"].values[0]
            out_path = os.path.join(output_path, c, style_kinds[style_label])

            image_loader(s_path).save(os.path.join(tmp_style, s_path.split('/')[-1]))
            
            for img in images:
                img_path = os.path.join(root_test, c, img)
                image_loader(img_path).save(os.path.join(tmp_shape, img))
            
            adaIN(path_shape=tmp_shape, path_style=tmp_style, path_output=out_path, cuda_device=cuda_device)

            tmp_files = glob.glob(os.path.join(tmp_shape, "*"))
            tmp_files.extend(glob.glob(os.path.join(tmp_style, "*")))

            for f in tmp_files:
                os.remove(f)

100%|██████████| 20/20 [01:26<00:00,  4.33s/it]
100%|██████████| 20/20 [01:13<00:00,  3.69s/it]
100%|██████████| 20/20 [01:07<00:00,  3.36s/it]
100%|██████████| 20/20 [01:10<00:00,  3.55s/it]
100%|██████████| 20/20 [01:16<00:00,  3.82s/it]
100%|██████████| 20/20 [01:08<00:00,  3.42s/it]
100%|██████████| 20/20 [01:08<00:00,  3.40s/it]
100%|██████████| 20/20 [01:08<00:00,  3.42s/it]
100%|██████████| 20/20 [01:25<00:00,  4.25s/it]
100%|██████████| 20/20 [01:08<00:00,  3.45s/it]
