In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np

In [2]:
import pickle
import platform
import os
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [3]:
def set_device():
    try:
        if torch.cuda.is_available():
            device = 'cuda'
        elif torch.backends.mps.is_available():
            device = 'mps'
        else:
            device = 'cpu'
    except:
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'
    print(f"Using {device} as device")
    return device

device = set_device()

Using cpu as device


  return torch._C._cuda_getDeviceCount() > 0


### Define Paths

In [4]:
real_path = f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/e4e_images/all/"
generated_path_00003 = f"{DATA_PATH}/Generated_Images/e4e/00003_snapshot_920/"
generated_path_00005 = f"{DATA_PATH}/Generated_Images/e4e/00005_snapshot_1200/"
generated_path_pti = f"{DATA_PATH}/Generated_Images/PTI/"
generated_path_restyle = f"{DATA_PATH}/Generated_Images/restyle/inference_results/4/"


scores_save_path_00003 = f"{DATA_PATH}/Metrics/L2/e4e_00003_snapshot_920/l2_scores_e4e_00003_snapshot_920.pkl"
scores_save_path_00005 = f"{DATA_PATH}/Metrics/L2/e4e_00005_snapshot_1200/l2_scores_e4e_00005_snapshot_1200.pkl"
scores_save_path_pti = f"{DATA_PATH}/Metrics/L2/PTI/l2_scores_pti.pkl"
scores_save_path_restyle = f"{DATA_PATH}/Metrics/L2/Restyle/l2_restyle_scores.pkl"

### Define Datasets and Calculation Functions

In [5]:
class PairedDataset(Dataset):

    def __init__(self, real_path, generated_path, transform):
        self.real_path = real_path
        self.generated_path = generated_path
        self.transform = transform

        self.pairs = []
        self.real_images = glob(f"{real_path}*.jpg")
        self.generate_images = glob(f"{generated_path}*.jpg")
        for sku in os.listdir(generated_path):
            real = [elem for elem in self.real_images if sku in elem][0]
            generated= [elem for elem in self.generate_images if sku in elem][0]
            self.pairs.append([real, generated])

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, index):
        real, fake = self.pairs[index]
        sku = real.split('/')[-1].split('.')[0]
        img_real = Image.open(real).convert('RGB')
        img_fake = Image.open(fake).convert('RGB')
        if self.transform:
            img_real = self.transform(img_real)
            img_fake = self.transform(img_fake)
        return img_real, img_fake, sku

In [6]:
transform = transforms.Compose([transforms.Resize((256, 256)),
									transforms.ToTensor(),
									transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
									])

In [7]:
loss_fn = torch.nn.MSELoss()
def calculate_l2_losses(generated_path, save_path):
    if not os.path.exists(save_path):
        print(f'Calculating all L2-Losses for {save_path}')
        # Create Dataset and Loader
        print("\tConstructing Dataset")
        ds = PairedDataset(real_path, generated_path, transform)
        loader = DataLoader(ds, batch_size=1, shuffle = False)
        # Calculate Losses
        l2_losses = {}
        for real, fake, sku in tqdm(loader, desc = '\tCalculating Losses'):
            l2_losses[sku[0]] = float(loss_fn(real.to(device), fake.to(device)))
        
        # Save Losses
        with open(save_path, 'wb') as handle:
            pickle.dump(l2_losses, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    else:
        print(f'L2-Losses already calculated for {generated_path}')
        print('Loading pre-calculated losses')
        with open(save_path, 'rb') as f:
            l2_losses = pickle.load(f)

    return np.mean(list(l2_losses.values()))
        

### Calulate all L2-Losses

In [8]:
mean_l2_losses = {}
# e4e 00003_snapshot_920
mean_l2_losses['e4e_00003'] = calculate_l2_losses(generated_path_00003, scores_save_path_00003)
# e4e 00005_snapthos_1200
mean_l2_losses['e4e_00005'] = calculate_l2_losses(generated_path_00005, scores_save_path_00005)
# PTI 
mean_l2_losses['PTI'] = calculate_l2_losses(generated_path_pti, scores_save_path_pti)
# Restyle 
mean_l2_losses['Restyle'] = calculate_l2_losses(generated_path_restyle, scores_save_path_restyle)
# Hyperstyle

Calculating all L2-Losses for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Metrics/L2/e4e_00003_snapshot_920/l2_scores_e4e_00003_snapshot_920.pkl
	Constructing Dataset


	Calculating Losses: 100%|██████████| 14060/14060 [06:34<00:00, 35.64it/s]


Calculating all L2-Losses for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Metrics/L2/e4e_00005_snapshot_1200/l2_scores_e4e_00005_snapshot_1200.pkl
	Constructing Dataset


	Calculating Losses: 100%|██████████| 14060/14060 [04:55<00:00, 47.54it/s]


Calculating all L2-Losses for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Metrics/L2/PTI/l2_scores_pti.pkl
	Constructing Dataset


	Calculating Losses: 100%|██████████| 500/500 [00:06<00:00, 81.57it/s]


Calculating all L2-Losses for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Metrics/L2/Restyle/l2_restyle_scores.pkl
	Constructing Dataset


	Calculating Losses: 100%|██████████| 14060/14060 [04:08<00:00, 56.67it/s]


In [10]:
with open(f"{DATA_PATH}/Metrics/L2/L2_Results.pkl", 'wb') as f:
    pickle.dump(mean_l2_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
mean_l2_losses

{'e4e_00003': 0.03234630992132584,
 'e4e_00005': 0.020278500808366007,
 'PTI': 0.008501449547009543,
 'Restyle': 0.014734626814979194}