### Setup Paths

In [None]:
import pickle
import platform
import os
if platform.system() == 'Darwin':
    DATA_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Data.nosync"
    ROOT_PATH = "/Users/maltegenschow/Documents/Uni/Thesis/Thesis"
elif platform.system() == 'Linux':
    DATA_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Data.nosync"
    ROOT_PATH = "/pfs/work7/workspace/scratch/tu_zxmav84-thesis/Thesis"

current_wd = os.getcwd()

In [None]:
# Real Data
real_data_input = f"{DATA_PATH}/Zalando_Germany_Dataset/dresses/images/e4e_images/all/"
real_data_features = f"{DATA_PATH}/Metrics/FID/features_real_dataset.npz"

# SG2Ada 00003 Snapshot 920
sg2_00003_input = f"{DATA_PATH}/Generated_Images/SG2Ada/00003_snapshot_920/"
sg2_00003_features = f"{DATA_PATH}/Metrics/FID/features_generated_00003.npz"

# SG2Ada 00005 Snapshot 1200
sg2_00005_input = f"{DATA_PATH}/Generated_Images/SG2Ada/00003_snapshot_1200/"
sg2_00005_features = f"{DATA_PATH}/Metrics/FID/features_generated_00005.npz"

# e4e from 00003_snapshot_920
e4e_00003_input = f"{DATA_PATH}/Generated_Images/e4e/00003_snapshot_920/"
e4e_00003_features = f"{DATA_PATH}/Metrics/FID/features_e4e_00003_snapshot_920.npz"

# e4e from 00005_snapshot_1200
e4e_00005_input = f"{DATA_PATH}/Generated_Images/e4e/00005_snapshot_1200/"
e4e_00005_features = f"{DATA_PATH}/Metrics/FID/features_e4e_00005_snapshot_1200.npz"

# PTI
pti_input =  f"{DATA_PATH}/Generated_Images/PTI/"
pti_features = f"{DATA_PATH}/Metrics/FID/features_pti.npz"

# Restyle
restyle_input = f"{DATA_PATH}/Generated_Images/restyle/inference_results/4/"
restyle_feautures = f"{DATA_PATH}/Metrics/FID/features_restyle.npz"

### Define FID Functions

In [None]:
def calculate_fid_features(input_path, output_path):
    if not os.path.exists(output_path):
        print(f"Calculating Features for Images in folder {input_path}")
        CMD = f"python -m pytorch_fid --save-stats {input_path} {output_path} --device cuda:0"
        !{CMD}
    else:
        print("Features already calculated and ready to use.")

In [None]:
import subprocess
def calculate_fid(dataset1, dataset2):
    CMD = ["python", "-m", "pytorch_fid", dataset1, dataset2]
    result = subprocess.run(CMD, stdout=subprocess.PIPE, text=True)
    return float(result.stdout.split()[1])

### Calculate all Feature Maps for FID Calculation

In [None]:
# Real Data
calculate_fid_features(real_data_input, real_data_features)
# SG2Ada 00003
calculate_fid_features(sg2_00003_input ,sg2_00003_features)
# SG2Ada 00005
calculate_fid_features(sg2_00005_input ,sg2_00005_features)
# e4e 00003
calculate_fid_features(e4e_00003_input, e4e_00003_features)
# e4e 00005
calculate_fid_features(e4e_00005_input, e4e_00005_features)
# PTI
calculate_fid_features(pti_input, pti_features)
# Restyle
calculate_fid_features(restyle_input, restyle_feautures)

### Calculate FIDs

In [None]:
fid_results = {}
fid_results['SG2ADA_00003_snapshot_920'] = calculate_fid(real_data_features, sg2_00003_features)
fid_results['SG2ADA_00005_snapshot_1200'] = calculate_fid(real_data_features, sg2_00005_features)
fid_results['e4e_00003'] = calculate_fid(real_data_features, e4e_00003_features)
fid_results['e4e_00005'] = calculate_fid(real_data_features, e4e_00005_features)
fid_results['PTI'] = calculate_fid(real_data_features, pti_features)
fid_results['Restyle'] = calculate_fid_features(real_data_features, restyle_feautures)
fid_results

In [None]:
with open(f"{DATA_PATH}/Metrics/FID/FID_Results.pkl", 'wb') as f:
    pickle.dump(fid_results, f, protocol=pickle.HIGHEST_PROTOCOL)