In [None]:
###
# This code calculates FID (Fréchet Inception Distance) and visualizes it on a matrix.
###

In [None]:
import numpy as np
import torch
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
import logging
from torch.nn.functional import adaptive_avg_pool2d
import torchvision.models as models
from PIL import Image
import torchvision.transforms as transforms
import glob
import os
from tqdm import tqdm
import psutil
import gc

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

In [None]:
base_dir = '/PATH/TO/YOUR/DIRECTORY'
hospital_paths = sorted(glob.glob(base_dir + '/*'))
hospital_names = [os.path.basename(path) for path in hospital_paths]

In [None]:
def load_images_from_folder(folder: str, max_images: int = None, image_size: Tuple[int, int] = (299, 299)) -> torch.Tensor:
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image_paths = glob.glob(os.path.join(folder, '*.png'))
    if max_images:
        image_paths = image_paths[:max_images]
    
    images = []
    for img_path in tqdm(image_paths, desc=f"Loading images from {folder}"):
        try:
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            images.append(image)
        except Exception as e:
            logger.warning(f"Error loading image {img_path}: {str(e)}")
    
    return torch.stack(images) if images else torch.empty(0)

def get_inception_features(images: torch.Tensor, model: torch.nn.Module, device: torch.device) -> np.ndarray:
    model.eval()
    features_list = []
    batch_size = 32
    
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size].to(device)
            features = model(batch)
            features = features.cpu().numpy()
            features_list.append(features)
    
    return np.concatenate(features_list, axis=0)

def extract_all_features(hospital_images: Dict[str, torch.Tensor], inception: torch.nn.Module, device: torch.device) -> Dict[str, np.ndarray]:
    features_dict = {}
    for name, images in hospital_images.items():
        if images.nelement() == 0:
            logger.warning(f"Empty dataset for {name}, skipping")
            continue
        logger.info(f"Extracting features for {name}")
        features = get_inception_features(images, inception, device)
        features_dict[name] = features
    return features_dict

def calculate_fid(mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray) -> float:
    mu1 = np.asarray(mu1).ravel()
    mu2 = np.asarray(mu2).ravel()
    sigma1 = np.atleast_2d(np.asarray(sigma1))
    sigma2 = np.atleast_2d(np.asarray(sigma2))

    # Add small epsilon for numerical stability
    epsilon = 1e-6
    sigma1 += np.eye(sigma1.shape[0]) * epsilon
    sigma2 += np.eye(sigma2.shape[0]) * epsilon

    diff = mu1 - mu2
    covmean = sqrtm(sigma1.dot(sigma2))

    if np.iscomplexobj(covmean):
        if not np.allclose(np.zeros_like(covmean.imag), covmean.imag, rtol=1e-3):
            print("- WARNING - Imaginary component in covmean is not negligible")
        covmean = covmean.real

    fid = np.sum(diff ** 2) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
    
    return float(fid)

def bootstrap_fid(hospital_features: dict, num_samples: int = 1000, sample_size: int = 100):
    hospital_statistics = {}
    for name, features in tqdm(hospital_features.items(), desc="Bootstrapping FID"):
        if features.shape[0] == 0:
            continue  

        mu_list, sigma_list = [], []
        for _ in range(num_samples):
            sampled_indices = np.random.choice(len(features), sample_size, replace=True)
            sampled_features = np.take(features, sampled_indices, axis=0)  

            mu, sigma = np.mean(sampled_features, axis=0), np.cov(sampled_features, rowvar=False)
            mu_list.append(mu)
            sigma_list.append(sigma)

        mu_array = np.memmap(f"{name}_mu.dat", dtype=np.float32, mode="w+", shape=(num_samples, features.shape[1]))
        sigma_array = np.memmap(f"{name}_sigma.dat", dtype=np.float32, mode="w+", shape=(num_samples, features.shape[1], features.shape[1]))

        mu_array[:] = mu_list
        sigma_array[:] = sigma_list

        hospital_statistics[name] = {'mu': mu_array, 'sigma': sigma_array}

        del sampled_features, sampled_indices, mu_list, sigma_list
        gc.collect()
        torch.cuda.empty_cache()

    return hospital_statistics

def calculate_fid_matrix(fid_results: Dict, hospital_names: List[str]) -> np.ndarray:
    n_hospitals = len(hospital_names)
    fid_matrix = np.zeros((n_hospitals, n_hospitals))
    
    for i, name1 in enumerate(hospital_names):
        for j, name2 in enumerate(hospital_names):
            if i < j:
                try:
                    mu1 = np.mean(fid_results[name1]['mu'], axis=0)
                    mu2 = np.mean(fid_results[name2]['mu'], axis=0)
                    
                    sigma1 = np.mean(fid_results[name1]['sigma'], axis=0)
                    sigma2 = np.mean(fid_results[name2]['sigma'], axis=0)
                    
                    fid_value = calculate_fid(mu1, sigma1, mu2, sigma2)
                    fid_matrix[i, j] = fid_value
                    fid_matrix[j, i] = fid_value
                    
                except Exception as e:
                    logger.warning(f"Error calculating FID between {name1} and {name2}: {str(e)}")
                    fid_matrix[i, j] = fid_matrix[j, i] = np.nan
    
    return fid_matrix

In [None]:
inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT).to(device)
inception.fc = torch.nn.Identity()

hospital_images = {name: load_images_from_folder(path) 
                  for name, path in zip(hospital_names, hospital_paths)}

features_dict = extract_all_features(hospital_images, inception, device)

In [None]:
fid_results = bootstrap_fid(features_dict)

In [None]:
fid_matrix = calculate_fid_matrix(fid_results, hospital_names)
fid_matrix

In [None]:
print("FID Matrix:")
print(fid_matrix)
print("\nMatrix statistics:")
print(f"Min value (excluding diagonal): {np.min(fid_matrix[~np.eye(fid_matrix.shape[0], dtype=bool)])}")
print(f"Max value: {np.max(fid_matrix[~np.eye(fid_matrix.shape[0], dtype=bool)])}")
print(f"Mean value: {np.mean(fid_matrix[~np.eye(fid_matrix.shape[0], dtype=bool)])}")

plt.figure(figsize=(12, 10))
mask = np.isnan(fid_matrix)
sns.heatmap(fid_matrix, 
           xticklabels=hospital_names, 
           yticklabels=hospital_names,
           mask=mask,
           cmap='viridis', 
           annot=True,
           fmt='.2f')

plt.title("FID Heatmap Between Hospitals")
plt.tight_layout()
plt.savefig("fid_heatmap.png", dpi=300, bbox_inches='tight')
plt.show()


np.save("fid_matrix.npy", fid_matrix)
logger.info("FID analysis completed successfully")