In [1]:
import os
import numpy as np
import sklearn.metrics
import sklearn
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import vgg16
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.preprocessing import normalize

In [None]:
class RandomEmbeddingR64(nn.Module):
    def __init__(self, input_channels=3):
        super().__init__()
        
        self.backbone = vgg16(pretrained=False).features
        
        if input_channels != 3:
            self._modify_first_conv(input_channels)
            
    
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096), 
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 64),      
            nn.ReLU(True),
            nn.Dropout(p=0.5),
        )
        
        
        self._initialize_weights()

    def _modify_first_conv(self, in_channels):
        
        old_conv = self.backbone[0]
        new_conv = nn.Conv2d(
            in_channels, 
            old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding
        )
        self.backbone[0] = new_conv

    def _initialize_weights(self):
       
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
       
        x = self.backbone(x)            # [batch, 512, 7, 7]
        x = torch.flatten(x, 1)         # [batch, 512*7*7]
        x = self.classifier(x)          # [batch, 64]
        return x

In [None]:
class GeneratedImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith((".png", ".jpg"))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("L")  
        if self.transform:
            image = self.transform(image)
        return image

In [None]:

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Grayscale(num_output_channels=1), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.1307], std=[0.3081])
])

In [None]:

from torchvision.datasets import FashionMNIST


real_dataset = FashionMNIST(root='./fashiondata', train=False, download=True, transform=transform)
real_loader = DataLoader(real_dataset, batch_size=64, shuffle=False)


hmvae_generated_dataset = GeneratedImageDataset(
    folder_path="generated_images_hmvae_fasion",
    transform=transform
)
hmvae_generated_loader = DataLoader(hmvae_generated_dataset, batch_size=64, shuffle=False)

memvae_generated_dataset = GeneratedImageDataset(
    folder_path="generated_images_MEMvae_1layers_fasion",
    transform=transform
)
memvae_generated_loader = DataLoader(memvae_generated_dataset, batch_size=64, shuffle=False)

vae_generated_dataset = GeneratedImageDataset(
    folder_path="generated_images_vae_fasion",
    transform=transform
)
vae_generated_loader = DataLoader(vae_generated_dataset, batch_size=64, shuffle=False)

In [None]:
def extract_realfeatures(loader, model, device="cuda:1"):
    model.eval()
    features = []
    with torch.no_grad():
        for data, _ in loader:
            data = data.to(device)
            feat = model(data)  
            features.append(feat.cpu().numpy())
    return np.concatenate(features, axis=0)
def extract_fakefeatures(loader, model, device="cuda:1"):
    model.eval()
    features = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            feat = model(data)  
            features.append(feat.cpu().numpy())
    return np.concatenate(features, axis=0)

device = torch.device("cuda:1")
model_r64 = RandomEmbeddingR64(input_channels=1).to(device)


real_features = extract_realfeatures(real_loader, model_r64,)
hmvae_generated_features = extract_fakefeatures(hmvae_generated_loader, model_r64)
memvae_generated_features=extract_fakefeatures(memvae_generated_loader, model_r64)
vae_generated_features=extract_fakefeatures(vae_generated_loader, model_r64)

In [7]:
def compute_pairwise_distance(data_x, data_y=None):
    """
    Args:
        data_x: numpy.ndarray([N, feature_dim], dtype=np.float32)
        data_y: numpy.ndarray([N, feature_dim], dtype=np.float32)
    Returns:
        numpy.ndarray([N, N], dtype=np.float32) of pairwise distances.
    """
    if data_y is None:
        data_y = data_x
    dists = sklearn.metrics.pairwise_distances(
        data_x, data_y, metric='euclidean', n_jobs=8)
    return dists


def get_kth_value(unsorted, k, axis=-1):
    """
    Args:
        unsorted: numpy.ndarray of any dimensionality.
        k: int
    Returns:
        kth values along the designated axis.
    """
    indices = np.argpartition(unsorted, k, axis=axis)[..., :k]
    k_smallests = np.take_along_axis(unsorted, indices, axis=axis)
    kth_values = k_smallests.max(axis=axis)
    return kth_values


def compute_nearest_neighbour_distances(input_features, nearest_k):
    """
    Args:
        input_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int
    Returns:
        Distances to kth nearest neighbours.
    """
    distances = compute_pairwise_distance(input_features)
    radii = get_kth_value(distances, k=nearest_k + 1, axis=-1)
    return radii


def compute_prdc(real_features, fake_features, nearest_k):
    """
    Computes precision, recall, density, and coverage given two manifolds.

    Args:
        real_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32)
        nearest_k: int.
    Returns:
        dict of precision, recall, density, and coverage.
    """

    print('Num real: {} Num fake: {}'
          .format(real_features.shape[0], fake_features.shape[0]))

    real_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        real_features, nearest_k)
    fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(
        fake_features, nearest_k)
    distance_real_fake = compute_pairwise_distance(
        real_features, fake_features)

    precision = (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).any(axis=0).mean()

    recall = (
            distance_real_fake <
            np.expand_dims(fake_nearest_neighbour_distances, axis=0)
    ).any(axis=1).mean()

    density = (1. / float(nearest_k)) * (
            distance_real_fake <
            np.expand_dims(real_nearest_neighbour_distances, axis=1)
    ).sum(axis=0).mean()

    coverage = (
            distance_real_fake.min(axis=1) <
            real_nearest_neighbour_distances
    ).mean()

    return dict(precision=precision, recall=recall,
                density=density, coverage=coverage)

In [None]:
prdc_metrics = compute_prdc(real_features, hmvae_generated_features, nearest_k=10)
print(f"Precision: {prdc_metrics['precision']:.4f}")
print(f"Recall: {prdc_metrics['recall']:.4f}")
print(f"Density: {prdc_metrics['density']:.4f}")
print(f"Coverage: {prdc_metrics['coverage']:.4f}")

In [None]:
prdc_metrics = compute_prdc(real_features,memvae_generated_features, nearest_k=10)
print(f"Precision: {prdc_metrics['precision']:.4f}")
print(f"Recall: {prdc_metrics['recall']:.4f}")
print(f"Density: {prdc_metrics['density']:.4f}")
print(f"Coverage: {prdc_metrics['coverage']:.4f}")

In [None]:
prdc_metrics = compute_prdc(real_features,vae_generated_features, nearest_k=10)
print(f"Precision: {prdc_metrics['precision']:.4f}")
print(f"Recall: {prdc_metrics['recall']:.4f}")
print(f"Density: {prdc_metrics['density']:.4f}")
print(f"Coverage: {prdc_metrics['coverage']:.4f}")