In [4]:
import os
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import lpips

loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to "traditional" perceptual loss, when used for optimization

path_a = os.path.join('/home', 'jovyan', 'datasets', 'drive')
print (path_a)
path_b = os.path.join('/home', 'jovyan', 'FastGAN', 'eval_40000')
print (path_b)
image_size = 256

transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

transform_augmented = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=image_size, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])


dset_a = ImageFolder(path_a, transform)
print("Number of real images:", len(dset_a))
loader_a = DataLoader(dset_a, batch_size=32, num_workers=4)

dset_b = ImageFolder(path_b, transform)
print("Number of synthetic images:", len(dset_b))
loader_b = DataLoader(dset_b, batch_size=32, num_workers=4)

augmented_images = ImageFolder(path_a, transform_augmented)
print("Number of augmented real images:", len(augmented_images))
loader_aug = DataLoader(augmented_images, batch_size=32, num_workers=4)


def normalize_batch(loader):
    i=0
    for img_batch,_ in loader:
        nn.functional.normalize(img_batch)
        i += 1
        print(f"Batch {i} normalized")
        
def verify_norm(loader):
    mean = 0.
    std = 0.
    for batch_idx, (real_images, _) in enumerate(loader):
        for i in range(real_images.shape[0]):
            normalized_image = nn.functional.normalize(real_images[i])
            mean += normalized_image.mean()
            std += normalized_image.std()
    print(f"Mean: {mean/len(dset_a)}")
    print(f"Std: {std/len(dset_a)}")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /opt/conda/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /opt/conda/lib/python3.7/site-packages/lpips/weights/v0.1/vgg.pth
/home/jovyan/datasets/drive
/home/jovyan/FastGAN/eval_40000
Number of real images: 170
Number of synthetic images: 2006
Number of augmented real images: 170


In [12]:
import lpips
import torch

# Define the LPIPS distance metric
lpips_distance = lpips.LPIPS(net="alex")

# Load the datasets
real_images = ImageFolder(path_a, transform)
synth_images = ImageFolder(path_b, transform)
augmented_images = ImageFolder(path_a, transform_augmented)

# Define the number of images to sample from each dataset
n_images = min(len(real_images), len(synth_images))

# Sample the same number of images from each dataset
indices_real = torch.randperm(len(real_images))[:n_images]
indices_synth = torch.randperm(len(synth_images))[:n_images]
indices_aug = torch.randperm(len(augmented_images))[:n_images]


images_real = [real_images[i][0] for i in indices_real]
images_synth = [synth_images[i][0] for i in indices_synth]
images_aug = [augmented_images[i][0] for i in indices_aug]

# Compute the LPIPS distance between the two sets of images
distance_real_aug = lpips_distance(torch.stack(images_real), torch.stack(images_aug))
mean_distance_real_aut = distance_real_aug.mean().item()

print(f"Mean distance between real and augmented: {mean_distance_real_aut}")

distance_real_synth = lpips_distance(torch.stack(images_real), torch.stack(images_synth))
mean_distance_real_synth = distance_real_synth.mean().item()

print(f"Mean distancebetween real and synth: {mean_distance_real_synth}")


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /opt/conda/lib/python3.7/site-packages/lpips/weights/v0.1/alex.pth
Mean distance between real and augmented: 0.30347901582717896
Mean distancebetween real and synth: 0.20380167663097382
