In [None]:
# !git clone https://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@github.com/NoahVl/Explaining-In-Style-Reproducibility-Study.git
# %cd Explaining-In-Style-Reproducibility-Study
# !git checkout main

In [None]:
# !pip install fire
# !pip install lpips
# !pip install einops
# !pip install kornia
# !pip install vector_quantize_pytorch
# !pip install Pillow
# !pip install pathlib
# !pip install aim

In [None]:
import os
import sys
import h5py
import numpy as np

import torch
from torch.utils.data import DataLoader
import math
import tqdm
import random
import imageio

import multiprocessing
from torchvision.utils import make_grid
from PIL import Image
import ast
import torchvision
from torchvision.datasets import ImageFolder
from torchvision.transforms.functional import resize

import requests
from PIL import ImageDraw
from PIL import ImageFont
from io import BytesIO
import IPython.display
from IPython.display import HTML
import matplotlib.pyplot as plt
from shutil import copyfile
import IPython.display as IPython_display
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

In [None]:
# %cd /kaggle/working
# !mkdir ./trained_classifiers
# copyfile("../input/facesattfind-all/faces-classifier.pt", "./trained_classifiers/faces-classifier.pt")

In [None]:
# %cd /kaggle/working
# %cd ../input/facesattfind-all
from resnet_classifier import ResNet
# %cd /kaggle/working

In [None]:
# %cd /kaggle/working
# %cd ../input/facesattfind-all

# TODO: You guys might want to change this to stylex_train_new
from stylex_train import StylEx, Dataset, DistributedSampler, MNIST_1vA
from stylex_train import image_noise, styles_def_to_tensor, make_weights_for_balanced_classes, cycle, default
# %cd /kaggle/working

In [None]:
def load_hdf5_results(data_file, name, threshold):
    return np.array(data_file[name])[0:threshold]

def model_loader(stylex_path,
                   classifier_name,
                   image_size,
                   cuda_rank):

    stylex = StylEx(image_size=image_size)
    stylex.load_state_dict(torch.load(stylex_path)["StylEx"])
    classifier = ResNet(classifier_name, cuda_rank=cuda_rank, output_size=2, image_size=image_size)
    return stylex, classifier

def sindex_to_block_idx_and_index(generator, sindex):
    tmp_idx = sindex

    block_idx = None
    idx = None

    for idx, block in enumerate(generator.blocks):
        if tmp_idx < block.num_style_coords:
            block_idx = idx
            idx = tmp_idx
            break
        else:
            tmp_idx = tmp_idx - block.num_style_coords

    return block_idx, idx

def plot_image(tensor, upscale_res=None):
    if upscale_res is not None:
        tensor = resize(tensor, upscale_res)
    grid = make_grid(tensor,nrow=5)
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)
    display(im)

In [None]:
data_path = "./data"
stylex_path = "./models/Faces-Resnet-ResizeFix64/model_300.pt"
classifier_name='resnet-18-64px-unfreezel4.pt'
results_folder = './'
threshold_folder = './'
dataset_name = None # for any dataset that is not MNIST
cuda_rank = 0

In [None]:
threshold_index = 501

hf = h5py.File('./style_change_records.hdf5', 'r')

style_change_effect = load_hdf5_results(hf, "style_change", threshold_index)
W_values = load_hdf5_results(hf, "latents", threshold_index)
base_probs = load_hdf5_results(hf, "base_prob", threshold_index)
all_style_vectors = load_hdf5_results(hf, "style_coordinates", threshold_index)
original_images = load_hdf5_results(hf, "original_images", threshold_index)
discriminator_results = load_hdf5_results(hf, "discriminator", threshold_index)

saved_noise = torch.Tensor(np.array(hf["noise"])).cuda(cuda_rank)
style_min = torch.Tensor(np.squeeze(np.array(hf["minima"])))
style_max = torch.Tensor(np.squeeze(np.array(hf["maxima"])))

all_style_vectors_distances = np.zeros((all_style_vectors.shape[0], all_style_vectors.shape[1], 2))
all_style_vectors_distances[:,:, 0] = all_style_vectors - np.tile(style_min, (all_style_vectors.shape[0], 1))
all_style_vectors_distances[:,:, 1] = np.tile(style_max, (all_style_vectors.shape[0], 1)) - all_style_vectors

style_min = style_min.cuda(cuda_rank)
style_max = style_max.cuda(cuda_rank)

In [None]:
num_style_coords = len(style_min)
image_size = original_images.shape[-1]
shift_size = 0.5
batch_size = 1

In [None]:
stylex, classifier = model_loader(stylex_path = stylex_path,
                                  classifier_name = classifier_name,
                                  image_size = image_size,
                                  cuda_rank = cuda_rank)

In [None]:
all_labels = np.argmax(base_probs, axis=1)
style_effect_classes = {}
W_classes = {}
style_vectors_distances_classes = {}
all_style_vectors_classes = {}

for img_ind in range(2):
    
    img_inx = np.array([i for i in range(all_labels.shape[0]) if all_labels[i] == img_ind])
    curr_style_effect = np.zeros((len(img_inx), style_change_effect.shape[1], style_change_effect.shape[2], style_change_effect.shape[3]))
    curr_w = np.zeros((len(img_inx), W_values.shape[1]))
    curr_style_vector_distances = np.zeros((len(img_inx), style_change_effect.shape[2], 2))
    
    for k, i in enumerate(img_inx):
        curr_style_effect[k, :, :] = style_change_effect[i, :, :, :]
        curr_w[k, :] = W_values[i, :]
        curr_style_vector_distances[k, :, :] = all_style_vectors_distances[i, :, :]
        
    style_effect_classes[img_ind] = curr_style_effect
    W_classes[img_ind] = curr_w
    style_vectors_distances_classes[img_ind] = curr_style_vector_distances
    all_style_vectors_classes[img_ind] = all_style_vectors[img_inx]
    print(f'Class {img_ind}, {len(img_inx)} images.')

In [None]:
def find_significant_styles(style_change_effect,
                            num_indices,
                            class_index,
                            generator,
                            classifier,
                            all_dlatents,
                            style_min,
                            style_max,
                            max_image_effect = 0.2,
                            label_size = 2,
                            sindex_offset = 0):
  
    num_images = style_change_effect.shape[0]
    style_effect_direction = np.maximum(0, style_change_effect[:, :, :, class_index].reshape((num_images, -1)))

    images_effect = np.zeros(num_images)
    all_sindices = []
    discriminator_removed = []

    while len(all_sindices) < num_indices:
        next_s = np.argmax(np.mean(style_effect_direction[images_effect < max_image_effect], axis=0))
        
        all_sindices.append(next_s)
        images_effect += style_effect_direction[:, next_s]
        style_effect_direction[:, next_s] = 0

    return [(x // style_change_effect.shape[2], (x % style_change_effect.shape[2]) + sindex_offset) for x in all_sindices]

In [None]:
label_size_clasifier = 2
num_indices =  10
effect_threshold = 0.5
s_indices_and_signs_dict = {}

for class_index in [0, 1]:
    split_ind =  class_index #1 - class_index
    all_s = style_effect_classes[split_ind]
    all_w = W_classes[split_ind]

    # Find s indicies
    s_indices_and_signs = find_significant_styles(style_change_effect=all_s,
                                                  num_indices=num_indices,
                                                  class_index=class_index,
                                                  generator=stylex.G,
                                                  classifier=classifier,
                                                  all_dlatents=all_w,
                                                  style_min=style_min,
                                                  style_max=style_max,
                                                  max_image_effect=effect_threshold*5,
                                                  label_size=label_size_clasifier,
                                                  sindex_offset=0)

    s_indices_and_signs_dict[class_index] = s_indices_and_signs

sindex_class_0 = [sindex for _, sindex in s_indices_and_signs_dict[0]]
all_sindex_joined_class_0 = [(1 - direction, sindex) for direction, sindex in s_indices_and_signs_dict[1] if sindex not in sindex_class_0]
all_sindex_joined_class_0 += s_indices_and_signs_dict[0]
scores = []

for direction, sindex in all_sindex_joined_class_0:
    other_direction = 1 if direction == 0 else 0
    curr_score = np.mean(style_change_effect[:, direction, sindex, 0]) + np.mean(style_change_effect[:, other_direction, sindex, 1])
    scores.append(curr_score)

s_indices_and_signs = [all_sindex_joined_class_0[i] for i in np.argsort(scores)[::-1]]

print('Directions and style indices for moving from class 1 to class 0 = ', s_indices_and_signs[:num_indices])
print('Use the other direction to move for class 0 to 1.')

In [None]:
def generate_user_study_img(tensor, upscale_res=None, nrow=2) -> None:
    """
    Plots an image from a tensor.
    """
    changed_tensor = tensor.clone()
    if upscale_res is not None:
        changed_tensor = resize(changed_tensor, upscale_res)


    grid = make_grid(changed_tensor, nrow=nrow)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    im = Image.fromarray(ndarr)

    return im

def get_images(dlatent,
                generator,
                classifier,
                sindex,
                s_style_min,
                s_style_max,
                style_direction_index,
                shift_size,
                label_size,
                noise, 
                cuda_rank):
    
    dlatent = [(torch.unsqueeze(torch.Tensor(dlatent).cuda(cuda_rank), 0), 5)]
    w_latent_tensor = styles_def_to_tensor(dlatent)
    generated_image, style_coords = generator(w_latent_tensor, noise, get_style_coords=True)

    block_idx, weight_idx = sindex_to_block_idx_and_index(generator, sindex)
    block = generator.blocks[block_idx]

    current_style_layer = None
    one_hot = None

    if weight_idx < block.input_channels:
        current_style_layer = block.to_style1
        one_hot = torch.zeros((1, block.input_channels)).cuda(cuda_rank)
    else:
        weight_idx -= block.input_channels
        current_style_layer = block.to_style2
        one_hot = torch.zeros((1, block.filters)).cuda(cuda_rank)

    one_hot[:, weight_idx] = 1


    if style_direction_index == 0:
        shift = one_hot * ((s_style_min - style_coords[:, sindex]) * shift_size).unsqueeze(1)
    else:
        shift = one_hot * ((s_style_max - style_coords[:, sindex]) * shift_size).unsqueeze(1)

    with torch.no_grad():
        shift = shift.squeeze(0)
        current_style_layer.bias += shift
        changed_image, _ = generator(w_latent_tensor, noise, get_style_coords=True)
        shift_logits = classifier.classify_images(changed_image)
        current_style_layer.bias -= shift
    
    return generated_image, changed_image

In [None]:
def user_study1(stylex,
                classifier,
                gender,
                top_attribute,
                other_attribute,
                s_indices_and_signs,
                base_probs,
                W_values,
                style_min,
                style_max,
                shift_size,
                label_size,
                noise,
                cuda_rank,
                upscale_res):

    if other_attribute >= len(s_indices_and_signs):
        raise ValueError("That attribute is not included in the attribute list, it only includes %s images" % len(s_indices_and_signs))
    if gender == "male" or gender == "female":
        gender_index = 0 if gender == "male" else 1
    else:
        raise ValueError("Please use either the male or female gender")
    
    gender_indices = list(np.where(np.argmax(base_probs, axis=1) == gender_index)[0])
    image_ids = random.sample(gender_indices, 4)
    four_latents = W_values[image_ids]
    
    indices_and_signs = np.array([s_indices_and_signs[top_attribute]]*2 + random.sample([s_indices_and_signs[top_attribute],
                                                                                s_indices_and_signs[other_attribute]], 2))
    indices_and_signs = indices_and_signs[[0,2,1,3]]

    g_images = []
    c_images = []
    
    for index in range(4):
        direction_index, style_index = indices_and_signs[index]
        if gender_index == 0:
            style_direction = 1 if direction_index == 0 else 0
        else:
            style_direction = direction_index
        
        generated_image, changed_image = get_images(four_latents[index],
                                                    stylex.G,
                                                    classifier,
                                                    style_index,
                                                    style_min[style_index],
                                                    style_max[style_index],
                                                    style_direction,
                                                    shift_size,
                                                    label_size,
                                                    noise,
                                                    cuda_rank)
        g_images.append(generated_image)
        c_images.append(changed_image)
        
    g_images = torch.cat(g_images)
    c_images = torch.cat(c_images)
        
    if upscale_res != None:
        g_images = generate_user_study_img(g_images, upscale_res=upscale_res, nrow=2)
        c_images = generate_user_study_img(c_images, upscale_res=upscale_res, nrow=2)
            
    return [g_images, c_images], indices_and_signs

In [None]:
top_k_range = 6
# Render the images
images_to_save = []
info_of_images = []
unique_attributes = list(range(top_k_range))
random.shuffle(unique_attributes)
for top_attribute in range(top_k_range):
    gender = "male" if top_attribute % 2 == 0 else "female" 
    
    for index, attr in enumerate(unique_attributes):
        if attr != top_attribute:
            other_attribute = unique_attributes.pop(index)
            break
    
    images_to_render, attribute_info = user_study1(stylex = stylex,
                                        classifier = classifier,
                                        gender = gender,
                                        top_attribute = top_attribute,
                                        other_attribute = other_attribute,
                                        s_indices_and_signs = s_indices_and_signs,
                                        base_probs = base_probs,
                                        W_values = W_values,
                                        style_min = style_min,
                                        style_max = style_max,
                                        shift_size = 1,
                                        label_size = 2,
                                        noise = saved_noise,
                                        cuda_rank = cuda_rank,
                                        upscale_res=512)
    
    images_to_save.append(images_to_render)
    info_of_images.append((attribute_info, (top_attribute, other_attribute)))

In [None]:
from PIL import Image, ImageFont, ImageDraw
import IPython.display as IPython_display

# Save the images of user study 1 to disk
# Check if user study directory exists, if not, create it
user_study_dir = "./user_study_images"
user_study1_dir = os.path.join(user_study_dir, "study_1")
user_study2_dir = os.path.join(user_study_dir, "study_2")

if not os.path.exists(user_study_dir):
    os.makedirs(user_study_dir)

    # Make a directory for each user study
    os.makedirs(user_study1_dir)
    os.makedirs(user_study2_dir)

for index, (images, _) in enumerate(zip(images_to_save, info_of_images)):
    file_path = os.path.join(user_study1_dir, f"class_study_{index}.gif")
    imageio.mimsave(file_path, images, fps=1 + 1/3)
    display(IPython_display.Image(filename=file_path))

# Save the info_of_images to a text file
with open(os.path.join(user_study1_dir, "info_of_images.txt"), "w") as f:
    for directions_and_sindex, chosen_attributes in info_of_images:
        target_attr = directions_and_sindex[0]
        correct_answer = "top-right" if (target_attr == directions_and_sindex[1]).all() else "bottom-right"
        f.write(f"Same transformation in {correct_answer} \n {chosen_attributes} \n {str(directions_and_sindex)} \n\n")

In [None]:
def user_study2(stylex,
                classifier,
                top_attribute,
                s_indices_and_signs,
                W_values,
                style_min,
                style_max,
                shift_size,
                label_size,
                noise,
                cuda_rank,
                upscale_res):
    
    image_ids = random.sample(range(len(W_values)), 4)
    four_latents = W_values[image_ids]
    
    indices_and_signs = np.array([s_indices_and_signs[top_attribute]] * 4)
    indices_and_signs = indices_and_signs[[0,2,1,3]] 

    g_images = []
    c_images = []
    
    for index in range(4):
        style_direction, style_index = indices_and_signs[index]

        
        generated_image, changed_image = get_images(four_latents[index],
                                                    stylex.G,
                                                    classifier,
                                                    style_index,
                                                    style_min[style_index],
                                                    style_max[style_index],
                                                    style_direction,
                                                    shift_size,
                                                    label_size,
                                                    noise,
                                                    cuda_rank)

        g_images.append(generated_image)
        c_images.append(changed_image)
        
    g_images = torch.cat(g_images)
    c_images = torch.cat(c_images)
        
    if upscale_res != None:
        g_images = generate_user_study_img(g_images, upscale_res=upscale_res, nrow=4)
        c_images = generate_user_study_img(c_images, upscale_res=upscale_res, nrow=4)
            
    return [g_images, c_images]

In [None]:
top_k_range = 6
# Render the images
images_to_save = []
for top_attribute in range(top_k_range):
    images_to_render = user_study2(stylex = stylex,
                                    classifier = classifier,
                                    top_attribute = top_attribute,
                                    s_indices_and_signs = s_indices_and_signs,
                                    W_values = W_values,
                                    style_min = style_min,
                                    style_max = style_max,
                                    shift_size = 1,
                                    label_size = 2,
                                    noise = saved_noise,
                                    cuda_rank = cuda_rank,
                                    upscale_res=512)
    
    images_to_save.append(images_to_render)

In [None]:
from PIL import Image, ImageFont, ImageDraw

for index, image_to_render in enumerate(images_to_save):
    for image, text in zip(image_to_render, ["Before", "After"]):
        font = ImageFont.truetype("./Roboto-Bold.ttf", size=30)
        image_editable = ImageDraw.Draw(image)
        image_editable.text((15,15), text, (255, 0, 0), font=font)
    
    file_path = os.path.join(user_study2_dir, f"verbal_study_{index}.gif")
    imageio.mimsave(file_path, image_to_render, fps=1)
    display(IPython_display.Image(filename=file_path))