In [None]:
import torch
import torch.nn as nn
class Generator(nn.Module):
    def __init__(self, latent_dim, ngf, image_channels):
        super(Generator, self).__init__()
        
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, image_channels, 3, 1, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.main(z)

In [None]:
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import numpy as np
from scipy.linalg import sqrtm
torch.manual_seed(42)

# Create InceptionV3 feature extractor
class InceptionV3FeatureExtractor(nn.Module):
    def __init__(self):
        super(InceptionV3FeatureExtractor, self).__init__()
        inception = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)
        
        # Get all modules except the final classifier
        modules = list(inception.children())[:-1]
        self.blocks = nn.ModuleList()
        
        # Split the model into smaller blocks to reduce memory usage
        self.blocks.append(nn.Sequential(
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ))
        
        self.blocks.append(nn.Sequential(
            inception.Conv2d_3b_1x1,
            inception.Conv2d_4a_3x3,
            nn.MaxPool2d(kernel_size=3, stride=2)
        ))
        
        self.blocks.append(nn.Sequential(
            inception.Mixed_5b,
            inception.Mixed_5c,
            inception.Mixed_5d
        ))
        
        self.blocks.append(nn.Sequential(
            inception.Mixed_6a,
            inception.Mixed_6b,
            inception.Mixed_6c,
            inception.Mixed_6d,
            inception.Mixed_6e
        ))
        
        self.blocks.append(nn.Sequential(
            inception.Mixed_7a,
            inception.Mixed_7b,
            inception.Mixed_7c,
            nn.AdaptiveAvgPool2d((1, 1))
        ))

        for block in self.blocks:
            block.eval()
    
    def forward(self, x):
        if x.shape[2] != 299 or x.shape[3] != 299:
            x = torch.nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)

        with torch.no_grad():
            for block in self.blocks:
                x = block(x)

            features = torch.flatten(x, 1)
            
        return features

def preprocess_for_inception(images):
    """Inception V3 expects images in range [-1, 1]
    If images are already in range [-1, 1], we keep them as is
    If images are in range [0, 1], we need to rescale them"""
    
    if images.min() >= -1 and images.max() <= 1:
        pass
    elif images.min() >= 0 and images.max() <= 1:
        images = images * 2 - 1
    
    if images.shape[2] != 299 or images.shape[3] != 299:
        images = torch.nn.functional.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
    
    return images

def calculate_inception_fid(generator, dataset, num_samples=5000, batch_size=32, latent_dim=100, device='cpu'):
    """
    Memory-efficient implementation of FID calculation.
    Processes images in small batches and accumulates statistics rather than storing all images.
    """
    feature_extractor = InceptionV3FeatureExtractor().to(device)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize variables to accumulate statistics
    real_features_sum = None
    real_features_sq_sum = None
    real_count = 0
    
    print("Processing real images...")
    # Process real images batch by batch
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            images = preprocess_for_inception(images)
            
            features = feature_extractor(images).cpu().numpy()
            
            # Accumulate statistics for mean and covariance
            if real_features_sum is None:
                real_features_sum = features.sum(axis=0)
                real_features_sq_sum = np.dot(features.T, features)
            else:
                real_features_sum += features.sum(axis=0)
                real_features_sq_sum += np.dot(features.T, features)
            
            real_count += features.shape[0]
            
            if real_count >= num_samples:
                break
    
    # Calculate mean and covariance for real images
    mu1 = real_features_sum / real_count
    sigma1 = real_features_sq_sum / real_count - np.outer(mu1, mu1)
    
    # Initialize variables for fake images
    fake_features_sum = None
    fake_features_sq_sum = None
    fake_count = 0
    
    print("Processing generated images...")
    # Process fake images batch by batch
    with torch.no_grad():
        for i in range(0, num_samples, batch_size):
            batch_size_i = min(batch_size, num_samples - i)
            z = torch.randn(batch_size_i, latent_dim).to(device)
            fake_batch = generator(z)
            
            fake_batch = preprocess_for_inception(fake_batch)
            
            features = feature_extractor(fake_batch).cpu().numpy()
            
            if fake_features_sum is None:
                fake_features_sum = features.sum(axis=0)
                fake_features_sq_sum = np.dot(features.T, features)
            else:
                fake_features_sum += features.sum(axis=0)
                fake_features_sq_sum += np.dot(features.T, features)
            
            fake_count += features.shape[0]
            
            # Free memory
            del fake_batch, features
            torch.cuda.empty_cache()
    
    mu2 = fake_features_sum / fake_count
    sigma2 = fake_features_sq_sum / fake_count - np.outer(mu2, mu2)
    
    # Calculate FID
    print("Calculating final FID score...")
    ssdiff = np.sum((mu1 - mu2) ** 2.0)

    covmean = sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    
    return fid


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
    
cifar10_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar10_transform)
    
latent_dim = 100
generator = Generator(latent_dim=latent_dim, ngf=64, image_channels=3).to(device)
generator.load_state_dict(torch.load('models/cifar10_subset_90_percent/generator.pth'))
generator.eval()
    
fid = calculate_inception_fid(generator, train_dataset, num_samples=10000, latent_dim=latent_dim, device=device)
print(f"Inception V3 FID Score: {fid:.4f}")

Using device: cuda
Files already downloaded and verified


  generator.load_state_dict(torch.load('models/cifar10_subset_90_percent/generator.pth'))


Processing real images...
Processing generated images...
Calculating final FID score...
Inception V3 FID Score: 31.4908
