### Installations

In [None]:
%%capture
!pip install torchfile 
!pip install tensorboardX
!pip install pytorch-fid 
import os

## Git operations

In [None]:
# Clone git repository 
!git clone 'https://github.com/XkunW/Image_Translation.git'

Cloning into 'Image_Translation'...
remote: Enumerating objects: 381, done.[K
remote: Counting objects: 100% (89/89), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 381 (delta 50), reused 53 (delta 23), pack-reused 292[K
Receiving objects: 100% (381/381), 382.56 MiB | 25.55 MiB/s, done.
Resolving deltas: 100% (157/157), done.


In [None]:
! git pull
# ! git status
# ! git checkout utils.py

fatal: not a git repository (or any of the parent directories): .git


Clone for calculating FID score

In [None]:
# !git clone 'https://github.com/mseitzer/pytorch-fid.git'
# ! git pull

## Drive mounting and unzipping data 

In [None]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
data_dir = '/content/drive/MyDrive/CSC2516_Project/Datasets/' #same for Tina and Sophie

# data zip 
summer2winter = data_dir+'summer2winter_yosemite_small_dataset.zip' 
monet2photo = data_dir+'monet2photo_small_dataset.zip'

Mounted at /content/drive


In [None]:
# change to UNIT folder
#%cd '/content/Image_Translation/UNIT'
os.chdir('Image_Translation/UNIT')
os.getcwd()

'/content/Image_Translation/UNIT'

In [None]:
# Unzipping datasets to the target folder
%%capture
# !unzip "$summer2winter" -d '/content/Image_Translation/UNIT/datasets/'
!unzip "$monet2photo" -d '/content/Image_Translation/UNIT/datasets/'

In [None]:
# copy vgg16 model weights into the models folder in github repo
!cp "/content/drive/MyDrive/CSC2516_Project/UNIT_colab/VGG_model/vgg16.weight" "/content/Image_Translation/UNIT/models"

## Functions for FID score

In [None]:
import os
import pathlib
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from multiprocessing import cpu_count
import numpy as np
import torch
import torchvision.transforms as TF
from PIL import Image
from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d
# try:
#     from tqdm import tqdm
# except ImportError:
#     # If tqdm is not available, provide a mock version of it
#     def tqdm(x):
#         return x

from pytorch_fid.inception import InceptionV3

IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
                    'tif', 'tiff', 'webp'}

class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None):
        self.files = files
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        return img

def get_activations(files, model, batch_size=50, dims=2048, device='cpu'):
    model.eval()

    if batch_size > len(files):
        print(('Warning: batch size is bigger than the data size. '
               'Setting batch size to data size'))
        batch_size = len(files)

    dataset = ImagePathDataset(files, transforms=TF.ToTensor())
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             drop_last=False,
                                             num_workers=cpu_count())

    pred_arr = np.empty((len(files), dims))

    start_idx = 0

    for batch in dataloader: #tqdm():
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch)[0]

        # If model output is not scalar, apply global spatial average pooling.
        # This happens if you choose a dimensionality not equal 2048.
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))

        pred = pred.squeeze(3).squeeze(2).cpu().numpy()

        pred_arr[start_idx:start_idx + pred.shape[0]] = pred

        start_idx = start_idx + pred.shape[0]

    return pred_arr

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1)
            + np.trace(sigma2) - 2 * tr_covmean)


def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
                                    device='cpu'):
    act = get_activations(files, model, batch_size, dims, device)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


def compute_statistics_of_path(path, model, batch_size, dims, device):
    if path.endswith('.npz'):
        with np.load(path) as f:
            m, s = f['mu'][:], f['sigma'][:]
    else:
        path = pathlib.Path(path)
        files = sorted([file for ext in IMAGE_EXTENSIONS
                       for file in path.glob('*.{}'.format(ext))])
        m, s = calculate_activation_statistics(files, model, batch_size,
                                               dims, device)

    return m, s


def calculate_fid_given_paths(paths, batch_size, device, dims):
    """Calculates the FID of two paths"""
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]

    model = InceptionV3([block_idx]).to(device)

    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
                                        dims, device)
    m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
                                        dims, device)
    fid_value = calculate_frechet_distance(m1, s1, m2, s2)

    return fid_value

def get_fid(batch_size, dims, path):
    device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
    fid_value = calculate_fid_given_paths(path,
                                          batch_size,
                                          device,
                                          dims)
    return fid_value

## Test code

In [None]:
# test code, to read model and generate images 
from __future__ import print_function
import argparse
import os
import torch
import torchvision.utils as vutils
from PIL import Image
import pandas as pd
from torch.autograd import Variable
from torchvision import transforms
from trainer import UNIT_Trainer
from utils import get_config, pytorch03_to_pytorch04


In [None]:
def generate_save_image(input, style, output_folder, file_name):
  # currently passed in one image, need to loop through images
    if torch.cuda.is_available():
        image = Variable(transform(Image.open(input).convert('RGB')).unsqueeze(0).cuda())
        style_image = Variable(
            transform(Image.open(style).convert('RGB')).unsqueeze(0).cuda()) if style != '' else None
    else:
        image = Variable(transform(Image.open(input).convert('RGB')).unsqueeze(0))
        style_image = Variable(
            transform(Image.open(style).convert('RGB')).unsqueeze(0)) if style != '' else None

    # Start testing - generate, need to change the image name later too
    content, _ = encode(image)
    outputs = decode(content)
    outputs = (outputs + 1) / 2.
    path = os.path.join(output_folder, file_name) #'output.jpg'
    vutils.save_image(outputs.data, path, padding=0, normalize=True)

### Generate FID score table saving temporary images 

In [None]:
# arguments
# --input inputs/gta_example.jpg --output_folder results/gta2city --checkpoint models/unit_gta2city.pt

checkpoint_dir_summer2winter = '/content/drive/MyDrive/CSC2516_Project/UNIT_summer2winter_small/outputs/unit_summer2winter_yosemite256_folder/checkpoints/'
checkpoint_dir_monet2photo = '/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/'

# 'configs/unit_summer2winter_yosemite256_list.yaml'
config_file =  'configs/unit_monet2photo_list.yaml'
output_folder = 'results/monet2photo' # need name addon
csv_output_folder = '/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/results/'
checkpoint = checkpoint_dir_monet2photo # or checkpoint_dir_monet2photo , need name add on 
output_only = True #only saving the generated image output

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=config_file, help="net configuration")
parser.add_argument('--style', type=str, default='', help="style image path")
# parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a")
parser.add_argument('--seed', type=int, default=10, help="random seed")
parser.add_argument('--num_style', type=int, default=10, help="number of styles to sample")
parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not")
parser.add_argument('--output_folder', type=str, default=output_folder, help="output image path")
# parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not")
parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight")
parser.add_argument('--trainer', type=str, default='UNIT', help="UNIT")
parser.add_argument('-f', default='')
opts = parser.parse_args()

In [None]:
# manual set up hps, iterate through iteration numbers
recon_kl_cyc_w_list = [0.01, 0.1]
lr_values = [0.0001, 0.0005]
iterations = ['00008000', '00016000', '00024000', '00032000', '00040000', 
              '00048000', '00056000', '00064000', '00072000', '00080000']

# load existing csv table or create new 
try:
    FID_table = pd.read_csv(csv_output_folder + 'FID.csv')
except: 
    FID_table =  pd.DataFrame(columns=['recon_kl_w', 'recon_kl_clc', 'lr_value', 'iteration', 
                                   'fid_val_A2B', 'fid_val_A2B2A', 'fid_val_B2A',
                                   'fid_val_B2A2B'])
torch.manual_seed(opts.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(opts.seed)

config = get_config(opts.config)

# global definition 
input_folder = config['data_root']
input_folder_A = input_folder + '/testA/'
input_folder_B = input_folder + '/testB/'

# iterate through different iterations
for lr_value in lr_values:
    for recon_kl_cyc_w in recon_kl_cyc_w_list:
        recon_kl_w = recon_kl_cyc_w
        for iter in iterations: 
            param_values = 'kl_w_' + str(recon_kl_w) + 'kl_clc_'+ str(recon_kl_cyc_w)  + '_lr_valie_' + str(lr_value)
            current_output_A2B = opts.output_folder + '/A2B/' + param_values + '/'
            current_output_A2B2A = opts.output_folder + '/A2B2A/' + param_values + '/'
            current_output_B2A = opts.output_folder + '/B2A/' + param_values + '/'
            current_output_B2A2B = opts.output_folder + '/B2A2B/' + param_values + '/'
            local_string = 'gen_'+ str(iter) + '_batch_size_1_recon_kl_w_' \
                                        + str(recon_kl_w) + '_recon_kl_clc_' \
                                        + str(recon_kl_cyc_w)  + '_lr_value_' + str(lr_value) + '.pt'
            current_checkpoint = os.path.join(checkpoint, local_string)
            print(current_checkpoint)

            # create output folder path 
            if not os.path.exists(opts.output_folder):
                os.makedirs(opts.output_folder)
            if not os.path.exists(current_output_A2B):
                os.makedirs(current_output_A2B)
            if not os.path.exists(current_output_A2B2A):
                os.makedirs(current_output_A2B2A)
            if not os.path.exists(current_output_B2A):
                os.makedirs(current_output_B2A)
            if not os.path.exists(current_output_B2A2B):
                os.makedirs(current_output_B2A2B)

            # Load experiment setting, modify config value 
            config = get_config(opts.config)
            opts.num_style = 1 if opts.style != '' else opts.num_style

            # Setup model and data loader
            config['vgg_model_path'] = '.'

            # loop through a2b and b2a 
            for a2b in [True, False]:
                trainer = UNIT_Trainer(config)
                try:
                    state_dict = torch.load(current_checkpoint)
                    trainer.gen_a.load_state_dict(state_dict['a'])
                    trainer.gen_b.load_state_dict(state_dict['b'])
                except:
                    state_dict = pytorch03_to_pytorch04(torch.load(current_checkpoint))
                    trainer.gen_a.load_state_dict(state_dict['a'])
                    trainer.gen_b.load_state_dict(state_dict['b'])

                if torch.cuda.is_available():
                    trainer.cuda()
                trainer.eval()
                encode = trainer.gen_a.encode if a2b else trainer.gen_b.encode  # encode function
                style_encode = trainer.gen_b.encode if a2b else trainer.gen_a.encode  # encode function
                decode = trainer.gen_b.decode if a2b else trainer.gen_a.decode  # decode function

                if 'new_size' in config:
                    new_size = config['new_size']
                else:
                    if a2b:
                        new_size = config['new_size_a']
                    else:
                        new_size = config['new_size_b']

                with torch.no_grad():
                    transform = transforms.Compose([transforms.Resize(new_size),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
                    if a2b:
                        # 1. Generate images A -> B
                        image_list_A = os.listdir(input_folder_A)  # list of images inside A folder
                        for input_image in image_list_A:
                            generate_save_image(input_folder_A+input_image, opts.style, current_output_A2B, input_image)

                        # 2. Generate images A -> B -> A
                        image_list_A2B = os.listdir(current_output_A2B)  # list of images inside A2B folder
                        for input_image in image_list_A2B:
                            generate_save_image(current_output_A2B+input_image, opts.style, current_output_A2B2A, input_image)
                        
                    else:
                        # 1. Generate images B -> A
                        image_list_B = os.listdir(input_folder_B)  # list of images inside B folder
                        for input_image in image_list_B:
                            generate_save_image(input_folder_B+input_image, opts.style, current_output_B2A, input_image)

                        # 2. Generate images B -> A -> B
                        image_list_B2A = os.listdir(current_output_B2A)  # list of images inside B2A folder
                        for input_image in image_list_B2A:
                            generate_save_image(current_output_B2A+input_image, opts.style, current_output_B2A2B, input_image)

                
                    # compute fid score 
                    if a2b:
                        # 1. compare A2B with testB
                        fid_val_A2B = get_fid(batch_size=1, dims=2048, path=[current_output_A2B, input_folder_B])

                        # 2. compare A2B2A with testA
                        fid_val_A2B2A = get_fid(batch_size=1, dims=2048, path=[current_output_A2B2A,input_folder_A])

                    else:
                        # 3. compare B2A with test A
                        fid_val_B2A = get_fid(batch_size=1, dims=2048, path=[current_output_B2A, input_folder_A])

                        # 4. compare B2A2B with test B
                        fid_val_B2A2B = get_fid(batch_size=1, dims=2048, path=[current_output_B2A2B, input_folder_B])

            # append record after generating both A2B and B2A results 
            current_fid = {
                'fid_val_A2B': fid_val_A2B,
                'fid_val_A2B2A': fid_val_A2B2A,
                'fid_val_B2A': fid_val_B2A,
                'fid_val_B2A2B': fid_val_B2A2B,
                'recon_kl_w': recon_kl_w,
                'recon_kl_clc': recon_kl_cyc_w,
                'lr_value': lr_value,
                'iteration': iter
                }

            FID_table = FID_table.append(current_fid, ignore_index=True)

        # save after each iteration 
        FID_table.to_csv(csv_output_folder + 'FID.csv', index=False)

/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/gen_00008000_batch_size_1_recon_kl_w_0.01_recon_kl_clc_0.01_lr_value_0.0001.pt
/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/gen_00016000_batch_size_1_recon_kl_w_0.01_recon_kl_clc_0.01_lr_value_0.0001.pt
/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/gen_00024000_batch_size_1_recon_kl_w_0.01_recon_kl_clc_0.01_lr_value_0.0001.pt
/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/gen_00032000_batch_size_1_recon_kl_w_0.01_recon_kl_clc_0.01_lr_value_0.0001.pt
/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/gen_00040000_batch_size_1_recon_kl_w_0.01_recon_kl_clc_0.01_lr_value_0.0001.pt
/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_

### Generate images only 

In [None]:
# arguments
# --input inputs/gta_example.jpg --output_folder results/gta2city --checkpoint models/unit_gta2city.pt

checkpoint_dir_summer2winter = '/content/drive/MyDrive/CSC2516_Project/UNIT_summer2winter_small/outputs/unit_summer2winter_yosemite256_folder/checkpoints/'
checkpoint_dir_monet2photo = '/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/'

# 'configs/unit_summer2winter_yosemite256_list.yaml'
config_file =  'configs/unit_monet2photo_list.yaml'
output_folder = '/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/results/' # need name addon
#csv_output_folder = '/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/results/'
checkpoint = checkpoint_dir_monet2photo # or checkpoint_dir_monet2photo , need name add on 
output_only = True #only saving the generated image output

parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=config_file, help="net configuration")
parser.add_argument('--style', type=str, default='', help="style image path")
# parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a")
parser.add_argument('--seed', type=int, default=10, help="random seed")
parser.add_argument('--num_style', type=int, default=10, help="number of styles to sample")
parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not")
parser.add_argument('--output_folder', type=str, default=output_folder, help="output image path")
# parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not")
parser.add_argument('--output_path', type=str, default='.', help="path for logs, checkpoints, and VGG model weight")
parser.add_argument('--trainer', type=str, default='UNIT', help="UNIT")
parser.add_argument('-f', default='')
opts = parser.parse_args()

In [None]:
# manual set up hps, iterate through iteration numbers
recon_kl_cyc_w = 0.01
recon_kl_w = 0.01
lr_value = 0.0001
iter = '00048000'

torch.manual_seed(opts.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(opts.seed)

config = get_config(opts.config)

# global definition 
input_folder = config['data_root']
input_folder_A = input_folder + '/testA/'
input_folder_B = input_folder + '/testB/'

param_values = 'kl_w_' + str(recon_kl_w) + 'kl_clc_'+ str(recon_kl_cyc_w)  + '_lr_valie_' + str(lr_value)
current_output_A2B = opts.output_folder + '/A2B/' + param_values + '/'
current_output_A2B2A = opts.output_folder + '/A2B2A/' + param_values + '/'
current_output_B2A = opts.output_folder + '/B2A/' + param_values + '/'
current_output_B2A2B = opts.output_folder + '/B2A2B/' + param_values + '/'
local_string = 'gen_'+ str(iter) + '_batch_size_1_recon_kl_w_' \
                            + str(recon_kl_w) + '_recon_kl_clc_' \
                            + str(recon_kl_cyc_w)  + '_lr_value_' + str(lr_value) + '.pt'

current_checkpoint = os.path.join(checkpoint, local_string)
print(current_checkpoint)

# create output folder path 
if not os.path.exists(opts.output_folder):
    os.makedirs(opts.output_folder)
if not os.path.exists(current_output_A2B):
    os.makedirs(current_output_A2B)
if not os.path.exists(current_output_A2B2A):
    os.makedirs(current_output_A2B2A)
if not os.path.exists(current_output_B2A):
    os.makedirs(current_output_B2A)
if not os.path.exists(current_output_B2A2B):
    os.makedirs(current_output_B2A2B)

# Load experiment setting, modify config value 
config = get_config(opts.config)
opts.num_style = 1 if opts.style != '' else opts.num_style

# Setup model and data loader
config['vgg_model_path'] = '.'

# loop through a2b and b2a 
for a2b in [True, False]:
    trainer = UNIT_Trainer(config)
    try:
        state_dict = torch.load(current_checkpoint)
        trainer.gen_a.load_state_dict(state_dict['a'])
        trainer.gen_b.load_state_dict(state_dict['b'])
    except:
        state_dict = pytorch03_to_pytorch04(torch.load(current_checkpoint))
        trainer.gen_a.load_state_dict(state_dict['a'])
        trainer.gen_b.load_state_dict(state_dict['b'])

    if torch.cuda.is_available():
        trainer.cuda()
    trainer.eval()
    encode = trainer.gen_a.encode if a2b else trainer.gen_b.encode  # encode function
    style_encode = trainer.gen_b.encode if a2b else trainer.gen_a.encode  # encode function
    decode = trainer.gen_b.decode if a2b else trainer.gen_a.decode  # decode function

    if 'new_size' in config:
        new_size = config['new_size']
    else:
        if a2b:
            new_size = config['new_size_a']
        else:
            new_size = config['new_size_b']

    with torch.no_grad():
        transform = transforms.Compose([transforms.Resize(new_size),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        if a2b:
            # 1. Generate images A -> B
            image_list_A = os.listdir(input_folder_A)  # list of images inside A folder
            for input_image in image_list_A:
                generate_save_image(input_folder_A+input_image, opts.style, current_output_A2B, input_image)

            # 2. Generate images A -> B -> A
            image_list_A2B = os.listdir(current_output_A2B)  # list of images inside A2B folder
            for input_image in image_list_A2B:
                generate_save_image(current_output_A2B+input_image, opts.style, current_output_A2B2A, input_image)
            
        else:
            # 1. Generate images B -> A
            image_list_B = os.listdir(input_folder_B)  # list of images inside B folder
            for input_image in image_list_B:
                generate_save_image(input_folder_B+input_image, opts.style, current_output_B2A, input_image)

            # 2. Generate images B -> A -> B
            image_list_B2A = os.listdir(current_output_B2A)  # list of images inside B2A folder
            for input_image in image_list_B2A:
                generate_save_image(current_output_B2A+input_image, opts.style, current_output_B2A2B, input_image)

/content/drive/MyDrive/CSC2516_Project/UNIT_monet2photo_small/outputs/unit_monet2photo_folder/checkpoints/gen_00048000_batch_size_1_recon_kl_w_0.01_recon_kl_clc_0.01_lr_value_0.0001.pt


### Done

In [None]:
FID_table = pd.read_csv(csv_output_folder + 'FID.csv')
FID_table.reset_index().drop(columns=['index']).to_csv(csv_output_folder + 'FID.csv', index=False)
FID_table.reset_index().drop(columns=['index'])


Unnamed: 0,recon_kl_w,recon_kl_clc,lr_value,iteration,fid_val_A2B,fid_val_A2B2A,fid_val_B2A,fid_val_B2A2B
0,0.01,0.01,0.0001,8000,224.968845,210.204624,197.646804,224.168804
1,0.01,0.01,0.0001,16000,224.968845,210.204624,174.118954,242.698832
2,0.01,0.01,0.0001,24000,224.968845,210.204624,176.418824,264.287935
3,0.01,0.01,0.0001,32000,224.968845,210.204624,169.634462,251.724293
4,0.01,0.01,0.0001,40000,224.968845,210.204624,171.687572,260.475951
...,...,...,...,...,...,...,...,...
95,0.10,0.10,0.0005,48000,325.311955,328.068634,316.374747,370.030933
96,0.10,0.10,0.0005,56000,278.774849,277.374384,227.174077,328.726003
97,0.10,0.10,0.0005,64000,330.886296,308.776254,198.270536,303.341669
98,0.10,0.10,0.0005,72000,302.286218,304.455138,329.108379,361.482108


In [None]:
FID_table = FID_table.assign(avg=FID_table.loc[:, ["fid_val_A2B", "fid_val_A2B2A", "fid_val_B2A", "fid_val_B2A2B"]].mean(axis=1))

In [None]:
FID_table = FID_table.sort_values(by='avg')

In [None]:
FID_table = FID_table.drop_duplicates()


In [None]:
FID_table.to_csv(csv_output_folder + 'FID_with_average.csv', index=False)
# FID_table = FID_table.dropna()
# FID_table = FID_table.drop(columns=['fid_val_A2B.1'])

In [None]:
FID_table

Unnamed: 0,recon_kl_w,recon_kl_clc,lr_value,iteration,fid_val_A2B,fid_val_A2B2A,fid_val_B2A,fid_val_B2A2B,avg
70,0.10,0.10,0.0001,8000,203.910638,183.760694,183.209818,212.683878,195.891257
71,0.10,0.10,0.0001,16000,202.958507,200.343928,172.600639,234.167779,202.517713
30,0.10,0.10,0.0001,8000,224.968845,210.204624,183.212693,212.631477,207.754410
92,0.10,0.10,0.0005,24000,204.744435,213.977543,178.736501,242.516149,209.993657
48,0.01,0.01,0.0005,72000,224.968845,210.204624,197.371917,208.524438,210.267456
...,...,...,...,...,...,...,...,...,...
98,0.10,0.10,0.0005,72000,302.286218,304.455138,329.108379,361.482108,324.332961
84,0.01,0.01,0.0005,40000,312.639646,316.967270,341.432377,348.106725,329.786504
85,0.01,0.01,0.0005,48000,303.227149,351.199004,311.179702,358.138596,330.936113
95,0.10,0.10,0.0005,48000,325.311955,328.068634,316.374747,370.030933,334.946567
