In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from glob import glob
import os
import pickle
from PIL import Image

In [None]:
import os
import platform
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [None]:
try:
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps'
    else:
        device = 'cpu'
except:
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
print(f"Using {device} as device")

## 0. Paths setup

In [None]:
e4e_model_path = f"{DATA_PATH}/Models/e4e/00005_snapshot_1200/setup/checkpoints/best_model.pt"
e4e_input_images_paths = [path for subdir in ['train', 'test'] for path in glob(f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/e4e_images/{subdir}/*.jpg")]
latents_save_dir = f"{DATA_PATH}/Models/e4e/00005_snapshot_1200/inversions/"
reconstructions_save_dir = f"{DATA_PATH}/Generated_Images/e4e/00005_snapshot_1200/"
sg2_generator_model_path = f"{DATA_PATH}/Models/Stylegan2_Ada/Experiments/00005-stylegan2_ada_images-mirror-auto2-kimg5000-resumeffhq512/network-snapshot-001200.pkl"

##  1. Get all Latent Codes

In [None]:
# Get into e4e dir to import inference code
os.chdir(f'{ROOT_PATH}/encoder4editing/')

from scripts.inference import *

# Setup Model
net, opts = setup_model(e4e_model_path, device)
generator = net.decoder
generator.eval()

# Get transforms
dataset_args = data_configs.DATASETS[opts.dataset_type]
transforms_dict = dataset_args['transforms'](opts).get_transforms()

# Define inversion function
def encode_from_path(image_path):
    # Input
    img_orig = Image.open(image_path).convert('RGB')
    img = transforms_dict['transform_test'](img_orig)
    img = img.reshape(-1, 3, 256, 256)
    img = img.to(device).float()

    latent = get_latents(net, img)
    return latent

In [None]:
print(f"Found {e4e_input_images_paths} images")


all_latents = torch.zeros(len(e4e_input_images_paths), 1, 16, 512)
all_paths = []
for i, path in enumerate(tqdm(e4e_input_images_paths)):
    sku = path.split('/')[-1].split('_')[-1].split('.')[0]
    latent = encode_from_path(path)
    latent = latent.cpu().detach()

    all_latents[i] = latent
    all_paths.append(path)

In [None]:
# Save latents tensor
torch.save(all_latents, f"{latents_save_dir}latents.pt")
# Save file paths
with open(f"{latents_save_dir}file_paths.pkl", 'wb') as handle:
    pickle.dump(all_paths, handle, protocol=pickle.HIGHEST_PROTOCOL)

## 2. Generate All Reconstructions


In [None]:
# Load Latents and File Paths
latents = torch.load(f"{latents_save_dir}latents.pt")
with open(f"{latents_save_dir}file_paths.pkl", 'rb') as f:
    file_paths = pickle.load(f)

# Create Latents Dict
latents_dict = {}
for i, file in enumerate(file_paths):
    sku = file.split('/')[-1].split('_')[-1].split('.')[0]
    latents_dict[sku] = latents[i]
# with open(f'{latents_save_dir}latents_dict.pkl', 'wb') as f:
#     pickle.dump(latents_dict, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# Initalize original custom SG2-Ada generator
os.chdir(f"{ROOT_PATH}/stylegan2-ada-pytorch/")
with open(sg2_generator_model_path, 'rb') as f:
    architecture = pickle.load(f)
    G = architecture['G_ema'].to(device)  # torch.nn.Module 
    D = architecture['D'].to(device)

# Go back into current dir
#os.chdir(current_wd)

In [None]:
def generate_from_sku(sku):
    latent = latents_dict[sku][0].unsqueeze(0).to(device)
    img = G.synthesis(latent)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = Image.fromarray(img[0].cpu().numpy(), 'RGB')
    return img

def generate_from_latent(latent):
    img = G.synthesis(latent)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = Image.fromarray(img[0].cpu().numpy(), 'RGB')
    return img

In [None]:
z = torch.randn([1,512], device = device)
img = G(z, None, noise_mode = 'const')

In [None]:
from IPython.utils import io
# Generate all Images and Save
for sku in tqdm(latents_dict.keys()):
    with io.capture_output() as captured:
        img = generate_from_sku(sku)
    img.save(f"{reconstructions_save_dir}{sku}.jpg")