# Using FDA to evaluate image style transfer

This notebook is based upon `FDA: Fourier Domain Adaptation for Semantic Segmentation` from Yanchao Yang.

Before starting, the original data and the transferred images should be stored in a known path.

In [1]:
import torch
import numpy as np

def high_freq_mutate( amp_src, amp_trg, L=0.1 ):
    """ Change the high frequency from source with the high frequency from target and back transform to image.
    """
    a_src = torch.fft.fftshift( amp_src, dim =(-2, -1) )
    a_trg = torch.fft.fftshift( amp_trg, dim =(-2, -1) )

    c, h, w = a_src.shape
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)
    c_h = np.floor(h/2.0).astype(int)
    c_w = np.floor(w/2.0).astype(int)

    h1 = c_h-b
    h2 = c_h+b+1
    w1 = c_w-b
    w2 = c_w+b+1

    a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
    a_src = torch.fft.ifftshift( a_src, dim =(-2, -1) )
    return a_src

def FDA_source_to_target(src_img, trg_img, L=0.1):
    """ Compute the 2-dimensional FFT of a real array.
    """

    src_img_torch = src_img.clone() #.cpu().numpy()
    trg_img_torch = trg_img.clone() #.cpu().numpy()

    # get fft of both source and target
    fft_src_torch = torch.fft.rfft2( src_img_torch, dim=(-2, -1) )
    fft_trg_torch = torch.fft.rfft2( trg_img_torch, dim=(-2, -1) )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = torch.abs(fft_src_torch), torch.angle(fft_src_torch)
    amp_trg, pha_trg = torch.abs(fft_trg_torch), torch.angle(fft_trg_torch)

    # mutate the amplitude part of source with target
    amp_src_ = high_freq_mutate( amp_src, amp_trg, L=L )

    # mutated fft of source
    fft_src_ = amp_src_ * torch.exp( 1j * pha_src )

    # get the mutated image
    src_in_trg = torch.fft.irfft2( fft_src_, dim=(-2, -1) )
    #src_in_trg = torch.real(src_in_trg)

    return src_in_trg

def FDA_distance_torch( src_img, src2trg_img, L=0.1 , normalize = False, display = False):
    """ Calculate FID between feature distribution 1 and feature distribution 2
    """

    src_img_torch = src_img.clone() #.cpu().numpy()
    src2trg_img_torch = src2trg_img.clone() #.cpu().numpy()

    # get fft of both source and target
    fft_src_torch = torch.fft.rfft2( src_img_torch, dim=(-2, -1) )
    fft_trg_torch = torch.fft.rfft2( src2trg_img_torch, dim=(-2, -1) )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = torch.abs(fft_src_torch), torch.angle(fft_src_torch)
    amp_trg, pha_trg = torch.abs(fft_trg_torch), torch.angle(fft_trg_torch)

    # mutate the amplitude part of source with target
    low_freq_part, a_src, a_trg = high_freq_part_torch( amp_src, amp_trg, L=L, normalize = normalize )

    low_freq_dist_fro = torch.linalg.norm(torch.flatten(low_freq_part))

    low_freq_dist_L1 = torch.linalg.norm(torch.flatten(low_freq_part), ord = 1)

    low_freq_dist_inf = torch.linalg.norm(torch.flatten(low_freq_part), ord = float('inf'))
    
    low_freq_dist = (low_freq_dist_fro, low_freq_dist_L1, low_freq_dist_inf)
    
    if display:

        # mutated fft of source
        fft_src_ = a_src * torch.exp( 1j * pha_src )
        # mutated fft of source
        fft_trg_ = a_trg * torch.exp( 1j * pha_trg )
        
        low_freq_part_src_ = low_freq_part * torch.exp( 1j * pha_src )
        low_freq_part_trg_ = low_freq_part * torch.exp( 1j * pha_trg )
        
        src_wo_style = torch.fft.irfft2( fft_src_, dim=(-2, -1) )
        trg_wo_style = torch.fft.irfft2( fft_trg_, dim=(-2, -1) )
        low_freq_part_ifft = torch.fft.irfft2( low_freq_part, dim=(-2, -1) )
        low_freq_part_src_ = torch.fft.irfft2( low_freq_part_src_, dim=(-2, -1) )
        low_freq_part_trg_ = torch.fft.irfft2( low_freq_part_trg_, dim=(-2, -1) )
        
        low_freq_tuple = (low_freq_part, low_freq_part_ifft, low_freq_part_src_, low_freq_part_trg_)
        
        return low_freq_dist, low_freq_tuple, src_wo_style, trg_wo_style
        
    else:
        
        return low_freq_dist

def high_freq_part_torch( amp_src, amp_trg, L=0.1, normalize = False):
    """ Calculate the normalized difference bewteen source and target image frequency.
    """
    # Shift the zero-frequency component to the center of the spectrum.
    a_src = torch.fft.fftshift( amp_src, dim =(-2, -1) )
    a_trg = torch.fft.fftshift( amp_trg, dim =(-2, -1) )
    
    max_src_temp = a_src.max(dim = 1)
    max_trg_temp = a_trg.max(dim = 1)
    max_src = max_src_temp.values.max(dim = 1)
    max_trg = max_trg_temp.values.max(dim = 1)

    c, h, w = a_src.shape
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)
    c_h = np.floor(h/2.0).astype(int)
    c_w = np.floor(w/2.0).astype(int)

    h1 = c_h-b
    h2 = c_h+b+1
    w1 = c_w-b
    w2 = c_w+b+1
    
    a_src[:,h1:h2,w1:w2] = a_src[:,h1:h2,w1:w2] * 0
    a_trg[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2] * 0
    
    if normalize:
        if 0 not in max_src.values:
            low_freq_part = torch.div(a_src.permute((1, 2, 0)), max_src.values) - torch.div(a_trg.permute((1, 2, 0)), max_trg.values)
        else:
            low_freq_part = a_src * 0
        low_freq_part = low_freq_part / ((2 * b) * (2 * b))
        
    else:
        low_freq_part = a_src - a_trg
        
    
    a_src = torch.fft.ifftshift( a_src, dim =(-2, -1) )
    a_trg = torch.fft.ifftshift( a_trg, dim =(-2, -1) )
    
    low_freq_part = torch.fft.ifftshift( low_freq_part, dim =(-2, -1) )
    
    return low_freq_part, a_src, a_trg


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from datasets.cityscapes_dataset import CityscapesDataset
from datasets.gta_dataset import GTA5Dataset
from datasets.retouch_dataset import Retouch_dataset

def create_dataset(dataset_mode, folder_name, split_name, size):
    if dataset_mode == "retouch":
        source_dataset = Retouch_dataset(base_dir=folder_name, list_dir=split_name, size = crop_size)
    elif dataset_mode == "gta5":
        source_dataset = GTA5Dataset(root=folder_name, list_path=split_name, crop_size=size, ignore_label=19)
    elif dataset_mode == "cityscapes":
        source_dataset = CityscapesDataset(root=folder_name, list_path=split_name, crop_size=size, ignore_label=19)
    else:
        print("Unrecognized dataset!")
        sys.exit()
        
    return source_dataset

## Using the custom data in VISSL

The original data is saved in the `data` directory. The transferred images are saved in such a way, that they are stored in the `data/generated_images/#epoch` directory (`#epoch` is the number of CycleGAN epoch).

**EXAMPLE 1**: download the retouch data set from [retouch-dataset](https://drive.google.com/file/d/1r8pQCoVzEAHdy9wLW_MUkyfgBBFePMPv/view?usp=sharing) and insert it into the `data/real_images` directory. Download the transferred images from [transferred-retouch-images](https://drive.google.com/file/d/1nMcyF-z2yvPBDY70qBsT2Ydg7NUITpmR/view?usp=sharing) and isert the subfolders with the epoch number into the `data/generated_images` directory.

**EXAMPLE 2**: download the truncated retouch GTAV data set from [gta5-truncated-dataset](https://drive.google.com/file/d/1R9zmrwAKf03KOq9MSfhdPd6xOVRGEtrY/view?usp=sharing) and insert it into the `data/real_images` directory. Download the transferred images from [transferred-gta5-images](https://drive.google.com/file/d/1SLdGNHDi3LZTHXXNMNFDTmAQibAEjj-x/view?usp=sharing) and isert the subfolders with the epoch number into the `data/generated_images` directory. Note, it also works with the whole data set, one only has to change the `splits/gta5/gta5.txt` to the whole dataset. The truncated version is used due to memory and time efficiency.


In [16]:
import matplotlib.pyplot as plt
import cv2
from torch.utils.data import DataLoader
import csv
from tqdm import tqdm
import os
import argparse

opt = argparse.ArgumentParser()
opt.dataset_mode = "retouch"
opt.method = "jigsaw"
opt.real_dir = os.path.join(os.getcwd(), "data/real_images/retouch-dataset")
opt.fake_dir = os.path.join(os.getcwd(), "data/generated_images/OCT_new")
opt.load_epoch = 0

opt.data_list = os.path.join(os.getcwd(), "splits/cirrus_samples.txt")

opt.crop_size = (512, 512)

opt.num_threads = 0  
opt.batch_size = 1 
opt.no_flip = True  
opt.display_id = -1

transferred_images_dir = os.path.join(os.getcwd(), opt.fake_dir)
epochs = [int(f) for f in os.listdir(transferred_images_dir) if os.path.isdir(os.path.join(transferred_images_dir, f))]
epochs.sort()

head = os.path.join(os.getcwd(), "results")

if not os.path.exists(head):
    os.makedirs(head)

source_dataset = create_dataset(dataset_mode, opt.real_dir, opt.data_list, opt.crop_size)
source_loader = DataLoader(source_dataset, batch_size=opt.batch_size, shuffle=False)

device = torch.device('cuda:{}'.format(0))

for L in [0.01]:
#for L in [0.05]:

    to_write = []
    title = ["epoch", "mean fro", "var fro", "mean L1", "var L1", "mean inf", "var inf"]
    to_write.append(title)

    for epoch in epochs:
        print("Starting epoch {} :".format(epoch))
        folder_name = os.path.join(opt.fake_dir, "{}".format(epoch))
        print("Loading from: '" + folder_name + "'")
        source2target_dataset = create_dataset(dataset_mode, folder_name, opt.data_list, opt.crop_size)
        source2target_loader = DataLoader(source2target_dataset, batch_size=opt.batch_size, shuffle=False)
        source2target_loader_iter = iter(source2target_loader)
        FDA_distances_fro = []
        FDA_distances_L1 = []
        FDA_distances_inf = []

        for i, data in enumerate(tqdm(source_loader)):
            if i < batch:
                source = data["image"].cuda().to(device)
                source_img = source[0]

                source2target = source2target_loader_iter.next()["image"].cuda().to(device)
                source2target_img = source2target[0]

                FDA_distance = FDA_distance_torch(src_img = source_img, 
                                                  src2trg_img = source2target_img, 
                                                  L = L, normalize = True)

                (FDA_distance_fro, FDA_distance_L1, FDA_distance_inf) = FDA_distance
                
                FDA_distances_fro.append(FDA_distance_fro.item())
                FDA_distances_L1.append(FDA_distance_L1.item())
                FDA_distances_inf.append(FDA_distance_inf.item())

            else:
                break

        FDA_distances_fro = np.array(FDA_distances_fro)
        FDA_distances_L1 = np.array(FDA_distances_L1)
        FDA_distances_inf = np.array(FDA_distances_inf)

        result = [epoch, np.mean(FDA_distances_fro), np.var(FDA_distances_fro),
                  np.mean(FDA_distances_L1), np.var(FDA_distances_L1), 
                  np.mean(FDA_distances_inf), np.var(FDA_distances_inf)]

        print("Mean content FT : {} (mean fro), {} (mean L1), {} (mean inf)".format(np.mean(FDA_distances_fro), 
                                                                     np.mean(FDA_distances_L1), np.mean(FDA_distances_inf)))

        print("Var content FT : {} (var fro), {} (var L1), {} (var inf)".format(np.var(FDA_distances_fro), 
                                                                     np.var(FDA_distances_L1), np.var(FDA_distances_inf)))
        to_write.append(result)
    
    str_L = str(L).replace(".", "")
    
    with open(head+"results_content_norm_HFFT_" + str_L +"_{}.csv".format(opt.dataset_mode), "w") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(to_write)

Starting epoch 3 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/3'


100%|███████████████████████████████████████| 3072/3072 [00:32<00:00, 93.97it/s]


Mean content FT : 0.06787624010879274 (mean fro), 9.06347272824496 (mean L1), 0.011698037359110458 (mean inf)
Var content FT : 0.0003679565172793263 (var fro), 1.0601259926331206 (var L1), 2.9892104834947856e-05 (var inf)
Starting epoch 12 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/12'


100%|███████████████████████████████████████| 3072/3072 [00:32<00:00, 95.42it/s]


Mean content FT : 0.07681800747741363 (mean fro), 8.70493844524026 (mean L1), 0.01619822119391756 (mean inf)
Var content FT : 0.0006075895454700128 (var fro), 1.5175142058558773 (var L1), 5.7841632768268003e-05 (var inf)
Starting epoch 15 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/15'


100%|███████████████████████████████████████| 3072/3072 [00:33<00:00, 92.78it/s]


Mean content FT : 0.06886866446014513 (mean fro), 8.86506050142149 (mean L1), 0.011183873062009297 (mean inf)
Var content FT : 0.0005046855508523822 (var fro), 1.6243900158309474 (var L1), 3.5880192666415044e-05 (var inf)
Starting epoch 18 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/18'


100%|███████████████████████████████████████| 3072/3072 [00:33<00:00, 92.71it/s]


Mean content FT : 0.07480982354294004 (mean fro), 8.185165011168769 (mean L1), 0.016372884328120563 (mean inf)
Var content FT : 0.0007457351358228135 (var fro), 1.7292686174075793 (var L1), 7.345235986229524e-05 (var inf)
Starting epoch 21 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/21'


100%|███████████████████████████████████████| 3072/3072 [00:33<00:00, 91.49it/s]


Mean content FT : 0.07547928482866458 (mean fro), 9.076687966783842 (mean L1), 0.014487132174357006 (mean inf)
Var content FT : 0.0004865532010478301 (var fro), 1.6648424304485099 (var L1), 4.504819816128137e-05 (var inf)
Starting epoch 24 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/24'


100%|███████████████████████████████████████| 3072/3072 [00:33<00:00, 90.83it/s]


Mean content FT : 0.0789804838859709 (mean fro), 9.905114601987103 (mean L1), 0.015266536481779744 (mean inf)
Var content FT : 0.0005895168139869911 (var fro), 1.5073230193015157 (var L1), 5.113223042472342e-05 (var inf)
Starting epoch 27 :
Loading from: '/home/zeju/Documents/Zeiss_Self_Supervised/data/generated_images/OCT_new/27'


100%|███████████████████████████████████████| 3072/3072 [00:34<00:00, 89.40it/s]

Mean content FT : 0.07794486243316594 (mean fro), 8.963636408559978 (mean L1), 0.016056491546957357 (mean inf)
Var content FT : 0.0008772217034051315 (var fro), 2.119382445946462 (var L1), 6.935879316783008e-05 (var inf)



