In [7]:
FIRST_EPOCH = 10
EPOCH_INTERVAL = 10
END_EPOCH = 160

In [8]:
import os
import random
import numpy as np
from PIL import Image
from scipy.linalg import sqrtm
import torch
from torchvision import models, transforms
import csv
from datetime import datetime

def calculate_fid(real_features, fake_features):
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)

    ssdiff = np.sum((mu1 - mu2) ** 2.0)
    covmean = sqrtm(sigma1.dot(sigma2))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

def extract_features(folder, model, transform, num_images):
    images = os.listdir(folder)
    sampled_images = random.sample(images, num_images)
    features = []

    for img_name in sampled_images:
        img_path = os.path.join(folder, img_name)
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0)

        with torch.no_grad():
            feature = model(img_tensor).squeeze(0).numpy()
        features.append(feature)

    return np.array(features), sampled_images

In [9]:
def main(epoch, csv_file="fid_results.csv"):
    folder1 = f"../dataset/basic"
    folder2 = f"../synthesis_images/step6/epoch{epoch}"
    num_images = 1500

    transform = transforms.Compose([
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = torch.nn.Identity()
    inception_model.eval()

    real_features, real_sampled = extract_features(folder1, inception_model, transform, num_images)
    fake_features, fake_sampled = extract_features(folder2, inception_model, transform, num_images)

    fid_score = calculate_fid(real_features, fake_features)

    current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    file_exists = os.path.isfile(csv_file)
    with open(csv_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        if not file_exists:
            # ファイルが存在しない場合はヘッダーを追加
            writer.writerow(["Epoch", "Execution Time", "FID Score", "Sampled Images from Original", "Sampled Images from Generated"])
        writer.writerow([epoch, current_time, fid_score, ";".join(real_sampled), ";".join(fake_sampled)])

    print(f"Epoch {epoch} - FID score: {fid_score}")
    print(f"Results saved to {csv_file}")


for epoch in range (FIRST_EPOCH, END_EPOCH+10, EPOCH_INTERVAL):
    print(epoch)
    if __name__ == "__main__":
        main(epoch)

10
Epoch 10 - FID score: 324.67656481192716
Results saved to fid_results.csv
20
Epoch 20 - FID score: 313.90332537107474
Results saved to fid_results.csv
30
Epoch 30 - FID score: 249.0677208566722
Results saved to fid_results.csv
40
Epoch 40 - FID score: 163.81441827407536
Results saved to fid_results.csv
50
Epoch 50 - FID score: 104.42224624181446
Results saved to fid_results.csv
60
Epoch 60 - FID score: 51.56157375202825
Results saved to fid_results.csv
70
Epoch 70 - FID score: 28.54749340460875
Results saved to fid_results.csv
80
Epoch 80 - FID score: 27.385347958660155
Results saved to fid_results.csv
90
Epoch 90 - FID score: 24.281041082509557
Results saved to fid_results.csv
100
Epoch 100 - FID score: 36.1367976772574
Results saved to fid_results.csv
110
Epoch 110 - FID score: 24.29972406328111
Results saved to fid_results.csv
120
Epoch 120 - FID score: 23.96450912726603
Results saved to fid_results.csv
130
Epoch 130 - FID score: 23.1909155582343
Results saved to fid_results.csv
