# Using FDA to evaluate image style transfer

This notebook is based upon `FDA: Fourier Domain Adaptation for Semantic Segmentation` from Yanchao Yang.

Before starting, the original data and the transferred images should be stored in a known path.

In [1]:
import torch
import numpy as np

def high_freq_mutate( amp_src, amp_trg, L=0.1 ):
    a_src = torch.fft.fftshift( amp_src, dim =(-2, -1) )
    a_trg = torch.fft.fftshift( amp_trg, dim =(-2, -1) )

    c, h, w = a_src.shape
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)
    c_h = np.floor(h/2.0).astype(int)
    c_w = np.floor(w/2.0).astype(int)

    h1 = c_h-b
    h2 = c_h+b+1
    w1 = c_w-b
    w2 = c_w+b+1

    a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
    a_src = torch.fft.ifftshift( a_src, dim =(-2, -1) )
    return a_src

def FDA_source_to_target(src_img, trg_img, L=0.1):
    # exchange magnitude
    # input: src_img, trg_img

    src_img_torch = src_img.clone() #.cpu().numpy()
    trg_img_torch = trg_img.clone() #.cpu().numpy()

    # get fft of both source and target
    fft_src_torch = torch.fft.rfft2( src_img_torch, dim=(-2, -1) )
    fft_trg_torch = torch.fft.rfft2( trg_img_torch, dim=(-2, -1) )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = torch.abs(fft_src_torch), torch.angle(fft_src_torch)
    amp_trg, pha_trg = torch.abs(fft_trg_torch), torch.angle(fft_trg_torch)

    # mutate the amplitude part of source with target
    amp_src_ = high_freq_mutate( amp_src, amp_trg, L=L )

    # mutated fft of source
    fft_src_ = amp_src_ * torch.exp( 1j * pha_src )

    # get the mutated image
    src_in_trg = torch.fft.irfft2( fft_src_, dim=(-2, -1) )
    #src_in_trg = torch.real(src_in_trg)

    return src_in_trg

def FDA_distance_torch( src_img, src2trg_img, L=0.1 , normalize = False, display = False):
    # exchange magnitude
    # input: src_img, trg_img

    src_img_torch = src_img.clone() #.cpu().numpy()
    src2trg_img_torch = src2trg_img.clone() #.cpu().numpy()

    # get fft of both source and target
    fft_src_torch = torch.fft.rfft2( src_img_torch, dim=(-2, -1) )
    fft_trg_torch = torch.fft.rfft2( src2trg_img_torch, dim=(-2, -1) )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = torch.abs(fft_src_torch), torch.angle(fft_src_torch)
    amp_trg, pha_trg = torch.abs(fft_trg_torch), torch.angle(fft_trg_torch)

    # mutate the amplitude part of source with target
    low_freq_part, a_src, a_trg = high_freq_part_torch( amp_src, amp_trg, L=L, normalize = normalize )

    low_freq_dist_fro = torch.linalg.norm(torch.flatten(low_freq_part))

    low_freq_dist_L1 = torch.linalg.norm(torch.flatten(low_freq_part), ord = 1)

    low_freq_dist_inf = torch.linalg.norm(torch.flatten(low_freq_part), ord = float('inf'))
    
    low_freq_dist = (low_freq_dist_fro, low_freq_dist_L1, low_freq_dist_inf)
    
    if display:

        # mutated fft of source
        fft_src_ = a_src * torch.exp( 1j * pha_src )
        # mutated fft of source
        fft_trg_ = a_trg * torch.exp( 1j * pha_trg )
        
        low_freq_part_src_ = low_freq_part * torch.exp( 1j * pha_src )
        low_freq_part_trg_ = low_freq_part * torch.exp( 1j * pha_trg )
        
        src_wo_style = torch.fft.irfft2( fft_src_, dim=(-2, -1) )
        trg_wo_style = torch.fft.irfft2( fft_trg_, dim=(-2, -1) )
        low_freq_part_ifft = torch.fft.irfft2( low_freq_part, dim=(-2, -1) )
        low_freq_part_src_ = torch.fft.irfft2( low_freq_part_src_, dim=(-2, -1) )
        low_freq_part_trg_ = torch.fft.irfft2( low_freq_part_trg_, dim=(-2, -1) )
        
        low_freq_tuple = (low_freq_part, low_freq_part_ifft, low_freq_part_src_, low_freq_part_trg_)
        
        return low_freq_dist, low_freq_tuple, src_wo_style, trg_wo_style
        
    else:
        
        return low_freq_dist

def high_freq_part_torch( amp_src, amp_trg, L=0.1, normalize = False):
    a_src = torch.fft.fftshift( amp_src, dim =(-2, -1) )
    a_trg = torch.fft.fftshift( amp_trg, dim =(-2, -1) )
    
    
    #print(a_src.shape)
    #print(a_trg.shape)
    
    max_src_temp = a_src.max(dim = 1)
    max_trg_temp = a_trg.max(dim = 1)
    max_src = max_src_temp.values.max(dim = 1)
    max_trg = max_trg_temp.values.max(dim = 1)
    #print(max_src.values)
    #print(max_trg.values)
    #print(a_src.max())
    #print(a_trg.max())

    c, h, w = a_src.shape
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)
    c_h = np.floor(h/2.0).astype(int)
    c_w = np.floor(w/2.0).astype(int)

    h1 = c_h-b
    h2 = c_h+b+1
    w1 = c_w-b
    w2 = c_w+b+1
    
    a_src[:,h1:h2,w1:w2] = a_src[:,h1:h2,w1:w2] * 0
    a_trg[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2] * 0
    
    if normalize:
    
        if 0 not in max_src.values:
            low_freq_part = torch.div(a_src.permute((1, 2, 0)), max_src.values) - torch.div(a_trg.permute((1, 2, 0)), max_trg.values)
        else:
            low_freq_part = a_src * 0
            
        #print(low_freq_part.shape)
            
        low_freq_part = low_freq_part / ((2 * b) * (2 * b))
        
    else:
        
        low_freq_part = a_src - a_trg
        
    
    a_src = torch.fft.ifftshift( a_src, dim =(-2, -1) )
    a_trg = torch.fft.ifftshift( a_trg, dim =(-2, -1) )
    
    low_freq_part = torch.fft.ifftshift( low_freq_part, dim =(-2, -1) )
    
    return low_freq_part, a_src, a_trg

In [2]:
# gta5 dataset (source)

import os
import os.path as osp
import sys

import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
from torch.utils import data
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

'''
labels = [
    #       name                     id    trainId   category            catId     hasInstances   ignoreInEval   color
    Label(  'unlabeled'            ,  0 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'ego vehicle'          ,  1 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'rectification border' ,  2 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'out of roi'           ,  3 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'static'               ,  4 ,      255 , 'void'            , 0       , False        , True         , (  0,  0,  0) ),
    Label(  'dynamic'              ,  5 ,      255 , 'void'            , 0       , False        , True         , (111, 74,  0) ),
    Label(  'ground'               ,  6 ,      255 , 'void'            , 0       , False        , True         , ( 81,  0, 81) ),
    Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (128, 64,128) ),
    Label(  'sidewalk'             ,  8 ,        1 , 'flat'            , 1       , False        , False        , (244, 35,232) ),
    Label(  'parking'              ,  9 ,      255 , 'flat'            , 1       , False        , True         , (250,170,160) ),
    Label(  'rail track'           , 10 ,      255 , 'flat'            , 1       , False        , True         , (230,150,140) ),
    Label(  'building'             , 11 ,        2 , 'construction'    , 2       , False        , False        , ( 70, 70, 70) ),
    Label(  'wall'                 , 12 ,        3 , 'construction'    , 2       , False        , False        , (102,102,156) ),
    Label(  'fence'                , 13 ,        4 , 'construction'    , 2       , False        , False        , (190,153,153) ),
    Label(  'guard rail'           , 14 ,      255 , 'construction'    , 2       , False        , True         , (180,165,180) ),
    Label(  'bridge'               , 15 ,      255 , 'construction'    , 2       , False        , True         , (150,100,100) ),
    Label(  'tunnel'               , 16 ,      255 , 'construction'    , 2       , False        , True         , (150,120, 90) ),
    Label(  'pole'                 , 17 ,        5 , 'object'          , 3       , False        , False        , (153,153,153) ),
    Label(  'polegroup'            , 18 ,      255 , 'object'          , 3       , False        , True         , (153,153,153) ),
    Label(  'traffic light'        , 19 ,        6 , 'object'          , 3       , False        , False        , (250,170, 30) ),
    Label(  'traffic sign'         , 20 ,        7 , 'object'          , 3       , False        , False        , (220,220,  0) ),
    Label(  'vegetation'           , 21 ,        8 , 'nature'          , 4       , False        , False        , (107,142, 35) ),
    Label(  'terrain'              , 22 ,        9 , 'nature'          , 4       , False        , False        , (152,251,152) ),
    Label(  'sky'                  , 23 ,       10 , 'sky'             , 5       , False        , False        , ( 70,130,180) ),
    Label(  'person'               , 24 ,       11 , 'human'           , 6       , True         , False        , (220, 20, 60) ),
    Label(  'rider'                , 25 ,       12 , 'human'           , 6       , True         , False        , (255,  0,  0) ),
    Label(  'car'                  , 26 ,       13 , 'vehicle'         , 7       , True         , False        , (  0,  0,142) ),
    Label(  'truck'                , 27 ,       14 , 'vehicle'         , 7       , True         , False        , (  0,  0, 70) ),
    Label(  'bus'                  , 28 ,       15 , 'vehicle'         , 7       , True         , False        , (  0, 60,100) ),
    Label(  'caravan'              , 29 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0, 90) ),
    Label(  'trailer'              , 30 ,      255 , 'vehicle'         , 7       , True         , True         , (  0,  0,110) ),
    Label(  'train'                , 31 ,       16 , 'vehicle'         , 7       , True         , False        , (  0, 80,100) ),
    Label(  'motorcycle'           , 32 ,       17 , 'vehicle'         , 7       , True         , False        , (  0,  0,230) ),
    Label(  'bicycle'              , 33 ,       18 , 'vehicle'         , 7       , True         , False        , (119, 11, 32) ),
    Label(  'license plate'        , -1 ,       -1 , 'vehicle'         , 7       , False        , True         , (  0,  0,142) ),
]
'''

class GTA5Dataset(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(256, 256), mean=(128, 128, 128), ignore_label=255):
        self.root = root
        self.list_path = list_path
        self.crop_size = crop_size
        self.ignore_label = ignore_label
        self.mean = mean
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
        if not max_iters==None:
            self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
        self.files = []

        self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
                              19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
                              26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18}

        self.id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
            3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
            7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
            14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
            18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
            28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}

        #self.id_to_trainid = {7: 1, 24: 2, 26: 3} #Road/car/people
        self.id_to_trainid = {11: 1, 24: 2, 21: 3} #Building/car/vegetation
        #self.ignore_label = 0

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, index):
        name = self.img_ids[index]
        
        image = Image.open(osp.join(self.root, "images/%s" % name)).convert('RGB')
        label = Image.open(osp.join(self.root, "labels/%s" % name))
        # resize
        image = image.resize(self.crop_size, Image.BICUBIC)
        label = label.resize(self.crop_size, Image.NEAREST)

        image = np.asarray(image, np.float32)
        label = np.asarray(label, np.int8)

        label_copy = self.ignore_label * np.ones(label.shape, dtype=np.float32)
        #for k, v in self.id_to_trainid.items():
        for k, v in self.id2label.items():
            label_copy[label == k] = v
        size = image.shape
        #image = image[:, :, ::-1]  # change to BGR
        #image -= self.mean
        image = image.transpose((2, 0, 1))
        sample = {'image': image.copy(),
                  'label': label_copy.copy()}

        return sample
    
class GTA5Dataset1(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(256, 256), mean=(128, 128, 128), ignore_label=255):
        self.root = root
        self.list_path = list_path
        self.crop_size = crop_size
        self.ignore_label = ignore_label
        self.mean = mean
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
        if not max_iters==None:
            self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
        self.files = []

        self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
                              19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
                              26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18}

        self.id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
            3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
            7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
            14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
            18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
            28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}

        #self.id_to_trainid = {7: 1, 24: 2, 26: 3} #Road/car/people
        self.id_to_trainid = {11: 1, 24: 2, 21: 3} #Building/car/vegetation
        #self.ignore_label = 0

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, index):
        name = self.img_ids[index]
                
        name = name[:-4] + ".jpg"
        
        image = Image.open(osp.join(self.root, "images/%s" % name)).convert('RGB')
        
        name = name[:-4] + ".png"
        label = Image.open(osp.join(self.root, "labels/%s" % name))
        # resize
        image = image.resize(self.crop_size, Image.BICUBIC)
        label = label.resize(self.crop_size, Image.NEAREST)

        image = np.asarray(image, np.float32)
        label = np.asarray(label, np.int8)

        label_copy = self.ignore_label * np.ones(label.shape, dtype=np.float32)
        #for k, v in self.id_to_trainid.items():
        for k, v in self.id2label.items():
            label_copy[label == k] = v
        size = image.shape
        #image = image[:, :, ::-1]  # change to BGR
        #image -= self.mean
        image = image.transpose((2, 0, 1))
        sample = {'image': image.copy(),
                  'label': label_copy.copy()}

        return sample


In [3]:
# Cityscapes dataset (target)

import os
import os.path as osp
import sys

import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
from torch.utils import data
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
            220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
            0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
        'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
        'bicycle']

    
class CityscapesDataset(data.Dataset):
    def __init__(self, root, list_path, max_iters=None, crop_size=(256, 256), mean=(128, 128, 128), ignore_label=255):
        self.root = root
        self.list_path = list_path
        self.crop_size = crop_size
        self.ignore_label = ignore_label
        self.mean = mean
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
        if not max_iters==None:
            self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids)))
        self.files = []
        
        self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
                              19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
                              26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18}

        self.id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
            3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
            7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
            14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
            18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
            28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}

        #self.id_to_trainid = {7: 1, 24: 2, 26: 3} #Road/car/people
        self.id_to_trainid = {11: 1, 24: 2, 21: 3} #Building/car/vegetation
        #self.ignore_label = 0

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, index):
        name = self.img_ids[index]
        image_root = osp.join(self.root, 'cityscapes')
        label_root = osp.join(self.root, 'gtFine')
        image = Image.open(osp.join(image_root, "%s" % name)).convert('RGB')
        label = Image.open(osp.join(label_root, "%s" % name.replace("leftImg8bit", "gtFine_labelIds")))
        # resize
        image = image.resize(self.crop_size, Image.BICUBIC)
        label = label.resize(self.crop_size, Image.NEAREST)

        image = np.asarray(image, np.float32)
        label = np.asarray(label, np.int8)

        label_copy = self.ignore_label * np.ones(label.shape, dtype=np.float32)
        #for k, v in self.id_to_trainid.items():
        for k, v in self.id2label.items():
            label_copy[label == k] = v
        size = image.shape
        #image = image[:, :, ::-1]  # change to BGR
        #image -= self.mean
        image = image.transpose((2, 0, 1))
        sample = {'image': image.copy(),
                  'label': label_copy.copy()}

        return sample
    

In [4]:
import skimage.io as io
import matplotlib.pylab as plt
import numpy as np

from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
# from medpy.io import load
import random

import torch
import os
from PIL import Image  # using pillow-simd for increased speed

import cv2
from torchvision import transforms


def pil_loader(path):
    # open path as file to avoid ResourceWarning
    # (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert("L")

class Retouch_dataset(Dataset):
    def __init__(self,
                 base_dir,
                 list_dir,
                 size=(512, 512),
                 split='train',
                 is_train=False,
                 transform=None,
                 ext='.png'):
        self.transform = transform  # using transform in torch!
        self.split = split
        if split == '':
            self.sample_list = open(list_dir).readlines()
        else:
            self.sample_list = open(os.path.join(list_dir +
                                                 self.split + '.txt')).readlines()
        self.data_dir = base_dir
        self.loader = pil_loader
        self.to_tensor = transforms.Compose(
            [
                transforms.Resize(size),
                transforms.ToTensor(),
                #transforms.Normalize(mean=(0.5,), std=(0.5,)),
            ])

        self.is_train = is_train
        self.transform = transform
        self.ext = ext

    def augment(self, data, label):
        data_label = torch.cat((data, label), dim=0)
        data_label_aug = self.transform(data_label)
        data_aug = data_label_aug[0, :, :].unsqueeze(0)
        label_aug = data_label_aug[1, :, :].unsqueeze(0)
        return data_aug, label_aug

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        sample_name = self.sample_list[idx].strip('\n')

        vendor = sample_name.split(' ')[0]
        slice_name = sample_name.split(' ')[1]
        slice_idx = sample_name.split(' ')[2].zfill(3)

        data_path = os.path.join(self.data_dir,
                                 vendor,
                                 slice_name,
                                 'images',
                                 slice_idx + self.ext)
        '''

        label_path = os.path.join(self.data_dir,
                                  vendor,
                                  slice_name,
                                  'labels',
                                  slice_idx + '.npy')

        label = torch.from_numpy(np.load(label_path))
        label_idx = torch.argmax(label, dim=0, keepdim=True)
        
        '''
        
        data = self.to_tensor(self.loader(data_path))
        transform_avaliable = self.transform is not None and self.is_train
        do_aug = transform_avaliable and random.random() > 0.5

        if do_aug:
            # data, label_idx = self.augment(data, label_idx)
            data = self.augment(data)

        # label_idx = label_idx.squeeze(0).long()

        sample = {'image': data,
                  'case_name': sample_name}
        # print((label_idx==0).sum()/512**2)
        return sample

# Test Unit
# flip = transforms.RandomHorizontalFlip(p=0.5)
# base_dir = 'Retouch-dataset_test/pre_processed/'
# list_dir = ''
# dataset = Retouch_dataset(base_dir, list_dir, transform=flip)
# l = dataset[3]['label']
# d = dataset[3]['image']
#
# print(l.shape, d.shape)
#
# img = d.permute(1, 2, 0).numpy()
# print((img[:, :, 0] == img[:, :, 2]).all())
# print(dataset[3]['case_name'])
# plt.figure()
# plt.imshow(img)


  from scipy.ndimage.interpolation import zoom


In [5]:
def create_dataset(dataset_mode, folder_name, split_name, split, size):
    if dataset_mode == "retouch":
        source_dataset = Retouch_dataset(base_dir=folder_name, list_dir=split_name, split='', size = crop_size)
    elif dataset_mode == "gta5":
        source_dataset = GTA5Dataset(root=folder_name, list_path=split_name, crop_size=size, ignore_label=19)
    elif dataset_mode == "cityscapes":
        source_dataset = CityscapesDataset(root=folder_name, list_path=split_name, crop_size=size, ignore_label=19)
    else:
        print("Unrecognized dataset!")
        sys.exit()
        
    return source_dataset

## Using the custom data in VISSL

The original data is saved in the `data` directory. The transferred images are saved in such a way, that they are stored in the `data/transferred/#epoch` directory (`#epoch` is the number of CycleGAN epoch).

**EXAMPLE 1**: download the retouch data set from [retouch-dataset](https://drive.google.com/file/d/1r8pQCoVzEAHdy9wLW_MUkyfgBBFePMPv/view?usp=sharing) and insert it into the `data` directory. Download the transferred images from [transferred-retouch-images](https://drive.google.com/file/d/1nMcyF-z2yvPBDY70qBsT2Ydg7NUITpmR/view?usp=sharing) and isert the subfolders with the epoch number into the `data/transferred` directory.

**EXAMPLE 2**: download the truncated retouch GTAV data set from [gta5-truncated-dataset](https://drive.google.com/file/d/1R9zmrwAKf03KOq9MSfhdPd6xOVRGEtrY/view?usp=sharing) and insert it into the `data` directory. Download the transferred images from [transferred-gta5-images](https://drive.google.com/file/d/1SLdGNHDi3LZTHXXNMNFDTmAQibAEjj-x/view?usp=sharing) and isert the subfolders with the epoch number into the `data/transferred` directory. Note, it also works with the whole data set, one only has to change the `splits/gta5/gta5.txt` to the whole dataset. The truncated version is used due to memory and time efficiency.


In [10]:
# OCT torch implementation

import matplotlib.pyplot as plt
import cv2
from torch.utils.data import DataLoader
import csv
from tqdm import tqdm
import os
import data

dataset_mode ="retouch" # retouch / gta5

transferred_images_dir = os.path.join(os.getcwd(), "data/transferred")
epochs = [int(f) for f in os.listdir(transferred_images_dir) if os.path.isdir(os.path.join(transferred_images_dir, f))]
epochs.sort()

head = "results/{}/".format(dataset_mode)

if not os.path.exists(head):
    os.makedirs(head)
    
folder_name = "data/{}".format(dataset_mode)
split_name = "splits/{}/{}.txt".format(dataset_mode, dataset_mode)
split = ""
crop_size = (256, 256)
batch_size = 1
batch = 50000

source_dataset = create_dataset(dataset_mode, folder_name, split_name, split, crop_size)
source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=False)

#split_name = "./splits/gta5/gta5_list.txt"
#crop_size = (512, 512)
#batch_size = 1

device = torch.device('cuda:{}'.format(0))

for L in [0.01]:
#for L in [0.05]:

    to_write = []
    title = ["epoch", "mean fro", "var fro", "mean L1", "var L1", "mean inf", "var inf"]
    to_write.append(title)

    for epoch in epochs:
        print("Starting epoch {} :".format(epoch))
        folder_name = "data/fourier/{}/".format(epoch)  # Replace with your produced GTA5 dir !
        print("Loading from: '" + folder_name + "'")
        source2target_dataset = create_dataset(dataset_mode, folder_name, split_name, split, crop_size)
        source2target_loader = DataLoader(source2target_dataset, batch_size=batch_size, shuffle=False)
        source2target_loader_iter = iter(source2target_loader)
        FDA_distances_fro = []
        FDA_distances_L1 = []
        FDA_distances_inf = []

        for i, data in enumerate(tqdm(source_loader)):
            if i < batch:
                source = data["image"].cuda().to(device)
                source_img = source[0]

                source2target = source2target_loader_iter.next()["image"].cuda().to(device)
                source2target_img = source2target[0]

                FDA_distance = FDA_distance_torch(src_img = source_img, 
                                                  src2trg_img = source2target_img, 
                                                  L = L, normalize = True)

                (FDA_distance_fro, FDA_distance_L1, FDA_distance_inf) = FDA_distance
                
                FDA_distances_fro.append(FDA_distance_fro.item())
                FDA_distances_L1.append(FDA_distance_L1.item())
                FDA_distances_inf.append(FDA_distance_inf.item())

            else:
                break

        FDA_distances_fro = np.array(FDA_distances_fro)
        FDA_distances_L1 = np.array(FDA_distances_L1)
        FDA_distances_inf = np.array(FDA_distances_inf)

        result = [epoch, np.mean(FDA_distances_fro), np.var(FDA_distances_fro),
                  np.mean(FDA_distances_L1), np.var(FDA_distances_L1), 
                  np.mean(FDA_distances_inf), np.var(FDA_distances_inf)]

        print("Mean content FT : {} (mean fro), {} (mean L1), {} (mean inf)".format(np.mean(FDA_distances_fro), 
                                                                     np.mean(FDA_distances_L1), np.mean(FDA_distances_inf)))

        print("Var content FT : {} (var fro), {} (var L1), {} (var inf)".format(np.var(FDA_distances_fro), 
                                                                     np.var(FDA_distances_L1), np.var(FDA_distances_inf)))
        to_write.append(result)
    
    str_L = str(L).replace(".", "")
    
    with open(head+"results_content_norm_HFFT_" + str_L +"_{}.csv".format(dataset_mode), "w") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(to_write)

Starting epoch 0 :
Loading from: 'data/fourier/0/'


100%|███████████████████████████████████████| 2497/2497 [00:54<00:00, 45.54it/s]


Mean content FT : 0.005243747097311947 (mean fro), 0.6335962659294815 (mean L1), 0.0009937944220080135 (mean inf)
Var content FT : 2.9616228238586528e-05 (var fro), 0.21553011664875726 (var L1), 1.621874765788993e-06 (var inf)
Starting epoch 2 :
Loading from: 'data/fourier/2/'


100%|███████████████████████████████████████| 2497/2497 [00:52<00:00, 47.66it/s]


Mean content FT : 0.08665429896380025 (mean fro), 8.078563885255294 (mean L1), 0.016400589254074182 (mean inf)
Var content FT : 0.0005044099326271049 (var fro), 1.8853676441818012 (var L1), 4.1156961754610546e-05 (var inf)
Starting epoch 4 :
Loading from: 'data/fourier/4/'


100%|███████████████████████████████████████| 2497/2497 [00:52<00:00, 47.93it/s]


Mean content FT : 0.06765801955758928 (mean fro), 6.309012919556966 (mean L1), 0.013660774840277716 (mean inf)
Var content FT : 0.0004790678863079971 (var fro), 1.5331150619064098 (var L1), 4.1240441315447285e-05 (var inf)
Starting epoch 6 :
Loading from: 'data/fourier/6/'


100%|███████████████████████████████████████| 2497/2497 [00:52<00:00, 47.50it/s]


Mean content FT : 0.06665206196821233 (mean fro), 6.242947189150021 (mean L1), 0.012878599438743952 (mean inf)
Var content FT : 0.0003017592231936984 (var fro), 1.0474466145288612 (var L1), 2.5153750087352576e-05 (var inf)
Starting epoch 8 :
Loading from: 'data/fourier/8/'


100%|███████████████████████████████████████| 2497/2497 [00:53<00:00, 47.04it/s]


Mean content FT : 0.06708000184245286 (mean fro), 6.003102911440048 (mean L1), 0.014205187024745582 (mean inf)
Var content FT : 0.0003052561902349813 (var fro), 0.8679299992215921 (var L1), 3.182760228679323e-05 (var inf)
Starting epoch 10 :
Loading from: 'data/fourier/10/'


100%|███████████████████████████████████████| 2497/2497 [00:54<00:00, 45.81it/s]


Mean content FT : 0.07998240082369931 (mean fro), 6.301456551003752 (mean L1), 0.017078530196470466 (mean inf)
Var content FT : 0.0004496712571328838 (var fro), 0.8839361258553314 (var L1), 3.984291919705258e-05 (var inf)
Starting epoch 12 :
Loading from: 'data/fourier/12/'


100%|███████████████████████████████████████| 2497/2497 [00:56<00:00, 44.51it/s]


Mean content FT : 0.07259896933856395 (mean fro), 6.943597739249838 (mean L1), 0.014590108144467396 (mean inf)
Var content FT : 0.00034857091743702764 (var fro), 1.368007667142607 (var L1), 3.355426685897053e-05 (var inf)
Starting epoch 14 :
Loading from: 'data/fourier/14/'


100%|███████████████████████████████████████| 2497/2497 [00:54<00:00, 45.87it/s]


Mean content FT : 0.082871966961358 (mean fro), 6.566706064800573 (mean L1), 0.017814391563599128 (mean inf)
Var content FT : 0.0004836313141310321 (var fro), 1.1570412048222216 (var L1), 4.486242933471048e-05 (var inf)
Starting epoch 16 :
Loading from: 'data/fourier/16/'


100%|███████████████████████████████████████| 2497/2497 [00:52<00:00, 47.57it/s]

Mean content FT : 0.07060923156786142 (mean fro), 6.372371600587988 (mean L1), 0.014368668416030869 (mean inf)
Var content FT : 0.0004509931211690666 (var fro), 1.375853377443662 (var L1), 4.106327613230687e-05 (var inf)



