In [1]:
import pickle
import torch
import numpy as np
import pandas as pd
import os 
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider
from PIL import Image
import json
from tqdm import tqdm

DATA_PATH = "../../../Data.nosync/"
save_path = f"{DATA_PATH}/Models/e4e/experiments_default_lr/inversions/" 

## Setup

In [2]:
# Load in metadata
metadata = json.load(open(f"{DATA_PATH}Zalando_Germany_Dataset/dresses/metadata/dresses_metadata.json", 'r'))
metadata_df = pd.read_json(f"{DATA_PATH}Zalando_Germany_Dataset/dresses/metadata/dresses_metadata.json").T.reset_index().rename(columns={'index':'id'})

In [3]:
# Load in latents and file paths 
latents = torch.load(f"{save_path}latents.pt")
with open(f"{save_path}file_paths.pkl", 'rb') as f:
    file_paths = pickle.load(f)

latents_dict = {}
for i, file in enumerate(file_paths):
    sku = file.split('/')[-1].split('_')[-1].split('.')[0]
    latents_dict[sku] = latents[i]

In [4]:
# Initalize original custom SG2-Ada generator
os.chdir("../../stylegan2-ada-pytorch/")

if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"Device: {device}")
model_path = f"../../Data.nosync/Models/Stylegan2_Ada/Experiments/00003-stylegan2_ada_images-mirror-auto2-kimg1000-resumeffhq512/network-snapshot-000920.pkl"
with open(model_path, 'rb') as f:
    architecture = pickle.load(f)
    G = architecture['G_ema'].to(device)  # torch.nn.Module 
    D = architecture['D'].to(device)

os.chdir('../2_Inversion/e4e/')

Device: mps


In [5]:
def generate_from_sku(sku):
    latent = latents_dict[sku][0].unsqueeze(0).to(device)
    img = G.synthesis(latent)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = Image.fromarray(img[0].cpu().numpy(), 'RGB')
    return img

def generate_from_latent(latent):
    img = G.synthesis(latent)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    img = Image.fromarray(img[0].cpu().numpy(), 'RGB')
    return img

### Generate all Images and save as generated images

In [6]:
save_dir = f"{DATA_PATH}/Generated_Images/Zalando_Germany_Reconstructions/"

for sku in tqdm(latents_dict.keys()):
    img = generate_from_sku(sku)
    img.save(f"{save_dir}{sku}.jpg")

  4%|▍         | 609/14060 [01:30<33:25,  6.71it/s]


KeyboardInterrupt: 