In [1]:
import platform
import os
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [2]:
import pandas as pd
import numpy as np
import torch
import pickle
from PIL import Image
import matplotlib.pyplot as plt
from glob import glob
import warnings
warnings.filterwarnings('ignore')
from matplotlib import rc
rc('text', usetex=True)

In [3]:
meta = pd.read_csv('disentangled_typicality_scores.csv')

### Helper Functions

In [None]:
def tensor2im(var):
    var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
    var = ((var + 1) / 2)
    var[var < 0] = 0
    var[var > 1] = 1
    var = var * 255
    return Image.fromarray(var.astype('uint8'))

In [None]:
def setup_generator():
    os.chdir(f"{ROOT_PATH}/stylegan2-ada-pytorch")
    # Load model architecture
    experiment_path = f"{DATA_PATH}/Models/Stylegan2_Ada/Experiments/00005-stylegan2_ada_images-mirror-auto2-kimg5000-resumeffhq512/"
    model_name = "network-snapshot-001200.pkl"
    model_path = experiment_path + model_name
    with open(model_path, 'rb') as f:
        architecture = pickle.load(f)
        G = architecture['G_ema']
        D = architecture['D']
    os.chdir(current_wd)
    return G

G = setup_generator()

In [None]:
def generate_from_latent(latent):
    img = G.synthesis(latent, force_fp32=True, noise_mode = 'const')
    img = tensor2im(img.squeeze(0))
    return img

In [None]:
os.chdir(f"{ROOT_PATH}/2_Inversion/PTI/")
from pti_utils import load_pti
os.chdir(current_wd)

def generate_pti(latent, G_PTI):
    gen = G_PTI.synthesis(latent, noise_mode='const', force_fp32=True)
    img = tensor2im(gen.squeeze(0))
    return img

Using cpu as device


### Typicality Helper Functions

### InterFaceGAN Helper Functions

In [None]:
def get_interpolations(latent_code, start_distance, end_distance, steps, boundaries_base_dir):
    linspace = np.linspace(start_distance, end_distance, steps)

    # Repeat Latent for num_steps
    latent_code = latent_code.repeat(steps, 1, 1)

    boundary = np.load(f"{boundaries_base_dir}boundary_dim.npy")
    boundary = torch.tensor(boundary)
    for i in range(steps):
        latent_code[i, :, :] = latent_code[i, :, :] + linspace[i] * boundary
    
    return latent_code

In [7]:
def run_example(n, embedding_type = 'disentangled_embeddings_concat', direction = 'more', steps=5, distance=15, sku = None, generator = 'SG2'):

    boundaries_base_dir = f"{DATA_PATH}/Models/InterfaceGAN/Outputs/disentangled_typicality/{embedding_type}/{n}/"

    # Load in training stats
    training_stats = pd.read_csv(f"{boundaries_base_dir}summary_stats.csv")

    # Load in original Latents
    latents = torch.load(f"{DATA_PATH}/Models/e4e/00005_snapshot_1200/inversions/latents_dict.pt")

    # Filter to include only SKU for which PTI exists
    pti_skus = list(glob('/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Models/PTI/experiments/embeddings/zalando_germany/PTI/*'))
    pti_skus = [elem.split('/')[-1] for elem in pti_skus]
    df = meta[meta.sku.isin(pti_skus)]
    global chosen_sku

    if generator == 'PTI':
        # Sample one latent
        sample = df.sample(1)
        if sku:
            sample = df[df.sku == sku]
        else: 
            sample = df.sample(1)
            chosen_sku = sample.sku.item()

        device = torch.device('cpu')
        G_PTI, latent = load_pti(chosen_sku)

        # Freeze and set to eval
        G_PTI.eval()
        for param in G_PTI.parameters():
            param.requires_grad = False

        # Send to CPU
        G_PTI = G_PTI.to(device)
        latent_code = latent.to(device)
        latent_code =latent_code.squeeze(0).flatten()

    elif generator == 'SG2':
        if sku:
            sample = df[df.sku == sku]
        else: 
            sample = df.sample(1)
            chosen_sku = sample.sku.item()
        latent_code = latents[sample.sku.item()].squeeze(0).flatten()

    # Generate Interpolations
    if direction == 'more': 
        start_distance, end_distance = 0, distance
    elif direction == 'less': 
        start_distance, end_distance = 0,  -distance

    interpolations = get_interpolations(latent_code, start_distance, end_distance, steps, boundaries_base_dir)


    global imgs
    if generator == 'SG2':
        imgs = [generate_from_latent(interpolations[i, :, :].squeeze(0).reshape(1,16,512)) for i in range(steps)]
    elif generator == 'PTI':
        imgs = [generate_pti(interpolations[i, :, :].squeeze(0).reshape(1,16,512), G_PTI) for i in range(steps)]



    # Add real image for comparison of first inversion
    real = Image.open(f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/square_images/{chosen_sku}.jpg")

    imgs = [real] + imgs

    # Calculate typicality scores of generated images
    #scores = [calculate_typicality(img) for img in imgs]
    scores = [torch.tensor(0)] + [torch.tensor(0) for _ in range(steps)]

    fig, ax = plt.subplots(1, steps+1, figsize=(20, 5))
    ax = ax.ravel()
    for i in range(steps+1):
        ax[i].imshow(imgs[i])
        ax[i].axis('off')
            
        if i == 0: 
            ax[i].set_title(f'Original')
        elif i == 1: 
            ax[i].set_title(f'Inversion\nTypicality: {np.round(scores[i].item(), 2)}')
        else: 
            step = np.linspace(start_distance, end_distance, steps)[i-1]
            ax[i].set_title(f"{'+' if step > 0 else ''}{step}\nTypicality: {np.round(scores[i].item(), 2)}")

    
    fig.suptitle(f"SKU: {sample.sku.item()}\nDirection: {direction} typical\nGenerator type: " + r"\textbf{" + generator + "}")
    plt.show()

In [8]:
run_example(1000)

ValueError: a must be greater than 0 unless no samples are taken