In [4]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from scipy.linalg import sqrtm
from skimage.metrics import structural_similarity
import lpips
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc



class ImageMetricsComparisonSimple:
    def __init__(self, fake_dir, inv_fake_dir, real_dir, inv_real_dir, 
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.fake_dir = fake_dir
        self.inv_fake_dir = inv_fake_dir
        self.real_dir = real_dir
        self.inv_real_dir = inv_real_dir
        self.device = device
        self.target_size = (1024, 1024)
        
        
        self.loss_fn_alex = lpips.LPIPS(net='alex').to(device)

    def calculate_psnr(self, img1, img2):
        if isinstance(img1, Image.Image):
            img1 = img1.resize(self.target_size, Image.Resampling.LANCZOS)
            img1 = np.array(img1)
        if isinstance(img2, Image.Image):
            img2 = img2.resize(self.target_size, Image.Resampling.LANCZOS)
            img2 = np.array(img2)
        
        img1 = img1.astype(np.float64)
        img2 = img2.astype(np.float64)
        
        mse = np.mean((img1 - img2) ** 2)
        if mse == 0:
            return float('inf')
        
        max_pixel = 255.0
        psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
        return psnr

    def calculate_ssim(self, img1, img2):
        if isinstance(img1, Image.Image):
            img1 = img1.resize(self.target_size, Image.Resampling.LANCZOS)
            img1 = np.array(img1)
        if isinstance(img2, Image.Image):
            img2 = img2.resize(self.target_size, Image.Resampling.LANCZOS)
            img2 = np.array(img2)
            
        if img1.dtype != np.uint8:
            img1 = (img1.clip(0, 255) + 0.5).astype(np.uint8)
        if img2.dtype != np.uint8:
            img2 = (img2.clip(0, 255) + 0.5).astype(np.uint8)
            
        if len(img1.shape) == 3:
            img1 = np.mean(img1, axis=2).astype(np.uint8)
        if len(img2.shape) == 3:
            img2 = np.mean(img2, axis=2).astype(np.uint8)
            
        return structural_similarity(img1, img2, data_range=255)

    def calculate_lpips(self, img1, img2):
        if not isinstance(img1, Image.Image):
            img1 = Image.fromarray(img1)
        if not isinstance(img2, Image.Image):
            img2 = Image.fromarray(img2)
            
        img1 = img1.resize(self.target_size, Image.Resampling.LANCZOS)
        img2 = img2.resize(self.target_size, Image.Resampling.LANCZOS)
        
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
        img1_tensor = transform(img1).unsqueeze(0).to(self.device)
        img2_tensor = transform(img2).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            lpips_value = self.loss_fn_alex(img1_tensor, img2_tensor)
        return lpips_value.item()

    def calculate_metrics_for_pair(self, dir1, dir2, desc):
        files = sorted([f for f in os.listdir(dir1) if f.endswith('.png')])
        metrics = {
            'psnr': [], 'ssim': [], 'lpips': [], 'filenames': []
        }
        
        for filename in tqdm(files, desc=desc):
            path1 = os.path.join(dir1, filename)
            path2 = os.path.join(dir2, filename)
            
            if not os.path.exists(path2):
                continue
            
            try:
                img1 = Image.open(path1).convert('RGB')
                img2 = Image.open(path2).convert('RGB')
                
                psnr = self.calculate_psnr(img1, img2)
                ssim = self.calculate_ssim(img1, img2)
                lpips_val = self.calculate_lpips(img1, img2)
                
                if all(x is not None and not np.isnan(x) for x in [psnr, ssim, lpips_val]):
                    metrics['psnr'].append(psnr)
                    metrics['ssim'].append(ssim)
                    metrics['lpips'].append(lpips_val)
                    metrics['filenames'].append(filename)
                
            except Exception:
                continue
        
        return metrics

    def plot_comparative_metrics(self, fake_metrics, real_metrics):
        if len(fake_metrics['filenames']) == 0 or len(real_metrics['filenames']) == 0:
            return

    def plot_comparative_metrics(self, fake_metrics, real_metrics):
        if len(fake_metrics['filenames']) == 0 or len(real_metrics['filenames']) == 0:
            print("Error: No data available to plot metrics. Please check if the input directories contain matching image pairs.")
            return

        fig = plt.figure(figsize=(15, 12))
        gs = plt.GridSpec(3, 2)  # Added an extra row for ROC curve
        
        metrics_list = [
            ('PSNR (Higher is better)', 'psnr', gs[0, 0], True),
            ('SSIM (Higher is better)', 'ssim', gs[0, 1], True),
            ('LPIPS (Lower is better)', 'lpips', gs[1, 0], False)
        ]
        
        # Plot histograms
        for title, metric_name, pos, higher_better in metrics_list:
            ax = fig.add_subplot(pos)
            fake_values = np.array(fake_metrics[metric_name])
            real_values = np.array(real_metrics[metric_name])
            
            fake_values = fake_values[~np.isinf(fake_values)]
            real_values = real_values[~np.isinf(real_values)]
            
            if len(fake_values) > 0 and len(real_values) > 0:
                ax.hist(fake_values, bins=30, alpha=0.5, label='Fake vs Inv_Fake', color='blue')
                ax.hist(real_values, bins=30, alpha=0.5, label='Real vs Inv_Real', color='red')
                ax.set_title(title)
                ax.legend()
                ax.grid(True, alpha=0.3)
                
                fake_stats = f'Fake μ={np.mean(fake_values):.3f}, σ={np.std(fake_values):.3f}'
                real_stats = f'Real μ={np.mean(real_values):.3f}, σ={np.std(real_values):.3f}'
                ax.text(0.02, 0.98, f'{fake_stats}\n{real_stats}',
                    transform=ax.transAxes, va='top', fontsize=8)
            else:
                ax.text(0.5, 0.5, 'No data available', 
                    ha='center', va='center', transform=ax.transAxes)

        # Plot ROC curves
        ax_roc = fig.add_subplot(gs[1, 1])
        metrics_for_roc = ['psnr', 'ssim', 'lpips']
        
        if len(fake_metrics['filenames']) > 0 and len(real_metrics['filenames']) > 0:
            for metric_name in metrics_for_roc:
                # For LPIPS, lower is better so we need to invert the scores
                scores = np.concatenate([fake_metrics[metric_name], real_metrics[metric_name]])
                if metric_name == 'lpips':
                    scores = -scores  # Invert LPIPS scores since lower is better
                
                # Create labels: 1 for fake, 0 for real
                labels = np.concatenate([np.ones(len(fake_metrics[metric_name])), 
                                    np.zeros(len(real_metrics[metric_name]))])
                
                # Remove inf values
                valid_idx = ~np.isinf(scores)
                scores = scores[valid_idx]
                labels = labels[valid_idx]
                
                if len(scores) > 0 and len(np.unique(labels)) > 1:
                    fpr, tpr, _ = roc_curve(labels, scores)
                    roc_auc = auc(fpr, tpr)
                    ax_roc.plot(fpr, tpr, label=f'{metric_name.upper()} (AUC = {roc_auc:.2f})')
            
            ax_roc.plot([0, 1], [0, 1], 'k--')
            ax_roc.set_title('ROC Curves')
            ax_roc.set_xlabel('False Positive Rate')
            ax_roc.set_ylabel('True Positive Rate')
            ax_roc.legend()
            ax_roc.grid(True, alpha=0.3)
        else:
            ax_roc.text(0.5, 0.5, 'No data available for ROC curves', 
                    ha='center', va='center')

        # Plot correlation matrix
        ax_corr = fig.add_subplot(gs[2, :])  # Make correlation matrix span both columns
        metrics_data = []
        metric_labels = ['PSNR', 'SSIM', 'LPIPS']
        for metric in ['psnr', 'ssim', 'lpips']:
            combined_metric = np.concatenate([fake_metrics[metric], real_metrics[metric]])
            combined_metric = combined_metric[~np.isinf(combined_metric)]
            if len(combined_metric) > 0:
                metrics_data.append(combined_metric)
        
        if len(metrics_data) > 0:
            metrics_data = np.array(metrics_data)
            corr_matrix = np.corrcoef(metrics_data)
            
            im = ax_corr.imshow(corr_matrix, cmap='coolwarm', aspect='auto')
            ax_corr.set_title('Metrics Correlation')
            
            ax_corr.set_xticks(range(len(metric_labels)))
            ax_corr.set_yticks(range(len(metric_labels)))
            ax_corr.set_xticklabels(metric_labels)
            ax_corr.set_yticklabels(metric_labels)
            
            for i in range(len(metric_labels)):
                for j in range(len(metric_labels)):
                    ax_corr.text(j, i, f'{corr_matrix[i, j]:.2f}', 
                            ha='center', va='center')
            
            plt.colorbar(im, ax=ax_corr)
        else:
            ax_corr.text(0.5, 0.5, 'No data available', 
                    ha='center', va='center', transform=ax_corr.transAxes)
        
        plt.tight_layout()
        plt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        self.save_numerical_results(fake_metrics, real_metrics)

    def save_numerical_results(self, fake_metrics, real_metrics):
        with open('metrics_results.txt', 'w') as f:
            f.write("=== Numerical Results ===\n\n")
            
            for metric in ['psnr', 'ssim', 'lpips']:
                f.write(f"\n{metric.upper()} Statistics:\n")
                
                fake_values = np.array(fake_metrics[metric])
                real_values = np.array(real_metrics[metric])
                
                fake_values = fake_values[~np.isinf(fake_values)]
                real_values = real_values[~np.isinf(real_values)]
                
                f.write(f"Fake vs Inv_Fake:\n")
                f.write(f"  Mean: {np.mean(fake_values):.4f}\n")
                f.write(f"  Std:  {np.std(fake_values):.4f}\n")
                f.write(f"  Min:  {np.min(fake_values):.4f}\n")
                f.write(f"  Max:  {np.max(fake_values):.4f}\n")
                
                f.write(f"Real vs Inv_Real:\n")
                f.write(f"  Mean: {np.mean(real_values):.4f}\n")
                f.write(f"  Std:  {np.std(real_values):.4f}\n")
                f.write(f"  Min:  {np.min(real_values):.4f}\n")
                f.write(f"  Max:  {np.max(real_values):.4f}\n")


def main():
    base_dir = '/shared/shashmi/inversion/inversion/SD3.5'
    fake_dir = os.path.join(base_dir, 'fake')
    inv_fake_dir = os.path.join(base_dir, 'inv_fake')
    real_dir = os.path.join(base_dir, 'real')
    inv_real_dir = os.path.join(base_dir, 'inv_real')
    
    metrics_calc = ImageMetricsComparisonSimple(
        fake_dir, inv_fake_dir, real_dir, inv_real_dir
    )
    
    fake_metrics = metrics_calc.calculate_metrics_for_pair(fake_dir, inv_fake_dir, "Fake vs Inv_Fake")
    real_metrics = metrics_calc.calculate_metrics_for_pair(real_dir, inv_real_dir, "Real vs Inv_Real")
    metrics_calc.plot_comparative_metrics(fake_metrics, real_metrics)

if __name__ == "__main__":
    try:
        torch.multiprocessing.set_start_method('spawn')
    except RuntimeError:
        pass
    main()

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


Loading model from: /shared/shashmi/conda_envs/inverter/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth


Fake vs Inv_Fake: 100%|██████████| 1279/1279 [12:33<00:00,  1.70it/s]
Real vs Inv_Real: 100%|██████████| 1271/1271 [13:32<00:00,  1.56it/s]


In [1]:
import os
import shutil
import random

def get_common_files(dir1, dir2):
    """Get list of filenames that exist in both directories"""
    files1 = set(os.listdir(dir1))
    files2 = set(os.listdir(dir2))
    return list(files1.intersection(files2))

def copy_paired_images(source_dir, dest_dir, num_pairs=10):
    """
    Copy specified number of image pairs maintaining the correspondence between
    fake/inv_fake and real/inv_real folders.
    """
    # Create destination directory if it doesn't exist
    if not os.path.exists(dest_dir):
        os.makedirs(dest_dir)
    
    # Create all required folders
    for folder in ['fake', 'inv_fake', 'real', 'inv_real']:
        os.makedirs(os.path.join(dest_dir, folder), exist_ok=True)
    
    # Get common files between fake and inv_fake
    fake_path = os.path.join(source_dir, 'fake')
    inv_fake_path = os.path.join(source_dir, 'inv_fake')
    fake_pairs = get_common_files(fake_path, inv_fake_path)
    
    # Get common files between real and inv_real
    real_path = os.path.join(source_dir, 'real')
    inv_real_path = os.path.join(source_dir, 'inv_real')
    real_pairs = get_common_files(real_path, inv_real_path)
    
    # Select random samples
    if len(fake_pairs) < num_pairs:
        print(f"Warning: Only {len(fake_pairs)} matching fake/inv_fake pairs found. Using all of them.")
        selected_fake_pairs = fake_pairs
    else:
        selected_fake_pairs = random.sample(fake_pairs, num_pairs)
        
    if len(real_pairs) < num_pairs:
        print(f"Warning: Only {len(real_pairs)} matching real/inv_real pairs found. Using all of them.")
        selected_real_pairs = real_pairs
    else:
        selected_real_pairs = random.sample(real_pairs, num_pairs)
    
    # Copy fake/inv_fake pairs
    for filename in selected_fake_pairs:
        # Copy fake image
        shutil.copy2(
            os.path.join(fake_path, filename),
            os.path.join(dest_dir, 'fake', filename)
        )
        # Copy corresponding inv_fake image
        shutil.copy2(
            os.path.join(inv_fake_path, filename),
            os.path.join(dest_dir, 'inv_fake', filename)
        )
        print(f"Copied fake pair: {filename}")
    
    # Copy real/inv_real pairs
    for filename in selected_real_pairs:
        # Copy real image
        shutil.copy2(
            os.path.join(real_path, filename),
            os.path.join(dest_dir, 'real', filename)
        )
        # Copy corresponding inv_real image
        shutil.copy2(
            os.path.join(inv_real_path, filename),
            os.path.join(dest_dir, 'inv_real', filename)
        )
        print(f"Copied real pair: {filename}")

# Set paths
current_dir = "/shared/shashmi/inversion/inversion/SD3"
demo_dir = os.path.join(os.path.dirname(current_dir), "sd3_demo")

# Execute the copying
copy_paired_images(current_dir, demo_dir, 100)

print("\nProcess completed! Check the sd3_demo folder for the copied image pairs.")

Copied fake pair: people_with_their_head_covered_on_a_motorbike_068db080f5.png
Copied fake pair: a_red_bus_that_has_stuff_wrote_on_the_outside_e8a7ceee81.png
Copied fake pair: A_skier_standing_behind_a_difficulty_level_sign_on_da34a5a4ae.png
Copied fake pair: a_train_on_a_train_track_at_a_station__02de4d0786.png
Copied fake pair: An_unfinished_bathroom_with_a_boarded_window_and_e_1c8628bb4f.png
Copied fake pair: A_small_her_of_sheep_gathered_on_a_lawn_in_front_o_17c7e3c548.png
Copied fake pair: A_city_train_as_it_travels_down_the_tracks__0ed76a84c6.png
Copied fake pair: An_old_brick_city_has_water_and_a_clock_tower__2cdbb626d9.png
Copied fake pair: A_white_bus_driving_down_a_street_next_to_trees__39d06f449a.png
Copied fake pair: A_woman_sitting_on_a_bench_in_a_field_in_front_of__6b172362cc.png
Copied fake pair: A_biplane_takes_flight_on_a_sunny_day_fbf9235db7.png
Copied fake pair: A_Virgin_Records_train_next_to_a_blue__yellow_and__b66ef7639d.png
Copied fake pair: Blurred_image_of_motor