In [1]:
FIRST_EPOCH = 10
EPOCH_INTERVAL = 10
END_EPOCH = 500

In [2]:
import os
import random
import numpy as np
from PIL import Image
from scipy.linalg import sqrtm
import torch
from torchvision import models, transforms
import csv
from datetime import datetime

def calculate_fid(real_features, fake_features):
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)

    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def extract_features(folder, model, transform, num_images):
    images = os.listdir(folder)
    sampled_images = random.sample(images, num_images)
    features = []

    for img_name in sampled_images:
        img_path = os.path.join(folder, img_name)
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0)

        with torch.no_grad():
            feature = model(img_tensor).squeeze(0).numpy()
        features.append(feature)

    return np.array(features), sampled_images

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def main(epoch, csv_file="fid_results.csv"):
    folder1 = f"../dataset/basic"
    folder2 = f"../synthesis_images/epoch{epoch}"
    num_images = 1500

    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = torch.nn.Identity()
    inception_model.eval()

    real_features, real_sampled = extract_features(folder1, inception_model, transform, num_images)
    fake_features, fake_sampled = extract_features(folder2, inception_model, transform, num_images)

    fid_score = calculate_fid(real_features, fake_features)

    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    file_exists = os.path.isfile(csv_file)
    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            # ファイルが存在しない場合はヘッダーを追加
            writer.writerow(["Epoch", "Execution Time", "FID Score", "Sampled Images from Original", "Sampled Images from Generated"])
        writer.writerow([epoch, current_time, fid_score, ";".join(real_sampled), ";".join(fake_sampled)])

    print(f"Epoch {epoch} - FID score: {fid_score}")
    print(f"Results saved to {csv_file}")


for epoch in range (FIRST_EPOCH, END_EPOCH+10, EPOCH_INTERVAL):
    print(epoch)
    if __name__ == "__main__":
        main(epoch)

10




Epoch 10 - FID score: 315.16121474226
Results saved to fid_results.csv
20
Epoch 20 - FID score: 266.62087187097006
Results saved to fid_results.csv
30
Epoch 30 - FID score: 265.4476128005759
Results saved to fid_results.csv
40
Epoch 40 - FID score: 233.06468928557948
Results saved to fid_results.csv
50
Epoch 50 - FID score: 239.02401043221403
Results saved to fid_results.csv
60
Epoch 60 - FID score: 248.23393446001643
Results saved to fid_results.csv
70
Epoch 70 - FID score: 217.04723641140333
Results saved to fid_results.csv
80
Epoch 80 - FID score: 172.44870712093655
Results saved to fid_results.csv
90
Epoch 90 - FID score: 189.2273254147801
Results saved to fid_results.csv
100
Epoch 100 - FID score: 180.45469884138524
Results saved to fid_results.csv
110
Epoch 110 - FID score: 182.53063314749772
Results saved to fid_results.csv
120
Epoch 120 - FID score: 167.75917689249843
Results saved to fid_results.csv
130
Epoch 130 - FID score: 163.08355481245997
Results saved to fid_results.csv