<a href="https://colab.research.google.com/github/RohithThiru/AI-projects/blob/main/Quality-of%20GAN-Generated-Images-Using-FID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 1: installs & imports
!pip install -q tensorflow==2.12.0 tensorflow-hub==0.14.0 pillow matplotlib tqdm

import os, io, zipfile, math, random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input as inception_preprocess
from tensorflow.keras.preprocessing import image as kimage
from google.colab import files

print("TensorFlow version:", tf.__version__)


In [None]:
# Cell 2: utilities
def load_image_paths_from_dir(dir_path, exts=("png","jpg","jpeg")):
    files = []
    for root, _, fnames in os.walk(dir_path):
        for f in fnames:
            if f.lower().split('.')[-1] in exts:
                files.append(os.path.join(root, f))
    files.sort()
    return files

def load_and_preprocess(img_path, target_size=(299,299)):
    img = Image.open(img_path).convert("RGB")
    img = img.resize(target_size, Image.BICUBIC)
    arr = np.asarray(img).astype("float32")
    return arr

def batch_load_images(paths, target_size=(299,299), max_images=None):
    if max_images:
        paths = paths[:max_images]
    imgs = [load_and_preprocess(p, target_size) for p in paths]
    return np.stack(imgs, axis=0)

def show_grid(real_imgs, gen_imgs, n=8, title=""):
    # expects arrays in [0,255] or [0,1] floats; convert to [0,1]
    def norm(x):
        x = np.clip(x, 0, 255)
        if x.max() > 1.0:
            x = x / 255.0
        return x
    real_imgs = norm(real_imgs)
    gen_imgs = norm(gen_imgs)
    plt.figure(figsize=(16, 4))
    for i in range(n):
        plt.subplot(2, n, i+1)
        plt.imshow(real_imgs[i])
        plt.axis("off")
        if i==0: plt.title("Real")
        plt.subplot(2, n, n+i+1)
        plt.imshow(gen_imgs[i])
        plt.axis("off")
        if i==0: plt.title("Generated")
    plt.suptitle(title)
    plt.show()


In [None]:
# Cell 3: FID computation utilities
from scipy import linalg

# load InceptionV3 up to pool_3 (global average pooling)
def get_inception_model():
    base = InceptionV3(include_top=False, weights='imagenet', pooling='avg', input_shape=(299,299,3))
    return base

inception_model = get_inception_model()

def get_activations(images, batch_size=32):
    # images: numpy array with shape (N, H, W, 3), values in [0,255] or [0,1]
    x = images.copy().astype("float32")
    if x.max() <= 1.0:
        x = x*255.0
    # preprocess as Inception expects
    x = inception_preprocess(x)
    acts = inception_model.predict(x, batch_size=batch_size, verbose=0)
    return acts

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    # from original FID implementation
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)
    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    tr_covmean = np.trace(covmean)
    fid = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2*tr_covmean
    return float(fid)

def compute_fid_from_images(real_images, gen_images, batch_size=32):
    acts1 = get_activations(real_images, batch_size=batch_size)
    acts2 = get_activations(gen_images, batch_size=batch_size)
    mu1, sigma1 = acts1.mean(axis=0), np.cov(acts1, rowvar=False)
    mu2, sigma2 = acts2.mean(axis=0), np.cov(acts2, rowvar=False)
    fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return fid


In [None]:
# Cell 4: MODE selection
# MODE = "UPLOAD"  -> you upload a zip/folder of generated images (recommended)
# MODE = "DEMO"    -> notebook trains a tiny DCGAN quickly to produce samples (for demo only)
MODE = "UPLOAD"  # change to "DEMO" to run demo generator training

# Choose which real dataset to compare against (CIFAR10 or CelebA if available)
REAL_DATASET = "CIFAR10"  # "CIFAR10"

# helper to fetch real images (CIFAR10)
if REAL_DATASET == "CIFAR10":
    (xtr, _), (xte, _) = tf.keras.datasets.cifar10.load_data()
    # use a reasonable subset as real reference (e.g., 5000 images)
    real_images_full = xtr.astype("float32")
    print("CIFAR10 loaded, total real images available:", real_images_full.shape[0])
else:
    raise ValueError("Unsupported dataset")


In [None]:
# Cell 5A: Upload generated images (zipped) - run if MODE=="UPLOAD"
if MODE != "UPLOAD":
    print("Skipping upload cell (MODE != UPLOAD)")
else:
    print("Please upload a zip file of generated images (images inside root of zip).")
    uploaded = files.upload()  # uses browser file chooser
    # handle uploaded file
    zip_path = None
    for fn in uploaded:
        if fn.lower().endswith(".zip"):
            zip_path = fn
            break
    if zip_path is None:
        raise RuntimeError("Please upload a zip file of generated images (.zip).")
    extract_dir = "/content/gen_images"
    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(extract_dir)
    print("Extracted to", extract_dir)
    gen_paths = load_image_paths_from_dir(extract_dir)
    print("Found generated images:", len(gen_paths))
    # load a subset (or all if small)
    N_GEN = min(len(gen_paths), 5000)
    gen_images = batch_load_images(gen_paths, target_size=(299,299), max_images=N_GEN)
    # sample same number of real images randomly
    real_sample_indices = np.random.choice(real_images_full.shape[0], N_GEN, replace=False)
    real_subset = real_images_full[real_sample_indices]
    # resize real images to 299x299
    real_resized = np.stack([np.asarray(Image.fromarray(img.astype('uint8')).resize((299,299), Image.BICUBIC)) for img in real_subset], axis=0)
    print("Prepared real and generated arrays:", real_resized.shape, gen_images.shape)


In [None]:
# Cell 5B: Quick demo GAN training to generate sample images (MODE == "DEMO")
if MODE != "DEMO":
    print("Skipping demo GAN training (MODE != DEMO)")
else:
    # Simple DCGAN on CIFAR10 -- very small and fast for demonstration only
    import tensorflow.keras as keras
    IMG_SHAPE = (32,32,3)
    LATENT_DIM = 100

    # prepare CIFAR images scaled [-1,1]
    x_real = (real_images_full / 127.5) - 1.0
    dataset = tf.data.Dataset.from_tensor_slices(x_real).shuffle(10000).batch(128)

    # generator
    def build_generator():
        model = keras.Sequential([
            keras.layers.Dense(8*8*128, input_dim=LATENT_DIM),
            keras.layers.Reshape((8,8,128)),
            keras.layers.BatchNormalization(),
            keras.layers.UpSampling2D(),
            keras.layers.Conv2D(128, kernel_size=3, padding="same", activation="relu"),
            keras.layers.BatchNormalization(),
            keras.layers.UpSampling2D(),
            keras.layers.Conv2D(64, kernel_size=3, padding="same", activation="relu"),
            keras.layers.Conv2D(3, kernel_size=3, padding="same", activation="tanh")
        ])
        return model

    # discriminator
    def build_discriminator():
        model = keras.Sequential([
            keras.layers.Conv2D(64, 3, strides=2, input_shape=IMG_SHAPE, padding="same"),
            keras.layers.LeakyReLU(0.2),
            keras.layers.Dropout(0.3),
            keras.layers.Conv2D(128, 3, strides=2, padding="same"),
            keras.layers.LeakyReLU(0.2),
            keras.layers.Dropout(0.3),
            keras.layers.Flatten(),
            keras.layers.Dense(1, activation='sigmoid')
        ])
        return model

    gen = build_generator()
    disc = build_discriminator()
    bce = keras.losses.BinaryCrossentropy()
    gen_opt = keras.optimizers.Adam(0.0002, 0.5)
    disc_opt = keras.optimizers.Adam(0.0002, 0.5)

    # training loop (very short)
    EPOCHS = 6
    for epoch in range(EPOCHS):
        for real_batch in dataset.take(300):  # limit batches so it's quick
            bs = real_batch.shape[0]
            noise = tf.random.normal((bs, LATENT_DIM))
            fake = gen(noise, training=True)
            # train disc
            with tf.GradientTape() as tape:
                real_logits = disc(real_batch, training=True)
                fake_logits = disc(fake, training=True)
                d_loss = bce(tf.ones_like(real_logits), real_logits) + bce(tf.zeros_like(fake_logits), fake_logits)
            grads = tape.gradient(d_loss, disc.trainable_variables)
            disc_opt.apply_gradients(zip(grads, disc.trainable_variables))
            # train gen
            noise = tf.random.normal((bs, LATENT_DIM))
            with tf.GradientTape() as tape2:
                gen_imgs = gen(noise, training=True)
                logits = disc(gen_imgs, training=True)
                g_loss = bce(tf.ones_like(logits), logits)
            grads2 = tape2.gradient(g_loss, gen.trainable_variables)
            gen_opt.apply_gradients(zip(grads2, gen.trainable_variables))
        print(f"Epoch {epoch+1}/{EPOCHS} done")

    # generate samples
    N_GEN = 1000
    noise = tf.random.normal((N_GEN, LATENT_DIM))
    gen_out = gen.predict(noise, batch_size=64)
    # gen_out in [-1,1] -> convert to [0,255] uint8
    gen_out = ((gen_out + 1.0) * 127.5).astype('uint8')
    # resize to 299x299 for Inception
    gen_images = np.stack([np.asarray(Image.fromarray(img).resize((299,299), Image.BICUBIC)) for img in gen_out], axis=0)
    # sample same number of real images
    real_subset_idx = np.random.choice(real_images_full.shape[0], N_GEN, replace=False)
    real_subset = real_images_full[real_subset_idx]
    real_resized = np.stack([np.asarray(Image.fromarray(img.astype('uint8')).resize((299,299), Image.BICUBIC)) for img in real_subset], axis=0)
    print("Generated demo images:", gen_images.shape)


In [None]:

assert 'gen_images' in globals() and 'real_resized' in globals(), "gen_images or real_resized not prepared. Run upload or demo steps."

gen_images_arr = gen_images.astype("float32")
real_images_arr = real_resized.astype("float32")
N = min(real_images_arr.shape[0], gen_images_arr.shape[0], 2000)
print("Using N images per set for FID:", N)
real_subset_small = real_images_arr[:N]
gen_subset_small = gen_images_arr[:N]

show_grid(real_subset_small/255.0, gen_subset_small/255.0, n=8, title="Real vs Generated samples (first 8)")


print("Computing FID (this may take some time)...")
fid_value = compute_fid_from_images(real_subset_small, gen_subset_small, batch_size=32)
print("FID:", fid_value)


In [None]:
pdf_path = "fid_report.pdf"
with PdfPages(pdf_path) as pdf:
    fig = plt.figure(figsize=(8.27, 11.69))
    plt.axis("off")
    plt.text(0.01, 0.92, "FID Evaluation Report", fontsize=18, weight='bold')
    plt.text(0.01, 0.86, f"Mode: {MODE}    Real dataset: {REAL_DATASET}", fontsize=10)
    plt.text(0.01, 0.80, "Metric: Fr√©chet Inception Distance (FID)\n\nFID measures the distance between multivariate Gaussians fit to Inception features of real and generated images. Lower is better (closer to real data distribution).", fontsize=10)
    plt.text(0.01, 0.55, f"Computed FID (N={N}): {fid_value:.4f}", fontsize=12)
    plt.text(0.01, 0.45, "Notes:\n- Use at least ~1k-5k images for stable FID.\n- This report contains sample visuals and the computed FID.", fontsize=10)
    pdf.savefig(fig, bbox_inches='tight')
    plt.close(fig)

    fig = plt.figure(figsize=(11,5))
    M = 8
    for i in range(M):
        plt.subplot(2, M, i+1)
        plt.imshow(real_subset_small[i].astype('uint8'))
        plt.axis("off")
        if i==0: plt.title("Real")
        plt.subplot(2, M, M+i+1)
        plt.imshow(gen_subset_small[i].astype('uint8'))
        plt.axis("off")
        if i==0: plt.title("Generated")
    plt.suptitle("Real vs Generated samples (first 8)")
    plt.tight_layout()
    pdf.savefig(fig, bbox_inches='tight', dpi=150)
    plt.close(fig)

    fig = plt.figure(figsize=(8.27, 11.69))
    plt.axis("off")
    conclusion = (
        f"Conclusions & interpretation\n\n- Computed FID: {fid_value:.4f}\n\n"
        "- Interpretation: lower FID indicates generated images are closer to the real distribution in Inception feature space.\n"
        "- Use-case notes: for serious evaluation, compute FID with ~5k images and multiple runs.\n"
        "- Visual inspection: check the visuals page for obvious artifacts, mode collapse, or diversity issues.\n"
    )
    plt.text(0.01, 0.95, "Conclusions", fontsize=14, weight='bold')
    plt.text(0.01, 0.6, conclusion, fontsize=10)
    pdf.savefig(fig, bbox_inches='tight')
    plt.close(fig)

print("Saved report to", pdf_path)
