In [28]:
import torch
from torchvision.models import inception_v3
from torchvision import transforms
from torch.nn.functional import adaptive_avg_pool2d

from PIL import Image
import numpy as np
import pandas as pd
import os
from tqdm import tqdm

import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance

device = torch.device(f'cuda:3' if torch.cuda.is_available() else 'cpu')
fid = FrechetInceptionDistance(feature=64).to(device)

In [41]:
dir = '/media/global_data/fair_neural_compression_data/decoded_rfw/decoded_64x64/mbt2018'
for name in os.listdir(dir):
    print(name)

celebA
fairface


In [38]:
generated_images_path = '/media/global_data/fair_neural_compression_data/decoded_rfw/decoded_64x64/jpeg/q_1'
meta_data_path = '/media/global_data/fair_neural_compression_data/datasets/RFW/clean_metadata/numerical_labels.csv'
clean_images_path = '/media/global_data/fair_neural_compression_data/datasets/RFW/data_64'

meta_data_inf = pd.read_csv(meta_data_path).to_numpy()

In [14]:
def get_images(meta_data_inf, path):
    image_tensors = []
    for meta_data in tqdm(meta_data_inf, total=len(meta_data_inf)):
        file_path = os.path.join(path, meta_data[2], meta_data[1])
        image = Image.open(file_path).convert('RGB')
        image_tensor = transforms.ToTensor()(image)
        image_tensors.append(image_tensor)
    
    image_tensors = torch.stack(image_tensors, dim=0)
    return image_tensors

In [39]:
generated_image_tensors = get_images(meta_data_inf, generated_images_path)

100%|██████████| 40607/40607 [00:08<00:00, 4756.98it/s]


In [30]:
generated_image_tensors = generated_image_tensors.to(device).to(torch.uint8)

In [31]:
clean_image_tensors = get_images(meta_data_inf, clean_images_path)

100%|██████████| 40607/40607 [00:10<00:00, 3957.30it/s]


In [32]:
clean_image_tensors = clean_image_tensors.to(device).to(torch.uint8)

In [33]:
def update_fid_in_batches(fid, images, batch_size=128, real=True):
    for i in range(0, len(images), batch_size):
        batch = images[i:i + batch_size]
        fid.update(batch, real=real)

In [34]:
update_fid_in_batches(fid, clean_image_tensors, batch_size=64, real=True)
update_fid_in_batches(fid, generated_image_tensors, batch_size=64, real=False)


In [48]:
float(fid.compute())

1.4183015082380734e-05

- mbt2018
    - celebA
        - q0001: 6.5600e-05
        - q0009: 2.8195e-05
        - q1:
        - q2:
        - q3: 1.9363e-05
        
    - fairface
        - q0001: 9.0552e-05
        - q0009: 1.4183e-05
        - q1: 
        - q2:
        - q3:


In [4]:
model = inception_v3(pretrained=True, transform_input=False)
model.fc = torch.nn.Identity()  # Remove the final classification layer
model = model.eval()



In [None]:
preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def load_and_preprocess_image(img_path):
    img = Image.open(img_path).convert('RGB')
    img = preprocess(img)
    return img.unsqueeze(0)