In [1]:
import os
import torch
from torchvision.io import read_image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.transforms import transforms as T
import torchvision
from torch.multiprocessing import Pool, cpu_count
from PIL import Image
from tqdm import tqdm

import time

transform = T.ToTensor()

2023-11-09 22:07:17.905868: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-09 22:07:17.906845: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-09 22:07:17.914575: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
def process_image(image_file):
    image = transform(Image.open(image_file))
    ret_tensor = torch.zeros(49, 3, 512, 512)

    for i in range(7):
        current_row = i * 256
        for j in range(7):
            current_col = j * 256
            
            ret_tensor[i * 7 + j] = image[:, current_row:current_row + 512, current_col:current_col + 512]

    return ret_tensor

def get_fid_score_a(real_image_tensor, gen_image_tensor):
    start_time = time.time()
    real_image_tensor = real_image_tensor
    gen_image_tensor = gen_image_tensor
    fid = FrechetInceptionDistance(feature=2048, normalize=True)
    fid.reset()
    fid.update(real_image_tensor, real=True)
    fid.update(gen_image_tensor, real=False)
    fid_score = fid.compute()
    print(f"Time taken: {time.time() - start_time}")
    print(f"FID score: {fid_score}")
    return fid_score

def get_fid_score_b(real_image_tensor, gen_image_tensor):
    N = len(real_image_tensor)
    start_time = time.time()
    real_image_tensor = real_image_tensor
    gen_image_tensor = gen_image_tensor
    fid = FrechetInceptionDistance(feature=2048, normalize=True)
    fid.reset()

    fid.update(real_image_tensor[:len(real_image_tensor) // 2], real=True)
    fid.update(gen_image_tensor[len(gen_image_tensor) // 2:], real=False)
    fid_score_1 = fid.compute()
    print(f"Time taken: {time.time() - start_time}")
    print(f"FID score: {fid_score_1}")

    start_time = time.time()
    fid.reset()

    fid.update(real_image_tensor[len(real_image_tensor) // 2:], real=True)
    fid.update(gen_image_tensor[:len(gen_image_tensor) // 2], real=False)
    fid_score_2 = fid.compute()
    print(f"Time taken: {time.time() - start_time}")
    print(f"FID score: {fid_score_2}")

    fid_score = (fid_score_1 + fid_score_2) / 2
    print(f"Final FID score: {fid_score}")



In [3]:
def process_datasets_tiled(real_dir, fake_dir, num_images):
    if os.listdir(real_dir) != os.listdir(fake_dir):
        raise ValueError("The real and fake directories must contain the same number of images.")

    N = len(os.listdir(real_dir))
    total_images = N * num_images

    real_tensor = torch.zeros(total_images * 49, 3, 512, 512)
    fake_tensor = torch.zeros(total_images * 49, 3, 512, 512)

    folder_list = os.listdir(real_dir)

    for i, folder_name in tqdm(enumerate(folder_list)):
        real_folder_path = os.path.join(real_dir, folder_name)
        fake_folder_path = os.path.join(fake_dir, folder_name)

        real_images = os.listdir(real_folder_path)
        fake_images = os.listdir(fake_folder_path)

        for j in range(num_images):
            real_image_path = os.path.join(real_folder_path, real_images[j])
            fake_image_path = os.path.join(fake_folder_path, fake_images[j])

            real_tensor[(i * num_images + j) * 49 : ((i * num_images) + j + 1) * 49] = process_image(real_image_path)
            fake_tensor[(i * num_images + j) * 49 : ((i * num_images) + j + 1) * 49] = process_image(fake_image_path)


    return real_tensor, fake_tensor

def process_datasets(real_dir, fake_dir, num_images):
    if os.listdir(real_dir) != os.listdir(fake_dir):
        raise ValueError("The real and fake directories must contain the same number of images.")

    N = len(os.listdir(real_dir))
    total_images = N * num_images

    real_tensor = torch.zeros(total_images, 3, 2048, 2048)
    fake_tensor = torch.zeros(total_images, 3, 2048, 2048)

    folder_list = os.listdir(real_dir)

    for i, folder_name in tqdm(enumerate(folder_list)):
        real_folder_path = os.path.join(real_dir, folder_name)
        fake_folder_path = os.path.join(fake_dir, folder_name)

        real_images = os.listdir(real_folder_path)
        fake_images = os.listdir(fake_folder_path)

        for j in range(num_images):
            real_image_path = os.path.join(real_folder_path, real_images[j])
            fake_image_path = os.path.join(fake_folder_path, fake_images[j])

            real_tensor[i * num_images + j] = transform(Image.open(real_image_path))
            fake_tensor[i * num_images + j] = transform(Image.open(fake_image_path))

    return real_tensor, fake_tensor

In [4]:
real = '/scratch/bbut/min_validation_set/bing_urban'
fake = '/scratch/bbut/axiao/ours_validation_out2048_urban_from10_negative_52334'

In [5]:
import os
import torch
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from torchmetrics.image.kid import KernelInceptionDistance

# Assuming process_image is defined as before
def process_image_kid(image_path):
    image = transform(Image.open(image_path))
    return torch.unsqueeze(image, dim=0)

# Function to process a single folder for real or fake images
def process_folder(folder, real_or_fake):
    # Define the base directory
    base_dir = real if real_or_fake == 'real' else fake
    # file_name = '20.png' if real_or_fake == 'real' else '10_gt_20.png'
    file_name = '20.png' if os.path.exists(os.path.join(base_dir, folder, '20.png')) else '10_gt_20.png'
    # Construct the path to the image
    image_path = os.path.join(base_dir, folder, file_name)
    # Process the image
    if metric == 'fid':
        tiles = process_image(image_path)
    else:
        tiles = process_image(image_path)
    # Update the fid object with all tiles
    for tile in tiles:
        fid.update(tile.unsqueeze(0), real=(real_or_fake == 'real'))

In [6]:
for metric in ['kid', 'fid']:
    for from_folder in ['20', 'urban']:
        real = '/scratch/bbut/min_validation_set/bing_' + from_folder
        if metric == 'fid':
            fid = FrechetInceptionDistance(feature=2048, normalize=True, reset_real_features = False)
        else:
            fid = KernelInceptionDistance(feature=2048, normalize=True, reset_real_features = False)
        with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
            # Create a progress bar
            real_folders = list(os.listdir(real))            
            # Submit tasks to the executor for real images
            list(executor.map(process_folder, real_folders, ['real']*len(real_folders)))


        # Setup the ThreadPoolExecutor
        print(from_folder)

        root = ''   #Add the common prefix of the fake images here (if any)
        suffix = '' #Add the common suffix of the fake images here (if any)
        
        for fake_folder in ['0', '52334', 'untuned_base', 'hat', 'liif', 'interpolation']:
            fake = root + fake_folder + suffix
            fid.reset()
            with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
                # Create a progress bar
                fake_folders = list(os.listdir(fake))
                # Submit tasks to the executor for fake images
                list(executor.map(process_folder, fake_folders, ['fake']*len(fake_folders)))

            # Compute the FID score
            fid_score = fid.compute()
            print(f"* {metric} score: {fid_score}\t - {fake_folder}")



20
* kid score: (tensor(0.0549), tensor(0.0020))	 - 0
* kid score: (tensor(0.0210), tensor(0.0013))	 - 52334
* kid score: (tensor(0.0915), tensor(0.0018))	 - untuned_base
* kid score: (tensor(0.3371), tensor(0.0039))	 - hat
* kid score: (tensor(0.3857), tensor(0.0043))	 - liif
* kid score: (tensor(0.2965), tensor(0.0038))	 - interpolation
urban
* kid score: (tensor(0.1474), tensor(0.0055))	 - 0
* kid score: (tensor(0.0533), tensor(0.0026))	 - 52334
* kid score: (tensor(0.1594), tensor(0.0034))	 - untuned_base
* kid score: (tensor(0.4821), tensor(0.0062))	 - hat
* kid score: (tensor(0.5249), tensor(0.0061))	 - liif
* kid score: (tensor(0.4158), tensor(0.0052))	 - interpolation
