In [85]:
import os
import shutil
import random
from datetime import datetime
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.utils as torch_utils
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import numpy as np
from PIL import Image
from scipy.linalg import sqrtm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [86]:
# Generator
class Generator(nn.Module):
    def __init__(self, classes):
        super(Generator, self).__init__()
        self.classes = classes

        self.embedding = nn.Linear(classes, 8*8)

        self.latent_vector = nn.Sequential(
            nn.Linear(100, 512*8*8),
            nn.LeakyReLU(0.2, inplace=True),
        )

        upsample_1 = self.upsample_block(513, 256, 1)
        upsample_2 = self.upsample_block(256, 128, 1)
        upsample_3 = self.upsample_block(128, 64, 1)

        self.conv_model = nn.Sequential(
            upsample_1,
            upsample_2,
            upsample_3,
            nn.Conv2d(64, 3, (1, 1), 1, 0),
            nn.Tanh()
        )
    
    def upsample_block(self, in_channels, out_channels, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, (4, 4), 2, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x, label):
        latent_vector = self.latent_vector(x).view(-1, 512, 8, 8)
        label_embedding = self.embedding(label).view(-1, 1, 8, 8)
        comb_latent_vector = torch.concat((latent_vector, label_embedding), dim = 1)
        output = self.conv_model(comb_latent_vector)
        return output

In [87]:
def save_images(images, output_dir):
    print('Saving Images...')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    for idx, img in enumerate(images):
        img = img.permute(1, 2, 0)
        img = img.numpy()
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
        filename = f'image_{timestamp}_{idx}.png'
        filepath = os.path.join(output_dir, filename)
        plt.imsave(filepath, img)
    print("Done!")
    return True

def load_images_as_tensors(image_paths):
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert PIL Image to Tensor
    ])
    
    image_tensors = []
    
    for path in image_paths:
        image = Image.open(path).convert('RGB')  # Open image and convert to RGB
        image_tensors.append(image)
    
    return image_tensors

In [88]:
def sample_images_from_directories(folder_path, sample_size=10):
    sampled_image_paths = []
    
    # List all directories in the given folder
    for root, dirs, files in os.walk(folder_path):
        for dir_name in dirs:
            dir_path = os.path.join(root, dir_name)
            # List all files in the directory
            all_files = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
            # Sample 10 files from the list
            sampled_files = random.sample(all_files, min(sample_size, len(all_files)))
            # Add sampled file paths to the result list
            sampled_image_paths.extend(sampled_files)
    
    return sampled_image_paths

In [89]:
def eval(classes, model_path, num_images, output_dir, grid=True):
    fixed_noise = torch.randn(classes*10, 100, device=device)
    fixed_labels = []
    for i in range(classes):
        lab = [0 if j != i else 1 for j in range(classes)]
        lab = lab*10
        fixed_labels.append(lab)
    fixed_labels = torch.Tensor(fixed_labels).view(classes*10, classes).float().to(device)

    gen_net = Generator(classes)
    gen_net.to(device)
    gen_net.load_state_dict(torch.load(model_path))

    iters = 0
    
    gen_net.eval()
    images = []
    print("Starting Inference Loop...")
    for image in range(num_images):
        with torch.no_grad():
            fake = gen_net(fixed_noise, fixed_labels).detach().cpu()
        
        if grid:
            images.append(torch_utils.make_grid(fake, padding=2, nrow=10, normalize=True))
        else:
            for i in range(fake.size(0)):
                img = fake[i, :, :, :].squeeze(0)* 255
                images.append(img.type(torch.uint8))
             
    return save_images(images, output_dir)

In [90]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]) 

# Define a custom dataset class
class CustomImageDataset(Dataset):
    def __init__(self, image_tensors):
        self.image_tensors = image_tensors

    def __len__(self):
        return len(self.image_tensors)

    def __getitem__(self, idx):
        return self.image_tensors[idx], 0  # Dummy label

# Function to extract features using InceptionV3
def get_features(dataloader, model, device):
    model.eval()
    features = []
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            pred = model(images)
            features.extend(pred.cpu().numpy())
    return np.array(features)

# Function to calculate FID
def calculate_fid(real_features, generated_features):
    # calculate mean and covariance statistics
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = generated_features.mean(axis=0), np.cov(generated_features, rowvar=False)
    # calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [91]:
# Load pre-trained InceptionV3 model
inception_model = models.inception_v3(weights='DEFAULT', transform_input=False).to(device)
inception_model.fc = torch.nn.Identity()  # Remove the final classification layer

In [92]:
shutil.rmtree('/kaggle/working/shoes_eval')

In [93]:
eval(3, '/kaggle/input/cgan-training/cgan_shoe.pt', 100, '/kaggle/working/shoes_eval', False)

Starting Inference Loop...
Saving Images...
Done!


True

In [94]:
# Example usage:
dataset_path = '/kaggle/input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset'
generated_path = '/kaggle/working/shoes_eval'
real_images = load_images_as_tensors(sample_images_from_directories(dataset_path, 300))
generated_images = load_images_as_tensors([os.path.join(generated_path, x) for x in os.listdir(generated_path)])

In [95]:
real_images = torch.stack([transform(img) for img in real_images])
generated_images = torch.stack([transform(img) for img in generated_images])

real_dataset = CustomImageDataset(real_images)
generated_dataset = CustomImageDataset(generated_images)

real_loader = DataLoader(real_dataset, batch_size=32, shuffle=False)
generated_loader = DataLoader(generated_dataset, batch_size=32, shuffle=False)

# Extract features
real_features = get_features(real_loader, inception_model, device)
generated_features = get_features(generated_loader, inception_model, device)

# Calculate FID
fid_score = calculate_fid(real_features, generated_features)
print(f"CGAN Shoes FID Score: {fid_score}")

CGAN Shoes FID Score: 24.234657


In [96]:
shutil.rmtree('/kaggle/working/flowers_eval')

In [97]:
eval(5, '/kaggle/input/cgan-training/cgan_flowers.pt', 50, '/kaggle/working/flowers_eval', False)

Starting Inference Loop...
Saving Images...
Done!


True

In [98]:
dataset_path = '/kaggle/input/flower-classification-5-classes-roselilyetc/Flower Classification/Flower Classification'
generated_path = '/kaggle/working/flowers_eval'
real_images = load_images_as_tensors(sample_images_from_directories(dataset_path, 500))
generated_images = load_images_as_tensors([os.path.join(generated_path, x) for x in os.listdir(generated_path)])

In [99]:
real_images = torch.stack([transform(img) for img in real_images])
generated_images = torch.stack([transform(img) for img in generated_images])

real_dataset = CustomImageDataset(real_images)
generated_dataset = CustomImageDataset(generated_images)

real_loader = DataLoader(real_dataset, batch_size=32, shuffle=False)
generated_loader = DataLoader(generated_dataset, batch_size=32, shuffle=False)

# Extract features
real_features = get_features(real_loader, inception_model, device)
generated_features = get_features(generated_loader, inception_model, device)

# Calculate FID
fid_score = calculate_fid(real_features, generated_features)
print(f"CGAN Flowers FID Score: {fid_score}")

CGAN Flowers FID Score: 30.6876865


In [100]:
shutil.rmtree('/kaggle/working/mnist_eval')

In [101]:
eval(10, '/kaggle/input/cgan-training/cgan_mnist.pt', 100, '/kaggle/working/mnist_eval', False)

Starting Inference Loop...
Saving Images...
Done!


True

In [102]:
dataset_path = '/kaggle/input/mnistasjpg/trainingSet/trainingSet'
generated_path = '/kaggle/working/mnist_eval'
real_images = load_images_as_tensors(sample_images_from_directories(dataset_path, 1000))
generated_images = load_images_as_tensors([os.path.join(generated_path, x) for x in os.listdir(generated_path)])

In [103]:
real_images = torch.stack([transform(img) for img in real_images])
generated_images = torch.stack([transform(img) for img in generated_images])

real_dataset = CustomImageDataset(real_images)
generated_dataset = CustomImageDataset(generated_images)

real_loader = DataLoader(real_dataset, batch_size=32, shuffle=False)
generated_loader = DataLoader(generated_dataset, batch_size=32, shuffle=False)

# Extract features
real_features = get_features(real_loader, inception_model, device)
generated_features = get_features(generated_loader, inception_model, device)

# Calculate FID
fid_score = calculate_fid(real_features, generated_features)
print(f"CGAN MNIST FID Score: {fid_score}")

CGAN MNIST FID Score: 2.4678997
