In [1]:
from glob import glob
from tqdm import tqdm
import pickle
import numpy as np

### Define paths

In [2]:
import platform
import os

if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [3]:
real_path = f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/e4e_images/all/"
generated_path_00003 = f"{DATA_PATH}/Generated_Images/e4e/00003_snapshot_920/"
generated_path_00005 = f"{DATA_PATH}/Generated_Images/e4e/00005_snapshot_1200/"
generated_path_pti = f"{DATA_PATH}/Generated_Images/PTI/"
generated_path_restyle = f"{DATA_PATH}/Generated_Images/restyle/inference_results/4/"

scores_save_path_00003 = f"{DATA_PATH}/Metrics/LPIPS/00003_snapshot_920/lpips_scores_00003_snapshot_920.pkl"
scores_save_path_00005 = f"{DATA_PATH}/Metrics/LPIPS/00005_snapshot_1200/lpips_scores_00005_snapshot_1200.pkl"
scores_save_path_pti = f"{DATA_PATH}/Metrics/LPIPS/PTI/lpips_scores_pti.pkl"
scores_save_path_restyle = f"{DATA_PATH}/Metrics/LPIPS/Restyle/lpips_restyle_scores.pkl"

### Setup LPIPS

In [4]:
import lpips
import matplotlib.pyplot as plt
loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
loss_fn_alex = loss_fn_alex.to('cuda')

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /pfs/work7/workspace/scratch/tu_zxmav84-thesis/miniconda3/envs/thesis/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth


In [5]:
def get_lpips(sku, generated_path, real_path):
    real = lpips.im2tensor(lpips.load_image(f"{real_path}{sku}"))
    fake = lpips.im2tensor(lpips.load_image(f"{generated_path}{sku}"))
    fake = fake.cuda()
    real = real.cuda()
    score = loss_fn_alex.forward(real, fake)
    return score.item()

In [6]:
skus = glob(f"{real_path}*.jpg")
skus = [elem.split('/')[-1] for elem in skus]

### Calculate LPIPS Scores for e4e from 00003_snapshot_920

In [7]:
if not os.path.exists(scores_save_path_00003):
    lpips_scores_00003 = {sku: None for sku in skus}
    for sku in tqdm(skus):
        score = get_lpips(sku, generated_path_00003, real_path)
        lpips_scores_00003[sku] = score

    with open(scores_save_path_00003, 'wb') as handle:
        pickle.dump(lpips_scores_00003, handle, protocol=pickle.HIGHEST_PROTOCOL)
else: 
    with open(scores_save_path_00003, 'rb') as f:
        lpips_scores_00003 = pickle.load(f)

### Calculate LPIPS Scores for e4e from 00005_snapshot_1200

In [8]:
if not os.path.exists(scores_save_path_00005):
    lpips_scores_00005 = {sku: None for sku in skus}
    for sku in tqdm(skus):
        score = get_lpips(sku, generated_path_00005, real_path)
        lpips_scores_00005[sku] = score

    with open(scores_save_path_00005, 'wb') as handle:
        pickle.dump(lpips_scores_00005, handle, protocol=pickle.HIGHEST_PROTOCOL)
else: 
    with open(scores_save_path_00005, 'rb') as f:
        lpips_scores_00005 = pickle.load(f)

### Calculate LPIPS Scores for Restyle


In [9]:
if not os.path.exists(scores_save_path_restyle):
    lpips_scores_restyle = {sku: None for sku in skus}
    for sku in tqdm(skus):
        score = get_lpips(sku, generated_path_restyle, real_path)
        lpips_scores_restyle[sku] = score

    with open(scores_save_path_restyle, 'wb') as handle:
        pickle.dump(lpips_scores_restyle, handle, protocol=pickle.HIGHEST_PROTOCOL)
else: 
    with open(scores_save_path_restyle, 'rb') as f:
        lpips_scores_restyle = pickle.load(f)

### Calculate LPIPS Scores for PTI

In [10]:
skus = glob(f"{generated_path_pti}*.jpg")
skus = [elem.split('/')[-1] for elem in skus]
len(skus)

500

In [11]:
if not os.path.exists(scores_save_path_pti):
    lpips_scores_pti = {sku: None for sku in skus}
    for sku in tqdm(skus):
        score = get_lpips(sku, generated_path_pti, real_path)
        lpips_scores_pti[sku] = score

    with open(scores_save_path_pti, 'wb') as handle:
        pickle.dump(lpips_scores_pti, handle, protocol=pickle.HIGHEST_PROTOCOL)
else: 
    with open(scores_save_path_pti, 'rb') as f:
        lpips_scores_pti = pickle.load(f)

### Calculate Total LPIPS Scores

In [12]:
lpips_results = {}
lpips_results['e4e_00003'] = np.mean(list(lpips_scores_00003.values()))
lpips_results['e4e_00005'] = np.mean(list(lpips_scores_00005.values()))
lpips_results['PTI'] = np.mean(list(lpips_scores_pti.values()))
lpips_results['Restyle'] = np.mean(list(lpips_scores_restyle.values()))
lpips_results

{'e4e_00003': 0.14481622185478593,
 'e4e_00005': 0.11404780544934114,
 'PTI': 0.050234128189273176,
 'Restyle': 0.1137585677018397}

In [13]:
with open(f"{DATA_PATH}/Metrics/LPIPS/LPIPS_Results.pkl", 'wb') as f:
    pickle.dump(lpips_results, f, protocol=pickle.HIGHEST_PROTOCOL)