In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from scipy.linalg import sqrtm


import lpips
import torch

# Load LPIPS model (AlexNet by default)
lpips_model = lpips.LPIPS(net='alex')  # or 'vgg', 'squeeze'
lpips_model = lpips_model.eval().cuda()  # if you have GPU

# Load InceptionV3 for FID
inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))

def preprocess_for_fid(img_batch):
    img_batch = tf.image.resize(img_batch, (299, 299))
    img_batch = preprocess_input(img_batch * 255.0)
    return img_batch

def get_activations(images):
    images = preprocess_for_fid(images)
    activations = inception_model.predict(images, verbose=0)
    return activations

def calculate_fid(act1, act2):
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

def compute_psnr_ssim(generator, val_dataset):
    psnr_scores, ssim_scores = [], []
    for lr_img, hr_img in val_dataset:
        sr_img = generator(lr_img, training=False)
        sr_img = tf.clip_by_value(sr_img, 0.0, 1.0)
        hr_img = tf.clip_by_value(hr_img, 0.0, 1.0)
        psnr = tf.image.psnr(hr_img, sr_img, max_val=1.0).numpy()
        ssim = tf.image.ssim(hr_img, sr_img, max_val=1.0).numpy()
        psnr_scores.extend(psnr.flatten())
        ssim_scores.extend(ssim.flatten())
    return np.mean(psnr_scores), np.mean(ssim_scores)

def compute_fid(generator, val_dataset, max_batches=50):
    real_images, fake_images = [], []
    for i, (lr, hr) in enumerate(val_dataset):
        if i >= max_batches:
            break
        sr = generator(lr, training=False)
        sr = tf.clip_by_value(sr, 0.0, 1.0)
        hr = tf.clip_by_value(hr, 0.0, 1.0)
        real_images.append(hr[0])
        fake_images.append(sr[0])
    real_images = tf.stack(real_images)
    fake_images = tf.stack(fake_images)
    act1 = get_activations(real_images)
    act2 = get_activations(fake_images)
    return calculate_fid(act1, act2)

def compute_lpips(generator, val_dataset, max_batches=50):
    lpips_scores = []
    for i, (lr, hr) in enumerate(val_dataset):
        if i >= max_batches:
            break

        sr = generator(lr, training=False)
        sr = tf.clip_by_value(sr, 0.0, 1.0)
        hr = tf.clip_by_value(hr, 0.0, 1.0)

        # Convert TensorFlow tensors to NumPy arrays, then to PyTorch tensors
        sr_np = sr[0].numpy().transpose(2, 0, 1)  # [C, H, W]
        hr_np = hr[0].numpy().transpose(2, 0, 1)

        sr_tensor = torch.from_numpy(sr_np).unsqueeze(0).float().cuda()
        hr_tensor = torch.from_numpy(hr_np).unsqueeze(0).float().cuda()

        # Normalize to [-1, 1] for LPIPS
        sr_tensor = sr_tensor * 2 - 1
        hr_tensor = hr_tensor * 2 - 1

        lpips_score = lpips_model(sr_tensor, hr_tensor).item()
        lpips_scores.append(lpips_score)

    avg_lpips = sum(lpips_scores) / len(lpips_scores)
    print(f"Average LPIPS: {avg_lpips:.4f}")
    return avg_lpips


if __name__ == "__main__":
    # Load your models here
    from srgan import Generator  # example import
    from srwgan import GeneratorSRWGAN  # if custom object needed

    wgan = load_model("checkpoints/wgan_generator_epoch_200.h5")
    srgan = Generator(10)
    srgan.load_weights("checkpoints/srgan_generator_epoch_200.h5")
    srwgan = load_model("checkpoints/srwgan_generator_epoch_200.h5", custom_objects={'GeneratorSRWGAN': GeneratorSRWGAN})

    from dataset import val_dataset  # Ensure this exists and is preprocessed
    val_batches = list(val_dataset)
    for name, model in zip(["WGAN", "SRGAN", "SRWGAN"], [wgan, srgan, srwgan]):
        psnr, ssim = compute_psnr_ssim(model, val_batches)
        fid = compute_fid(model, val_batches)
        lpips = compute_lpips(model, val_batches)
        print(f"{name}: PSNR={psnr:.4f}, SSIM={ssim:.4f}, FID={fid:.4f}, LPIPS={lpips:.4f}")
