In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from glob import glob
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np

In [2]:
import pickle
import platform
import os
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [3]:
def set_device():
    try:
        if torch.cuda.is_available():
            device = 'cuda'
        elif torch.backends.mps.is_available():
            device = 'mps'
        else:
            device = 'cpu'
    except:
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'
    print(f"Using {device} as device")
    return device

device = set_device()

Using cuda as device


### Define Paths

In [4]:
real_path = f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/e4e_images/all/"
generated_path_00003 = f"{DATA_PATH}/Generated_Images/e4e/00003_snapshot_920/"
generated_path_00005 = f"{DATA_PATH}/Generated_Images/e4e/00005_snapshot_1200/"
generated_path_pti = f"{DATA_PATH}/Generated_Images/PTI/"
generated_path_restyle = f"{DATA_PATH}/Generated_Images/restyle/inference_results/4/"


scores_save_path_00003 = f"{DATA_PATH}/Metrics/MSSSIM/e4e_00003_snapshot_920/msssim_scores_e4e_00003_snapshot_920.pkl"
scores_save_path_00005 = f"{DATA_PATH}/Metrics/MSSSIM/e4e_00005_snapshot_1200/msssim_scores_e4e_00005_snapshot_1200.pkl"
scores_save_path_pti = f"{DATA_PATH}/Metrics/MSSSIM/PTI/msssim_scores_pti.pkl"
scores_save_path_restyle = f"{DATA_PATH}/Metrics/MSSSIM/Restyle/msssim_restyle_scores.pkl"

### Define Datasets and Calculation Functions

In [5]:
class PairedDataset(Dataset):

    def __init__(self, real_path, generated_path, transform):
        self.real_path = real_path
        self.generated_path = generated_path
        self.transform = transform

        self.pairs = []
        self.real_images = glob(f"{real_path}*.jpg")
        self.generate_images = glob(f"{generated_path}*.jpg")
        for sku in os.listdir(generated_path):
            real = [elem for elem in self.real_images if sku in elem][0]
            generated= [elem for elem in self.generate_images if sku in elem][0]
            self.pairs.append([real, generated])

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, index):
        real, fake = self.pairs[index]
        sku = real.split('/')[-1].split('.')[0]
        img_real = Image.open(real).convert('RGB')
        img_fake = Image.open(fake).convert('RGB')
        if self.transform:
            img_real = self.transform(img_real)
            img_fake = self.transform(img_fake)
        return img_real, img_fake, sku

In [6]:
transform = transforms.Compose([transforms.Resize((256, 256)),
									transforms.ToTensor(),
									transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
									])

In [7]:
from ms_ssim import MSSSIM
loss_fn = MSSSIM()
def calculate_msssim_losses(generated_path, save_path):
    if not os.path.exists(save_path):
        print(f'Calculating all MSSSIM-Losses for {save_path}')
        # Create Dataset and Loader
        print("\tConstructing Dataset")
        ds = PairedDataset(real_path, generated_path, transform)
        loader = DataLoader(ds, batch_size=1, shuffle = False)
        # Calculate Losses
        msssim_losses = {}
        for real, fake, sku in tqdm(loader, desc = '\tCalculating Losses'):
            msssim_losses[sku[0]] = float(loss_fn(real.to(device), fake.to(device)))
        
        # Save Losses
        with open(save_path, 'wb') as handle:
            pickle.dump(msssim_losses, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    else:
        print(f'MSSSIM-Losses already calculated for {generated_path}')
        print('Loading pre-calculated losses')
        with open(save_path, 'rb') as f:
            msssim_losses = pickle.load(f)

    return np.nanmean(list(msssim_losses.values()))
        

### Calulate all MS-SSIM-Losses

In [8]:
mean_msssim_losses = {}
# e4e 00003_snapshot_920
mean_msssim_losses['e4e_00003'] = calculate_msssim_losses(generated_path_00003, scores_save_path_00003)
# e4e 00005_snapthos_1200
mean_msssim_losses['e4e_00005'] = calculate_msssim_losses(generated_path_00005, scores_save_path_00005)
# PTI 
mean_msssim_losses['PTI'] = calculate_msssim_losses(generated_path_pti, scores_save_path_pti)
# Restyle 
mean_msssim_losses['Restyle'] = calculate_msssim_losses(generated_path_restyle, scores_save_path_restyle)
# Hyperstyle

MSSSIM-Losses already calculated for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Generated_Images/e4e/00003_snapshot_920/
Loading pre-calculated losses
MSSSIM-Losses already calculated for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Generated_Images/e4e/00005_snapshot_1200/
Loading pre-calculated losses
MSSSIM-Losses already calculated for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Generated_Images/PTI/
Loading pre-calculated losses
MSSSIM-Losses already calculated for /pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync/Generated_Images/restyle/inference_results/4/
Loading pre-calculated losses


In [9]:
with open(f"{DATA_PATH}/Metrics/MSSSIM/MSSSIM_Results.pkl", 'wb') as f:
    pickle.dump(mean_msssim_losses, f, protocol=pickle.HIGHEST_PROTOCOL)
mean_msssim_losses

{'e4e_00003': 0.8283094537174898,
 'e4e_00005': 0.8702037330478105,
 'PTI': 0.9450018156766892,
 'Restyle': 0.8918951418669431}