In [1]:
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
import torchvision
from torch.utils.data import DataLoader, random_split
from pathlib import Path

In [5]:
sample_percentage = 0.1
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(64),  # args.image_size + 1/4 *args.image_size
    torchvision.transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(Path(f"data/celeba"), transform=transforms)
sample_dataset, train_dataset = random_split(dataset, (sample_percentage, 1.0 - sample_percentage))
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
sample_dataloader = DataLoader(sample_dataset, batch_size=32, shuffle=True)

In [14]:
fid = FrechetInceptionDistance(feature=64, normalize=True, reset_real_features=False)

In [6]:
real_images, _ = next(iter(sample_dataloader))

In [10]:
real_images.min()

tensor(0.)

In [15]:
for _ in range(10):
    real_images, _ = next(iter(sample_dataloader))
    fid.update(real_images, real=True)

In [17]:
for _ in range(10):
    fake_images, _ = next(iter(train_dataloader))
    fid.update(fake_images, real=False)
    tmp = fid.compute()
    print(tmp)

tensor(0.0194)
tensor(0.0154)
tensor(0.0191)
tensor(0.0190)
tensor(0.0228)
tensor(0.0253)
tensor(0.0290)
tensor(0.0334)
tensor(0.0310)
tensor(0.0301)


In [18]:
len(sample_dataloader)

634

In [19]:
634*32

20288

In [20]:
for i, x in enumerate(sample_dataloader):
    print(i, end="\r")

633

In [13]:
fid.compute()

tensor(0.0331)

In [34]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(299),  # args.image_size + 1/4 *args.image_size
    torchvision.transforms.ToTensor(),
])
dataset = torchvision.datasets.ImageFolder(Path(f"data/celeba"), transform=transforms)

fid = FrechetInceptionDistance(feature=64, normalize=True, reset_real_features=False)

whole_dataset = DataLoader(dataset, batch_size=32)
print(f"Batches in dataset: {len(whole_dataset)}")
for i, (x, _) in enumerate(whole_dataset):
    if i == 100:
        break
    print(f"{i}/{len(whole_dataset)}", end="\r")
    fid.update(x, real=True)

Batches in dataset: 6332
99/6332

In [37]:
for _ in range(10):
    fake_images, _ = next(iter(train_dataloader))
    fid.update(fake_images, real=False)
    tmp = fid.compute()
    print(tmp)
    fid.reset()

tensor(0.0562)
tensor(0.2286)
tensor(0.4195)
tensor(0.2409)
tensor(0.6310)
tensor(0.3374)
tensor(0.2287)
tensor(0.2100)
tensor(0.3285)
tensor(0.2634)


In [40]:
fid.update(fake_images, real=False)
x = fid.compute()

In [47]:
x.item()

0.26711198687553406

In [31]:
def test(max_iter: int = None):
    if type(max_iter) == int and 23 < max_iter:
        print("Here I am")
    else:
        print("Else")

test(42)
test(None)

Here I am
Else


In [24]:

for i, x in enumerate(whole_dataset):
    print(f"{i}/{len(whole_dataset)}", end="\r")
    print(type(x))
    print(x)
    break

<class 'list'>
[tensor([[[[0.9922, 0.9922, 0.9922,  ..., 0.9961, 0.9961, 0.9961],
          [0.9922, 0.9922, 0.9922,  ..., 0.9961, 0.9961, 0.9961],
          [0.9922, 0.9922, 0.9922,  ..., 0.9961, 0.9961, 0.9961],
          ...,
          [0.5098, 0.5255, 0.5451,  ..., 0.4627, 0.4627, 0.4627],
          [0.6000, 0.6392, 0.6980,  ..., 0.4667, 0.4667, 0.4667],
          [0.6588, 0.7137, 0.8000,  ..., 0.4706, 0.4706, 0.4706]],

         [[0.9059, 0.9059, 0.9059,  ..., 0.9333, 0.9333, 0.9333],
          [0.9059, 0.9059, 0.9059,  ..., 0.9333, 0.9333, 0.9333],
          [0.9059, 0.9059, 0.9059,  ..., 0.9333, 0.9333, 0.9333],
          ...,
          [0.2431, 0.2588, 0.2784,  ..., 0.2000, 0.2000, 0.2000],
          [0.3333, 0.3725, 0.4314,  ..., 0.1961, 0.1961, 0.1961],
          [0.3922, 0.4471, 0.5333,  ..., 0.1961, 0.1961, 0.1961]],

         [[0.7608, 0.7608, 0.7608,  ..., 0.8706, 0.8706, 0.8706],
          [0.7608, 0.7608, 0.7608,  ..., 0.8706, 0.8706, 0.8706],
          [0.7608, 0.7608,