In [1]:
# --- 1. IMPORTS AND SETUP ---
import torch
import torch.nn as nn
import torchvision.models as models
from scipy.linalg import sqrtm
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os

print("--- Starting FID Score Calculation ---")

class Config:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers = 2
    dataset_path = '/kaggle/input/dgm-animals/Animals_data/animals/animals'
    image_size = 128
    img_channels = 3
    latent_dim = 100
    embedding_dim = 100
    features_g = 64
    features_d = 64
    batch_size = 64
    num_classes = 90

config = Config()

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, embedding_dim, channels_img, features_g):
        super(ConditionalGenerator, self).__init__()
        self.embed = nn.Embedding(num_classes, embedding_dim)
        self.net = nn.Sequential(
            self._block(latent_dim + embedding_dim, features_g * 16, 4, 1, 0),
            self._block(features_g * 16, features_g * 8, 4, 2, 1),
            self._block(features_g * 8, features_g * 4, 4, 2, 1),
            self._block(features_g * 4, features_g * 2, 4, 2, 1),
            self._block(features_g * 2, features_g, 4, 2, 1),
            nn.ConvTranspose2d(features_g, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh())
    def _block(self, in_c, out_c, k, s, p):
        return nn.Sequential(nn.ConvTranspose2d(in_c, out_c, k, s, p, bias=False), nn.BatchNorm2d(out_c), nn.ReLU(True))
    def forward(self, z, labels):
        embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
        x = torch.cat([z, embedding], dim=1)
        return self.net(x)

experiment_dir_to_evaluate = "/kaggle/input/cgan_200/pytorch/default/1/cgan_D1_G7_200_epochs"

# --- FID Configuration ---
FID_IMG_SIZE = 299
FID_BATCH_SIZE = 32
NUM_SAMPLES = 1000
DEVICE = config.device

# --- FID Transformations ---
fid_transform = transforms.Compose([
    transforms.Resize((FID_IMG_SIZE, FID_IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("FID setup complete.")

--- Starting FID Score Calculation ---
FID setup complete.


In [2]:
class InceptionV3FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
        # Replace the final classification layer with an identity layer to get features
        self.model.fc = nn.Identity()

    def forward(self, x):
        # In eval mode, InceptionV3 has a single output.
        # In train mode, it has an auxiliary output we would need to handle.
        # Setting .eval() is crucial.
        return self.model(x)

inception_model = InceptionV3FeatureExtractor().to(DEVICE)
inception_model.eval() # Must be in eval mode
print("InceptionV3 model loaded.")

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:00<00:00, 226MB/s] 


InceptionV3 model loaded.


In [3]:
# --- 3. FID CALCULATION LOGIC ---
@torch.no_grad()
def get_activations(dataloader, model, device, max_samples):
    activations = []
    for images, _ in dataloader:
        if len(activations) * FID_BATCH_SIZE >= max_samples: break
        images = images.to(device)
        act = model(images)
        activations.append(act.cpu())
    return torch.cat(activations, dim=0)[:max_samples]

def calculate_fid(act1, act2):
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean): covmean = covmean.real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [4]:
# --- 4. DATA PREPARATION AND EXECUTION ---
try:
    gen_final = ConditionalGenerator(config.latent_dim, config.num_classes, config.embedding_dim, config.img_channels, config.features_g).to(DEVICE)
    generator_path = os.path.join(experiment_dir_to_evaluate, 'generator_final.pth')
    gen_final.load_state_dict(torch.load(generator_path))
    gen_final.eval()
    print(f"Loaded trained generator from: {generator_path}")
except FileNotFoundError:
    print(f"ERROR: Generator model not found at {generator_path}.")
    raise

# --- Get REAL image activations ---
real_dataset_fid = datasets.ImageFolder(root=config.dataset_path, transform=fid_transform)
real_dataloader_fid = DataLoader(real_dataset_fid, batch_size=FID_BATCH_SIZE, shuffle=True, num_workers=config.num_workers)
print("Calculating activations for real images...")
real_activations = get_activations(real_dataloader_fid, inception_model, DEVICE, max_samples=NUM_SAMPLES).numpy()

# --- Get FAKE image activations ---
fake_activations_list = []
print(f"Generating {NUM_SAMPLES} fake images and calculating activations...")
with torch.no_grad():
    processed_samples = 0
    while processed_samples < NUM_SAMPLES:
        noise = torch.randn(FID_BATCH_SIZE, config.latent_dim, 1, 1, device=DEVICE)
        labels = torch.randint(0, config.num_classes, (FID_BATCH_SIZE,), device=DEVICE)
        generated_batch = gen_final(noise, labels) # Output is [-1, 1]
        
        # Apply the same FID transform to the generated images
        # The transform expects PIL or tensor in [0,1] range, so we must denormalize
        generated_batch_denorm = generated_batch * 0.5 + 0.5
        
        # We can apply the normalization part of the FID transform directly
        generated_batch_transformed = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(generated_batch_denorm)
        # We assume the generator output is already 299x299 for FID, which isn't true.
        # We must resize it first.
        generated_batch_resized = nn.functional.interpolate(generated_batch_transformed, size=(FID_IMG_SIZE, FID_IMG_SIZE), mode='bilinear', align_corners=False)

        act = inception_model(generated_batch_resized)
        fake_activations_list.append(act.cpu())
        processed_samples += FID_BATCH_SIZE

fake_activations = torch.cat(fake_activations_list, dim=0)[:NUM_SAMPLES].numpy()

# --- Calculate and Print Final Score ---
print("Calculating final FID score...")
fid_score = calculate_fid(real_activations, fake_activations)

print("\n" + "="*40)
print(f"FID Score for '{experiment_dir_to_evaluate}': {fid_score:.2f}")
print("="*40)

Loaded trained generator from: /kaggle/input/cgan_200/pytorch/default/1/cgan_D1_G7_200_epochs/generator_final.pth
Calculating activations for real images...
Generating 1000 fake images and calculating activations...
Calculating final FID score...

FID Score for '/kaggle/input/cgan_200/pytorch/default/1/cgan_D1_G7_200_epochs': 272.63
