In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,"
os.environ["HTTP_PROXY"] = "http://127.0.0.1:7890"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:7890"
import pathlib
from argparse import Namespace

import numpy as np
import torch
import torchvision.transforms as TF
from PIL import Image
import torch_fidelity


IMAGE_EXTENSIONS = {"bmp", "jpg", "jpeg", "pgm", "png", "ppm", "tif", "tiff", "webp"}


In [2]:
class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None):
        self.files = files
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert("RGB")
        if img.size[0] != 224 or img.size[1] != 224:
            img = img.resize((224, 224), Image.BILINEAR)
        if self.transforms is not None:
            img = self.transforms(img)
        return img


class TransformPILtoRGBTensor:
    def __call__(self, img):
        assert type(img) is Image.Image, 'Input is not a PIL.Image'
        width, height = img.size
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
        img = img.permute(2, 0, 1)
        return img

In [3]:
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1, num_samples=-1):
    fils_1 = get_files_of_path(paths[0], num_samples)
    fils_2 = get_files_of_path(paths[1], num_samples)

    dataset_1 = ImagePathDataset(fils_1, TransformPILtoRGBTensor())
    dataset_2 = ImagePathDataset(fils_2, TransformPILtoRGBTensor())

    metrics_dict = torch_fidelity.calculate_metrics(
        input1=dataset_2, 
        input2=dataset_1, 
        cuda=True, 
        fid=True, 
        kid=True, 
        verbose=False,
    )

    return metrics_dict

def get_files_of_path(path, num_samples=-1):
    path = pathlib.Path(path)
    files = sorted(
        [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))]
    )
    if num_samples > 0:
        if 'mosaic' not in str(path) and 'bezier' not in str(path): # real image
            np.random.seed(33)
            files = np.random.choice(files, num_samples, replace=False)
        else:
            files = files[:num_samples]

    return files
    

In [4]:
def main():
    args = Namespace()
    args.batch_size = 50
    args.device = "cuda:6"
    args.dims = 2048
    args.num_workers = 8

    if args.device is None:
        device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
    else:
        device = torch.device(args.device)

    if args.num_workers is None:
        try:
            num_cpus = len(os.sched_getaffinity(0))
        except AttributeError:
            # os.sched_getaffinity is not available under Windows, use
            # os.cpu_count instead (which may not return the *available* number
            # of CPUs).
            num_cpus = os.cpu_count()

        num_workers = min(num_cpus, 8) if num_cpus is not None else 0
    else:
        num_workers = args.num_workers

    for dataset in ['WSSS4LUAD', 'BCSS-WSSS', 'LUAD-HistoSeg']:
        data_dir = os.path.join('data', dataset)
        if dataset == 'WSSS4LUAD':
            real_image_path = "data/WSSS4LUAD/1.training"
            runs = ['', '2', '3', '5', '9']
            num_samples = 3600
        elif dataset == 'BCSS-WSSS':
            real_image_path = "data/BCSS-WSSS/training"
            runs = ['', '1', '2', '3', '4']
            num_samples = 7200
        elif dataset == 'LUAD-HitsoSeg':
            real_image_path = "data/LUAD-HistoSeg/train"
            runs = ['', '1', '2', '3', '4']
            num_samples = 3600

        for run in runs:
            if len(run) == 0:
                mosaic_dir = os.path.join(data_dir, f'mosaic_2_112', 'img')
                mosaic_disc_dir = os.path.join(data_dir, f'mosaic_2_112', 'disc_img_r18_e5')
                bezier_dir = os.path.join(data_dir, 'bezier224_5_0.2_0.05_1d1', 'img')
                bezier_disc_dir = os.path.join(data_dir, 'bezier224_5_0.2_0.05_1d1', 'disc_img_r18_e5')
            else:
                mosaic_dir = os.path.join(data_dir, f'mosaic_2_112_run{run}', 'img')
                mosaic_disc_dir = os.path.join(data_dir, f'mosaic_2_112_run{run}', 'disc_img_r18_e5')
                bezier_dir = os.path.join(data_dir, f'bezier224_5_0.2_0.05_1d1_run{run}', 'img')
                bezier_disc_dir = os.path.join(data_dir, f'bezier224_5_0.2_0.05_1d1_run{run}', 'disc_img_r18_e5')
            
            exp_name = ['Mosaic', 'Mosaic Disc', 'Bezier', 'Bezier Disc']
            dset_list = [mosaic_dir, mosaic_disc_dir, bezier_dir, bezier_disc_dir]

            for name, dset in zip(exp_name, dset_list):
                metrics_dict = calculate_fid_given_paths(
                    [real_image_path, dset], args.batch_size, device, args.dims, num_workers, num_samples=num_samples
                )
                with open(f'./fid_kid_result/results_{dataset}.txt', 'a') as f:
                    f.write(f'Run #{run}, {name}: FID, KID = {metrics_dict}\n')
        



In [5]:
main()

  img = img.resize((224, 224), Image.BILINEAR)
