In [14]:
import os 
import json
import torch
from build_fid_metrics import build_real_stats_per_class
from torchvision.utils import save_image
from cleanfid import fid

CIFAR_FID_CACHE = "/vols/bitbucket/saravanan/distributional-mf/cache"
samples_dir = "/vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4"
samples_path = "/vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4/samples.pt"
labels_path = "/vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4/labels.pt"
samples = torch.load(samples_path)
labels = torch.load(labels_path)

print("loaded samples", samples.shape)
print("loaded labels", labels.shape)

stats_names = build_real_stats_per_class(CIFAR_FID_CACHE)
per_class_fid = {}
fid_path = os.path.join(samples_dir, "fid")

os.makedirs(fid_path, exist_ok=True)
classes = torch.unique(labels).tolist()

for c in classes:
    cdir = os.path.join(fid_path, f"gen_c{c}")
    os.makedirs(cdir, exist_ok=True)
    idx = (labels == c).nonzero(as_tuple=False).squeeze(1)
    samples_cls = samples[idx]
    for i, img in enumerate(samples_cls):
        save_image(img.clamp(0,1), os.path.join(cdir, f"{i:06d}.png"))

    fid_score = fid.compute_fid(cdir,
                                dataset_name=stats_names[c],
                                dataset_split="custom",
                                mode="clean",
                                dataset_res=32)
    per_class_fid[c] = fid_score
    
with open(os.path.join(samples_dir, "per_class_fid.json"), "w") as f:
    json.dump(per_class_fid, f, indent=4)

print(per_class_fid)

loaded samples torch.Size([4096, 3, 32, 32])
loaded labels torch.Size([4096])
compute FID of a folder with cifar10_train_clean_class0 statistics
Found 1024 images in the folder /vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4/fid/gen_c0


FID gen_c0 : 100%|██████████| 32/32 [00:13<00:00,  2.37it/s]


compute FID of a folder with cifar10_train_clean_class5 statistics
Found 1024 images in the folder /vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4/fid/gen_c5


FID gen_c5 : 100%|██████████| 32/32 [00:10<00:00,  3.10it/s]


compute FID of a folder with cifar10_train_clean_class7 statistics
Found 1024 images in the folder /vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4/fid/gen_c7


FID gen_c7 : 100%|██████████| 32/32 [00:12<00:00,  2.50it/s]


compute FID of a folder with cifar10_train_clean_class9 statistics
Found 1024 images in the folder /vols/bitbucket/saravanan/distributional-mf/outputs/default/cifar/class_conditional/iwae_500_False_4/fid/gen_c9


FID gen_c9 : 100%|██████████| 32/32 [00:11<00:00,  2.85it/s]


{0: np.float64(80.2431764587813), 5: np.float64(94.63106822062451), 7: np.float64(101.59386659014154), 9: np.float64(82.17014458699393)}
