Written By Devorah Rotman 316472026 and Carmit Kaye 346038169



In [None]:
from tqdm import tqdm
import numpy as np

import warnings
import os
from scipy import linalg
import torch
from torchvision import models, transforms
from STGAN.Data_loader_STGAN import get_content_loader, get_style_loader
from STGAN.models_STGAN import Generator
from torchvision.utils import save_image
from skimage.exposure import match_histograms
import warnings
from sklearn.metrics.pairwise import cosine_similarity

warnings.filterwarnings("ignore")


In [2]:
def unnormalize_tanh(img_tensor):
    return (img_tensor + 1) / 2

def unnormalize_vgg(img):

    mean = torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1,3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1,3, 1, 1)
    return torch.clamp(img * std + mean, 0, 1)

def tensor_to_numpy(img):
    return img.permute(1, 2, 0).detach().cpu().numpy()

def numpy_to_tensor(img_np):
    return torch.tensor(img_np).permute(2, 0, 1).float()

def gamma_correct(img, gamma=1.2):
    return img.pow(gamma).clamp(0, 1)

def match_to_content(generated_img, content_img):
    """
    Match histogram of generated image to content image.
    Both tensors should be [3, H, W], values in [0, 1]
    """
    gen_np = generated_img.permute(1, 2, 0).detach().cpu().numpy()
    con_np = content_img.permute(1, 2, 0).detach().cpu().numpy()

    matched_np = match_histograms(gen_np, con_np, channel_axis=-1)
    matched = torch.tensor(matched_np).permute(2, 0, 1).float().to(generated_img.device)
    return matched.clamp(0, 1)


class inference_and_metric():
    def __init__(self, content_path, style_path, model_path, device, in_style, channels, batch_size, generated_directory_path):
        self.content_loader = get_content_loader(content_path, batch_size)
        self.style_loader = get_style_loader(style_path, batch_size)
        self.model_path = model_path
        self.device = device
        #models - vgg, generator, feature extractor
        self.gen = Generator(in_style, channels).to(device)
        self.load_trained_model()
        self.vgg = models.vgg19(pretrained=True).features.to(self.device).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.feature_extractor =  models.inception_v3(pretrained=True, transform_input=False).to(device)
        self.feature_extractor.fc = torch.nn.Identity()
        self.save_dir = generated_directory_path


    def load_trained_model(self):
        ckpt = torch.load(self.model_path, map_location=self.device)
        for i in range(ckpt['grow_rank']):
            self.gen.grow()
            self.gen.cuda()
        self.gen.load_state_dict(ckpt['generator'])
        for param in self.gen.parameters():
            param.requires_grad = False
        del ckpt

    def get_activations_from_model(self):
        """Generate fake images from model and compute their Inception features."""
        self.vgg.eval()
        self.gen.eval()
        self.feature_extractor.eval()
        activations = []
        os.makedirs(self.save_dir, exist_ok=True) if self.save_dir is not None else None
        counter = 0
        transform = transforms.ColorJitter(contrast=0.5)

        with torch.no_grad():
            for inputs, _ in tqdm(self.content_loader):
                inputs = inputs.to(self.device)

                # Infer fake images using cGAN
                feats = self.vgg(inputs)
                generated = self.gen(feats)# assumes output shape [B, 3, H, W]

                #possibly save the generated images
                if self.save_dir is not None:
                    for i, img in enumerate(generated):
                        save_path = os.path.join(self.save_dir, f"{counter:05}.png")
                        #orig = inputs[i]
                        img = unnormalize_tanh(img)**1.8
                        ##img = transform(img)
                        #orig = unnormalize_vgg(orig)
                        #matched = match_to_content(img, orig)
                        #img = contrast_stretch(img)
                        #img = gamma_correct(img)
                        save_image(img.clamp( 0, 1), save_path)
                        #save_image(matched.clamp(0, 1), save_path)
                        #save_image(matched, save_path)
                        counter += 1
                #get features for FID
                if generated.shape[-1] != 299:
                    generated = torch.nn.functional.interpolate(generated, size=(299, 299), mode='bilinear', align_corners=False)

                feats = self.feature_extractor(generated)
                activations.append(feats.cpu().numpy())

        return np.concatenate(activations, axis=0)

    def get_activations_from_real(self):
        """Compute Inception features for real images (e.g., style images)."""
        self.feature_extractor.eval()
        activations = []

        with torch.no_grad():
            for images, _ in tqdm(self.style_loader):
                images = images.to(self.device)
                if images.shape[-1] != 299:
                    images = torch.nn.functional.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)

                feats = self.feature_extractor(images)
                activations.append(feats.cpu().numpy())

        return np.concatenate(activations, axis=0)

    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2):
        """Standard FID (MiFID) computation."""
        diff = mu1 - mu2
        covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        return diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)

    def compute_mifid(self, generated_feats, real_feats):
        """
        Compute Memorization-Informed FID (MiFID).
        Args:
            generated_feats (np.ndarray): shape (N_gen, 2048)
            real_feats (np.ndarray): shape (N_real, 2048)
        Returns:
            float: MiFID score
        """
        similarities = cosine_similarity(generated_feats, real_feats)  # Shape: [N_gen, N_real]
        nearest_indices = np.argmax(similarities, axis=1)              # Most similar real image for each generated one
        matched_real_feats = real_feats[nearest_indices]               # Shape: [N_gen, 2048]

        mu_gen = np.mean(generated_feats, axis=0)
        sigma_gen = np.cov(generated_feats, rowvar=False)

        mu_real = np.mean(matched_real_feats, axis=0)
        sigma_real = np.cov(matched_real_feats, rowvar=False)

        return self.calculate_frechet_distance(mu_gen, sigma_gen, mu_real, sigma_real)

    def compute_mifid_from_dataloaders(self):
        # InceptionV3 up to pool3 (2048D)

        print("Extracting features from generated images...")
        gen_acts = self.get_activations_from_model()
        print("Extracting features from real style images...")
        style_acts = self.get_activations_from_real()


        mifid_score = self.compute_mifid(gen_acts, style_acts)
        print("MiFID score:", mifid_score)
        return mifid_score






In [3]:
batch_size = 32
style_path = 'data/monet_jpg'
content_path = 'data/photo_jpg'
#model_path = 'STGAN/Samples/real_run_1/model.pt'
model_path = 'STGAN/checkpoints/last.pt'
in_style = 512 #size of w
channels = [512, 512, 512, 256, 128, 64, 32] # layer channel sizes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#generated_directory_path = 'STGAN/Samples/real_run_1/generated'
generated_directory_path = None

inferer = inference_and_metric(content_path, style_path, model_path, device, in_style, channels, batch_size, generated_directory_path)

inferer.compute_mifid_from_dataloaders()

Extracting features from generated images...


100%|██████████| 219/219 [00:54<00:00,  3.98it/s]


Extracting features from real style images...


100%|██████████| 9/9 [00:05<00:00,  1.70it/s]


MiFID score: 109.8110748614309


np.float64(109.8110748614309)