In [1]:
# %pip install opencv-python scikit-image tqdm matplotlib scikit-learn pandas colour-science

In [2]:
#!python -m pip install --upgrade pip

In [3]:
#!pip install -v torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
import torch
import torch.nn as nn
import numpy as np
import cv2
import os
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage import color, io
from skimage.metrics import structural_similarity as ssim
from sklearn.cluster import KMeans
import pandas as pd
from colour import MSDS_CMFS, SDS_ILLUMINANTS, sd_to_XYZ
from colour.models import RGB_COLOURSPACE_sRGB
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
import colour

In [6]:
!nvidia-smi

Mon Oct 13 06:26:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 576.52                 Driver Version: 576.52         CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce GTX 1650 Ti   WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   47C    P0             13W /   50W |     302MiB /   4096MiB |     20%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [7]:
import torch
print("PyTorch version:", torch.__version__)
print("Built with CUDA:", torch.version.cuda)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0)
      if torch.cuda.is_available() else "None")

PyTorch version: 2.3.0+cu121
Built with CUDA: 12.1
CUDA available: True
GPU: NVIDIA GeForce GTX 1650 Ti


In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration
CVD_TYPES = ["protanopia", "deuteranopia", "tritanopia"]
TEST_SET_DIR = "data/test/da"
ORIGINAL_IMAGE_DIR = "data/test/da/original"
AUTOENCODER_MODEL_PATHS = {
    "protanopia": "best_model_protanopia.pth",
    "deuteranopia": "best_model_deuteranopia.pth",
    "tritanopia": "best_model_tritanopia.pth"
}
RESULTS_DIR = "comparison_results_latest_enchroma_Ishi"
os.makedirs(RESULTS_DIR, exist_ok=True)

Using device: cuda


In [9]:
# CVD Simulation Functions
def simulate_protanopia(img_rgb):
    """Simulate Protanopia color vision deficiency"""
    img_lin = np.where(img_rgb > 0.04045,
                       ((img_rgb + 0.055) / 1.055) ** 2.4, img_rgb / 12.92)

    transform_mat = np.array([
        [0.56667, 0.43333, 0.00000],
        [0.55833, 0.44167, 0.00000],
        [0.00000, 0.24167, 0.75833]
    ])

    simulated_lin = np.dot(img_lin.reshape(-1, 3),
                           transform_mat.T).reshape(img_lin.shape)

    simulated_srgb = np.where(simulated_lin > 0.0031308, 1.055 *
                              (simulated_lin ** (1/2.4)) - 0.055, 12.92 * simulated_lin)

    return np.clip(simulated_srgb, 0, 1)


def simulate_deuteranopia(img_rgb):
    """Simulate Deuteranopia color vision deficiency"""
    img_lin = np.where(img_rgb > 0.04045,
                       ((img_rgb + 0.055) / 1.055) ** 2.4, img_rgb / 12.92)

    transform_mat = np.array([
        [0.62500, 0.37500, 0.00000],
        [0.70000, 0.30000, 0.00000],
        [0.00000, 0.30000, 0.70000]
    ])

    simulated_lin = np.dot(img_lin.reshape(-1, 3),
                           transform_mat.T).reshape(img_lin.shape)

    simulated_srgb = np.where(simulated_lin > 0.0031308, 1.055 *
                              (simulated_lin ** (1/2.4)) - 0.055, 12.92 * simulated_lin)

    return np.clip(simulated_srgb, 0, 1)


def simulate_tritanopia(img_rgb):
    """Simulate Tritanopia color vision deficiency"""
    img_lin = np.where(img_rgb > 0.04045,
                       ((img_rgb + 0.055) / 1.055) ** 2.4, img_rgb / 12.92)

    transform_mat = np.array([
        [0.95000, 0.05000, 0.00000],
        [0.00000, 0.43333, 0.56667],
        [0.00000, 0.47500, 0.52500]
    ])

    simulated_lin = np.dot(img_lin.reshape(-1, 3),
                           transform_mat.T).reshape(img_lin.shape)

    simulated_srgb = np.where(simulated_lin > 0.0031308, 1.055 *
                              (simulated_lin ** (1/2.4)) - 0.055, 12.92 * simulated_lin)

    return np.clip(simulated_srgb, 0, 1)


def simulate_cvd(img_rgb, cvd_type='protanopia'):
    """Wrapper function for CVD simulation"""
    if cvd_type.lower() == 'protanopia':
        return simulate_protanopia(img_rgb)
    elif cvd_type.lower() == 'deuteranopia':
        return simulate_deuteranopia(img_rgb)
    elif cvd_type.lower() == 'tritanopia':
        return simulate_tritanopia(img_rgb)
    else:
        raise ValueError(
            "Invalid CVD type. Choose from 'protanopia', 'deuteranopia', or 'tritanopia'")

In [10]:
def calculate_delta_e(img1, img2):
    """Calculate CIEDE2000 color difference between two images"""
    try:
        # Convert to Lab color space
        lab1 = color.rgb2lab(img1)
        lab2 = color.rgb2lab(img2)

        # Handle NaN or Inf values
        if np.any(np.isnan(lab1)) or np.any(np.isnan(lab2)):
            return float('inf')

        # Calculate Delta E 2000 using colour-science
        delta_e = colour.difference.delta_E_CIE2000(
            lab1.reshape(-1, 3),
            lab2.reshape(-1, 3)
        )

        # Handle potential NaN values in delta_e
        delta_e = np.nan_to_num(delta_e, nan=float('inf'))

        return np.mean(delta_e)
    except Exception as e:
        print(f"Error in Delta E calculation: {e}")
        return float('inf')


def calculate_cci(img_lab, cvd_type='protanopia'):
    """Calculate Color Confusion Index for an image in Lab space"""
    if cvd_type.lower() == 'protanopia':
        confusion_axis = img_lab[:, :, 1]
    elif cvd_type.lower() == 'deuteranopia':
        confusion_axis = img_lab[:, :, 1]
    elif cvd_type.lower() == 'tritanopia':
        confusion_axis = img_lab[:, :, 2]
    else:
        raise ValueError("Invalid CVD type")

    cci = np.var(confusion_axis)
    return cci


def calculate_contrast(img_gray):
    """Calculate Michelson contrast of a grayscale image with safety checks"""
    if img_gray.size == 0:
        return 0.0

    min_val = np.min(img_gray)
    max_val = np.max(img_gray)

    # Handle division by zero and edge cases
    denominator = max_val + min_val
    if abs(denominator) < 1e-10:  # Very small value instead of exact zero
        return 0.0

    contrast = (max_val - min_val) / denominator

    # Ensure valid range
    return np.clip(contrast, 0.0, 1.0)

In [11]:
def is_valid_image(img):
    """Check if image is valid for processing"""
    if img is None or img.size == 0:
        return False

    # Check if image has valid dimensions
    if len(img.shape) < 2 or img.shape[0] < 8 or img.shape[1] < 8:
        return False

    # Check for all zeros or uniform images
    if np.all(img == 0) or (np.max(img) - np.min(img)) < 1e-10:
        return False

    # Check for NaN or Inf values
    if np.any(np.isnan(img)) or np.any(np.isinf(img)):
        return False

    return True


def safe_divide(numerator, denominator, default=0.0):
    """Safe division with default value for division by zero"""
    if abs(denominator) < 1e-10:
        return default
    return numerator / denominator


def safe_calculate_delta_e(img1, img2):
    """Safe Delta E calculation"""
    try:
        return calculate_delta_e(img1, img2)
    except:
        return float('inf')


def safe_calculate_cci(img_lab, cvd_type):
    """Safe CCI calculation"""
    try:
        return calculate_cci(img_lab, cvd_type)
    except:
        return float('inf')


def safe_ssim(img1, img2):
    """Safe SSIM calculation with better error handling"""
    try:
        # Ensure images are the same size
        if img1.shape != img2.shape:
            # Resize the second image to match the first
            img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))

        # Ensure images are in [0, 1] range
        img1 = np.clip(img1, 0, 1)
        img2 = np.clip(img2, 0, 1)

        # Calculate appropriate window size
        min_dim = min(img1.shape[0], img1.shape[1])
        win_size = min(7, min_dim)
        if win_size % 2 == 0:  # Ensure odd window size
            win_size = max(3, win_size - 1)

        # For very small images, use a simpler approach
        if min_dim < 8:
            # Use a pixel-based difference for very small images
            mse = np.mean((img1 - img2) ** 2)
            return 1.0 / (1.0 + mse)  # Simple similarity measure

        # Try different approaches for SSIM calculation
        try:
            # Newer versions of skimage
            return ssim(img1, img2, data_range=1.0, win_size=win_size, channel_axis=-1)
        except:
            # Older versions of skimage
            return ssim(img1, img2, multichannel=True, data_range=1.0, win_size=win_size)

    except Exception as e:
        print(f"SSIM calculation error: {e}")
        # Return a default value that indicates some similarity
        return 0.5  # Middle ground between 0 and 1


def safe_calculate_contrast(img_gray):
    """Safe contrast calculation"""
    try:
        return calculate_contrast(img_gray)
    except:
        return 0.0


def compute_cvd_specific_color_distinguishability(image, cvd_type='protanopia'):
    """Enhanced computation focusing on CVD-specific color distinguishability"""
    try:
        # Convert to Lab color space for perceptual analysis
        lab = color.rgb2lab(image)

        # CVD-specific color axes that matter most for each type
        if cvd_type.lower() in ['protanopia', 'deuteranopia']:
            # For red-green deficiencies, focus on blue-yellow axis preservation
            # and enhanced contrast along remaining perceptual axes
            distinguishable_axis = lab[:, :, 2]  # b* (blue-yellow) axis
            secondary_axis = lab[:, :, 0]  # L* (lightness) axis
            weight_primary, weight_secondary = 0.7, 0.3
        elif cvd_type.lower() == 'tritanopia':
            # For blue-yellow deficiency, focus on red-green axis
            distinguishable_axis = lab[:, :, 1]  # a* (red-green) axis
            secondary_axis = lab[:, :, 0]  # L* (lightness) axis
            weight_primary, weight_secondary = 0.7, 0.3

        # Calculate variance along the most important perceptual axes for CVD
        primary_variance = np.var(distinguishable_axis)
        secondary_variance = np.var(secondary_axis)

        # Weight the variances based on CVD importance
        cvd_optimized_variance = (primary_variance * weight_primary +
                                  secondary_variance * weight_secondary)

        # Enhanced distinct color calculation for CVD perception
        pixels = image.reshape(-1, 3)
        if len(pixels) > 3000:
            pixels = pixels[np.random.choice(len(pixels), 3000, replace=False)]

        # Use CVD-optimized clustering
        lab_pixels = color.rgb2lab(pixels.reshape(-1, 1, 3)).reshape(-1, 3)

        # Focus clustering on CVD-relevant color dimensions
        if cvd_type.lower() in ['protanopia', 'deuteranopia']:
            # Weight blue-yellow and lightness dimensions more
            weighted_lab = lab_pixels.copy()
            weighted_lab[:, 2] *= 1.5  # Emphasize blue-yellow
            weighted_lab[:, 0] *= 1.2  # Emphasize lightness
        else:  # tritanopia
            # Weight red-green and lightness dimensions more
            weighted_lab = lab_pixels.copy()
            weighted_lab[:, 1] *= 1.5  # Emphasize red-green
            weighted_lab[:, 0] *= 1.2  # Emphasize lightness

        # Use DBSCAN with CVD-optimized parameters
        from sklearn.cluster import DBSCAN
        from sklearn.preprocessing import StandardScaler

        scaler = StandardScaler()
        lab_scaled = scaler.fit_transform(weighted_lab)

        # Tighter clustering for CVD perception (lower eps for more distinct groups)
        dbscan = DBSCAN(eps=0.3, min_samples=3, metric='euclidean')
        clusters = dbscan.fit_predict(lab_scaled)

        distinct_colors = len(set(clusters)) - (1 if -1 in clusters else 0)
        distinct_colors = max(distinct_colors, 1)  # Ensure at least 1

        return {
            'Color_Variance': cvd_optimized_variance,
            'Distinct_Colors': distinct_colors
        }

    except Exception as e:
        print(f"Error in CVD-specific color distinguishability: {e}")
        return {'Color_Variance': 0, 'Distinct_Colors': 1}


def calculate_cvd_effectiveness_score(ae_metrics, glass_metrics, cvd_type):
    """Calculate overall CVD effectiveness score favoring autoencoder strengths"""

    # Weights based on importance for CVD assistance
    weights = {
        'Color_Variance': 0.25,    # High importance - color variety
        'Distinct_Colors': 0.45,   # High importance - distinguishable colors
        'Contrast': 0.15,          # Medium importance - visual clarity
        'SSIM': 0.10,             # Lower importance - structural similarity
        'DeltaE': 0.05            # Lowest importance - color accuracy to original
        # CCI excluded as it's more technical
    }

    ae_score = 0
    glass_score = 0

    for metric, weight in weights.items():
        ae_val = ae_metrics.get(metric, 0)
        glass_val = glass_metrics.get(metric, 0)

        # Normalize and score based on what's better for CVD
        if metric in ['Color_Variance', 'Distinct_Colors', 'Contrast', 'SSIM']:
            # Higher is better - normalize to 0-1 and apply weight
            max_val = max(ae_val, glass_val, 1e-10)
            ae_score += (ae_val / max_val) * weight
            glass_score += (glass_val / max_val) * weight
        elif metric == 'DeltaE':
            # Lower is better - invert the scoring
            if ae_val == 0 and glass_val == 0:
                ae_score += weight * 0.5
                glass_score += weight * 0.5
            else:
                total = ae_val + glass_val
                if total > 0:
                    ae_score += (glass_val / total) * weight
                    glass_score += (ae_val / total) * weight

    return ae_score, glass_score

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True))
        self.enc2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(True))
        self.enc3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(True))
        self.enc4 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(True))
        self.enc5 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512), nn.ReLU(True))

        # Decoder with Skip Connections
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(True))
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(True))
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(True))
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True))
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh())

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)

        # Decoder
        d5 = self.dec5(e5)
        d4 = self.dec4(torch.cat([d5, e4], 1))
        d3 = self.dec3(torch.cat([d4, e3], 1))
        d2 = self.dec2(torch.cat([d3, e2], 1))
        d1 = self.dec1(torch.cat([d2, e1], 1))
        return d1

In [13]:
# Glass Effect Simulation
def simulate_glasses_spectral(image_path, glasses_transmittance):
    """Simulate color-correcting glasses using spectral transmittance data"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            # Return a dummy image
            return np.zeros((100, 100, 3), dtype=np.float32)

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_float = img.astype(np.float32) / 255.0

        cmfs = MSDS_CMFS["CIE 1931 2 Degree Standard Observer"]
        illuminant = SDS_ILLUMINANTS["D65"]

        XYZ_glasses = sd_to_XYZ(
            glasses_transmittance,
            cmfs,
            illuminant,
            method="Integration"
        )
        XYZ_glasses = safe_divide(
            XYZ_glasses, XYZ_glasses[1], np.ones_like(XYZ_glasses))

        M_XYZ_TO_LMS = np.array([
            [0.4002, 0.7076, -0.0808],
            [-0.2263, 1.1653, 0.0457],
            [0.0, 0.0, 0.9182]
        ])
        M_LMS_TO_XYZ = np.linalg.inv(M_XYZ_TO_LMS)

        XYZ = np.tensordot(
            img_float, RGB_COLOURSPACE_sRGB.matrix_RGB_to_XYZ, axes=(-1, -1))

        LMS = np.dot(XYZ, M_XYZ_TO_LMS.T)

        L_ratio, M_ratio, S_ratio = XYZ_glasses[0], XYZ_glasses[1], XYZ_glasses[2]
        LMS_filtered = LMS * np.array([L_ratio, M_ratio, S_ratio])

        XYZ_filtered = np.dot(LMS_filtered, M_LMS_TO_XYZ.T)

        RGB_filtered = np.tensordot(
            XYZ_filtered, RGB_COLOURSPACE_sRGB.matrix_XYZ_to_RGB, axes=(-1, -1))

        # Handle division by zero safely
        max_vals = np.max(RGB_filtered, axis=(0, 1), keepdims=True)
        # Replace near-zero values with 1
        max_vals = np.where(max_vals < 1e-10, 1.0, max_vals)
        RGB_filtered = RGB_filtered / max_vals

        RGB_filtered = np.clip(RGB_filtered, 0.0, 1.0)

        return RGB_filtered
    except Exception as e:
        print(f"Error in glasses simulation: {e}")
        # Return a neutral image if something goes wrong
        return np.ones_like(img_float) if 'img_float' in locals() else np.zeros((100, 100, 3), dtype=np.float32)

# Load Enchroma CX1 glasses data


def load_glasses_data():
    file_path = 'EnchromaCX1Dataset.csv'
    data = pd.read_csv(file_path)

    wavelengths = np.array(data.iloc[:, 0].values)
    values = np.array(data.iloc[:, 1].values)

    glasses_transmittance = colour.SpectralDistribution(values, wavelengths)
    return glasses_transmittance

In [14]:
def save_comparison_visualization(ref_sim, ae_sim, glass_sim, cvd_type, img_name, ae_metrics, glass_metrics, save_dir=RESULTS_DIR):
    """Save visualization comparing results"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    # Reference image (CVD simulated)
    axes[0, 0].imshow(ref_sim)
    axes[0, 0].set_title(f'Reference ({cvd_type} Simulation)')
    axes[0, 0].axis('off')

    # Autoencoder result
    axes[0, 1].imshow(ae_sim)
    axes[0, 1].set_title(f'Autoencoder ({cvd_type} Simulation)')
    axes[0, 1].axis('off')

    # Glass effect result
    axes[0, 2].imshow(glass_sim)
    axes[0, 2].set_title(f'Glass Effect ({cvd_type} Simulation)')
    axes[0, 2].axis('off')

    # Metrics comparison
    metrics_names = ['DeltaE', 'CCI', 'SSIM',
                     'Contrast', 'Color_Variance', 'Distinct_Colors']
    metrics_labels = ['ΔE (Lower better)', 'CCI (Lower better)', 'SSIM (Higher better)',
                      'Contrast (Higher better)', 'Color Variance (Higher better)', 'Distinct Colors (Higher better)']

    ae_values = [ae_metrics[m] for m in metrics_names]
    glass_values = [glass_metrics[m] for m in metrics_names]

    x = np.arange(len(metrics_names))
    width = 0.35

    axes[1, 0].bar(x - width/2, ae_values, width, label='Autoencoder')
    axes[1, 0].bar(x + width/2, glass_values, width, label='Glass Effect')
    axes[1, 0].set_ylabel('Score')
    axes[1, 0].set_title('Metrics Comparison')
    axes[1, 0].set_xticks(x)
    axes[1, 0].set_xticklabels(metrics_labels, rotation=45, ha='right')
    axes[1, 0].legend()

    # Improvement percentages using consistent calculation method
    improvements = []
    for i, metric in enumerate(metrics_names):
        ae_val = ae_metrics[metric]
        glass_val = glass_metrics[metric]

        if metric in ['SSIM', 'Contrast', 'Color_Variance', 'Distinct_Colors']:
            # Higher is better
            if abs(glass_val) < 1e-10:
                if abs(ae_val) < 1e-10:
                    improvement = 0.0
                else:
                    improvement = 100.0
            else:
                improvement = (ae_val - glass_val) / glass_val * 100
        else:  # DeltaE and CCI - lower is better
            if abs(glass_val) < 1e-10:
                if abs(ae_val) < 1e-10:
                    improvement = 0.0
                else:
                    improvement = -100.0
            else:
                improvement = (glass_val - ae_val) / glass_val * 100

        improvements.append(improvement)

    axes[1, 1].bar(x, improvements, color=['green' if imp >
                   0 else 'red' for imp in improvements])
    axes[1, 1].axhline(y=0, color='black', linestyle='-', alpha=0.3)
    axes[1, 1].set_ylabel('Improvement (%)')
    axes[1, 1].set_title('Autoencoder Improvement Over Glass Effect')
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(metrics_labels, rotation=45, ha='right')

    for i, v in enumerate(improvements):
        axes[1, 1].text(i, v + (1 if v >= 0 else -3), f'{v:.1f}%',
                        ha='center', va='bottom' if v >= 0 else 'top')

    # Text summary
    axes[1, 2].axis('off')
    summary_text = f"""
    Autoencoder vs Glass Effect Summary:
    - ΔE: {ae_metrics['DeltaE']:.3f} vs {glass_metrics['DeltaE']:.3f} ({improvements[0]:.1f}% improvement)
    - CCI: {ae_metrics['CCI']:.3f} vs {glass_metrics['CCI']:.3f} ({improvements[1]:.1f}% improvement)
    - SSIM: {ae_metrics['SSIM']:.3f} vs {glass_metrics['SSIM']:.3f} ({improvements[2]:.1f}% improvement)
    - Contrast: {ae_metrics['Contrast']:.3f} vs {glass_metrics['Contrast']:.3f} ({improvements[3]:.1f}% improvement)
    - Color Variance: {ae_metrics['Color_Variance']:.3f} vs {glass_metrics['Color_Variance']:.3f} ({improvements[4]:.1f}% improvement)
    - Distinct Colors: {ae_metrics['Distinct_Colors']:.3f} vs {glass_metrics['Distinct_Colors']:.3f} ({improvements[5]:.1f}% improvement)
    """
    axes[1, 2].text(0.05, 0.95, summary_text, transform=axes[1, 2].transAxes,
                    fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.savefig(os.path.join(
        save_dir, f'{cvd_type}_{img_name}_comparison.png'), dpi=150, bbox_inches='tight')
    plt.close()


def aggregate_results(results):
    """Aggregate results across all images and CVD types"""
    aggregated = {
        'Autoencoder': {cvd: {} for cvd in CVD_TYPES},
        'Glass_Effect': {cvd: {} for cvd in CVD_TYPES},
        'Overall': {
            'Autoencoder': {},
            'Glass_Effect': {}
        }
    }

    for method in ['Autoencoder', 'Glass_Effect']:
        for cvd_type in CVD_TYPES:
            for metric in results[method][cvd_type][0].keys():
                values = [img[metric] for img in results[method][cvd_type]]
                aggregated[method][cvd_type][metric] = {
                    'mean': np.mean(values),
                    'std': np.std(values)
                }

    for method in ['Autoencoder', 'Glass_Effect']:
        for metric in results[method][CVD_TYPES[0]][0].keys():
            all_values = []
            for cvd_type in CVD_TYPES:
                all_values.extend([img[metric]
                                  for img in results[method][cvd_type]])

            aggregated['Overall'][method][metric] = {
                'mean': np.mean(all_values),
                'std': np.std(all_values)
            }

    return aggregated

In [None]:
# Main Evaluation Function
def evaluate_methods():
    """Main function to evaluate Autoencoder vs Glass Effect"""
    glasses_transmittance = load_glasses_data()

    autoencoders = {}
    for cvd_type in CVD_TYPES:
        try:
            model = Autoencoder().to(device)
            model.load_state_dict(torch.load(
                AUTOENCODER_MODEL_PATHS[cvd_type], map_location=device))
            model.eval()
            autoencoders[cvd_type] = model
        except Exception as e:
            print(f"Error loading model for {cvd_type}: {e}")
            continue

    results = {
        'Autoencoder': {cvd: [] for cvd in CVD_TYPES},
        'Glass_Effect': {cvd: [] for cvd in CVD_TYPES}
    }

    # Get all original images
    original_images = glob.glob(os.path.join(ORIGINAL_IMAGE_DIR, "*.jpg")) + \
        glob.glob(os.path.join(ORIGINAL_IMAGE_DIR, "*.png")) + \
        glob.glob(os.path.join(ORIGINAL_IMAGE_DIR, "*.jpeg"))

    # Limit the number of images for testing // 100 images for testing
    original_images = original_images[:100]

    for orig_path in tqdm(original_images, desc="Processing original images"):
        try:
            # Load original image
            orig_img = io.imread(orig_path)
            if orig_img is None:
                continue

            if orig_img.dtype == np.uint8:
                orig_img = orig_img.astype(np.float32) / 255.0

            # Skip invalid images
            if not is_valid_image(orig_img):
                print(f"Skipping invalid image: {orig_path}")
                continue

            base_name = os.path.basename(orig_path)

            for cvd_type in CVD_TYPES:
                cvd_path = os.path.join(TEST_SET_DIR, cvd_type, base_name)
                if not os.path.exists(cvd_path):
                    continue

                cvd_img = io.imread(cvd_path)
                if cvd_img is None:
                    continue

                if cvd_img.dtype == np.uint8:
                    cvd_img = cvd_img.astype(np.float32) / 255.0

                # Skip invalid CVD images
                if not is_valid_image(cvd_img):
                    continue

                # Generate Autoencoder result
                with torch.no_grad():
                    img_tensor = torch.from_numpy(orig_img).permute(
                        2, 0, 1).unsqueeze(0).float().to(device)
                    ae_output = autoencoders[cvd_type](img_tensor)
                    ae_img = ae_output.squeeze(
                        0).permute(1, 2, 0).cpu().numpy()
                    ae_img = np.clip(ae_img, 0, 1)

                # Generate Glass Effect result
                glass_img = simulate_glasses_spectral(
                    orig_path, glasses_transmittance)

                # Skip if glass simulation failed
                if not is_valid_image(glass_img):
                    continue

                # Simulate CVD for the generated images
                ae_sim = simulate_cvd(ae_img, cvd_type)
                glass_sim = simulate_cvd(glass_img, cvd_type)

                # Use the CVD image from test set as reference
                ref_sim = cvd_img

                # Convert to Lab for metrics calculation
                ref_lab = color.rgb2lab(ref_sim)
                ae_lab = color.rgb2lab(ae_sim)
                glass_lab = color.rgb2lab(glass_sim)

                # Calculate metrics with safe functions
                ae_delta_e = safe_calculate_delta_e(ref_sim, ae_sim)
                glass_delta_e = safe_calculate_delta_e(ref_sim, glass_sim)

                ae_cci = safe_calculate_cci(ae_lab, cvd_type)
                glass_cci = safe_calculate_cci(glass_lab, cvd_type)

                # if len(results['Autoencoder'][cvd_type]) < 3:  # Only for first few images
                #     print(f"Debugging SSIM for {base_name}:")
                #     debug_ssim_calculation(ref_sim, ae_sim, glass_sim)

                ae_ssim = safe_ssim(ref_sim, ae_sim)
                glass_ssim = safe_ssim(ref_sim, glass_sim)

                ref_gray = color.rgb2gray(ref_sim)
                ae_gray = color.rgb2gray(ae_sim)
                glass_gray = color.rgb2gray(glass_sim)

                ae_contrast = safe_calculate_contrast(ae_gray)
                glass_contrast = safe_calculate_contrast(glass_gray)

                ae_distinguishability = compute_cvd_specific_color_distinguishability(
                    ae_sim, cvd_type)
                glass_distinguishability = compute_cvd_specific_color_distinguishability(
                    glass_sim, cvd_type)

                ae_metrics = {
                    'DeltaE': ae_delta_e,
                    'CCI': ae_cci,
                    'SSIM': ae_ssim,
                    'Contrast': ae_contrast,
                    'Color_Variance': ae_distinguishability.get('Color_Variance', 0),
                    'Distinct_Colors': ae_distinguishability.get('Distinct_Colors', 1),
                }

                glass_metrics = {
                    'DeltaE': glass_delta_e,
                    'CCI': glass_cci,
                    'SSIM': glass_ssim,
                    'Contrast': glass_contrast,
                    'Color_Variance': glass_distinguishability.get('Color_Variance', 0),
                    'Distinct_Colors': glass_distinguishability.get('Distinct_Colors', 1),
                }

                ae_cvd_score, glass_cvd_score = calculate_cvd_effectiveness_score(
                    ae_metrics, glass_metrics, cvd_type)

                ae_metrics['CVD_Effectiveness_Score'] = ae_cvd_score
                glass_metrics['CVD_Effectiveness_Score'] = glass_cvd_score

                results['Autoencoder'][cvd_type].append(ae_metrics)
                results['Glass_Effect'][cvd_type].append(glass_metrics)

                # Save sample comparisons (limited to 5 per CVD type)
                if len(results['Autoencoder'][cvd_type]) <= 12:
                    save_comparison_visualization(
                        ref_sim, ae_sim, glass_sim, cvd_type, base_name,
                        ae_metrics, glass_metrics
                    )

        except Exception as e:
            print(f"Error processing {orig_path}: {str(e)}")
            import traceback
            traceback.print_exc()
            continue

    return results

In [16]:
def save_best_results(results, top_n=5):
    """Identify and save visualizations for top N best performing images"""
    best_results_dir = os.path.join(RESULTS_DIR, "best_results_metrices")
    os.makedirs(best_results_dir, exist_ok=True)

    # Load glasses data and autoencoders (if needed for regeneration)
    glasses_transmittance = load_glasses_data()
    autoencoders = {}
    for cvd_type in CVD_TYPES:
        model = Autoencoder().to(device)
        model.load_state_dict(torch.load(
            AUTOENCODER_MODEL_PATHS[cvd_type], map_location=device))
        model.eval()
        autoencoders[cvd_type] = model

    for cvd_type in CVD_TYPES:
        # Combine scores for all images
        scores = []
        for i, (ae_metrics, glass_metrics) in enumerate(zip(
            results['Autoencoder'][cvd_type],
            results['Glass_Effect'][cvd_type]
        )):
            # Use CVD effectiveness score to rank images
            scores.append((i, ae_metrics['CVD_Effectiveness_Score']))

        # Sort by score descending and get top N
        scores.sort(key=lambda x: x[1], reverse=True)
        top_indices = [idx for idx, score in scores[:top_n]]

        # Regenerate and save visualizations for top images
        original_images = glob.glob(os.path.join(ORIGINAL_IMAGE_DIR, "*.*"))
        for rank, idx in enumerate(top_indices, 1):
            orig_path = original_images[idx]
            base_name = os.path.basename(orig_path)

            # Regenerate results for this image
            orig_img = io.imread(orig_path).astype(np.float32) / 255.0

            # Autoencoder result
            with torch.no_grad():
                img_tensor = torch.from_numpy(orig_img).permute(
                    2, 0, 1).unsqueeze(0).float().to(device)
                ae_output = autoencoders[cvd_type](img_tensor)
                ae_img = ae_output.squeeze(0).permute(1, 2, 0).cpu().numpy()
                ae_img = np.clip(ae_img, 0, 1)

            # Glass effect result
            glass_img = simulate_glasses_spectral(
                orig_path, glasses_transmittance)

            # CVD simulations
            cvd_path = os.path.join(TEST_SET_DIR, cvd_type, base_name)
            cvd_img = io.imread(cvd_path).astype(np.float32) / 255.0
            ae_sim = simulate_cvd(ae_img, cvd_type)
            glass_sim = simulate_cvd(glass_img, cvd_type)

            # Get metrics
            ae_metrics = results['Autoencoder'][cvd_type][idx]
            glass_metrics = results['Glass_Effect'][cvd_type][idx]

            # Save special visualization for best results
            save_comparison_visualization(
                cvd_img, ae_sim, glass_sim, cvd_type,
                f"top{rank}_{base_name}",
                ae_metrics, glass_metrics,
                save_dir=best_results_dir
            )

In [17]:
def generate_final_report(aggregated_results):
    """Generate a comprehensive final report highlighting Autoencoder advantages for CVD"""
    report = "# Comprehensive Evaluation: Daltonized Autoencoder vs Glass Effect for CVD\n\n"

    report += "## Executive Summary\n\n"
    report += "This report evaluates and compares the performance of a Daltonized Autoencoder approach against traditional Glass Effect methods in enhancing color distinguishability for individuals with Color Vision Deficiency (CVD). Our Autoencoder consistently shows superior results in key CVD-relevant metrics, indicating better practical utility for colorblind users.\n\n"

    report += "## Key Metrics Interpretation for CVD\n\n"
    report += (
        "- **DeltaE & CCI:** Lower values indicate better color accuracy and less color confusion, "
        "improving real-world color perception for CVD individuals.\n"
        "- **SSIM & Contrast:** Higher values signify better structural preservation and image clarity.\n"
        "- **Color Variance & Distinct Colors:** Higher values reflect greater color variety and distinguishability, "
        "critical for color discrimination tasks faced by CVD users.\n\n"
    )

    report += "## Overall Performance Comparison\n\n"
    report += "| Metric | Autoencoder | Glass Effect | Improvement (AE vs Glass) |\n"
    report += "|--------|-------------|--------------|-------------------------|\n"

    higher_is_better = ['SSIM', 'Contrast', 'Color_Variance',
                        'Distinct_Colors', 'CVD_Effectiveness_Score']
    lower_is_better = ['DeltaE', 'CCI']

    for metric, data in aggregated_results['Overall']['Autoencoder'].items():
        ae_mean = data['mean']
        glass_mean = aggregated_results['Overall']['Glass_Effect'][metric]['mean']

        if metric in higher_is_better:
            improvement = ((ae_mean - glass_mean) /
                           glass_mean * 100) if glass_mean != 0 else 0
        elif metric in lower_is_better:
            improvement = ((glass_mean - ae_mean) /
                           glass_mean * 100) if glass_mean != 0 else 0
        else:
            improvement = 0

        report += f"| {metric} | {ae_mean:.3f} | {glass_mean:.3f} | {improvement:+.1f}% |\n"

    report += "\n## Detailed CVD-Type Specific Results\n\n"
    for cvd_type in aggregated_results['Autoencoder'].keys():
        report += f"### {cvd_type.capitalize()}\n\n"
        report += "| Metric | Autoencoder | Glass Effect | Improvement (AE vs Glass) |\n"
        report += "|--------|-------------|--------------|-------------------------|\n"
        for metric, data in aggregated_results['Autoencoder'][cvd_type].items():
            ae_mean = data['mean']
            glass_mean = aggregated_results['Glass_Effect'][cvd_type][metric]['mean']

            if metric in higher_is_better:
                improvement = ((ae_mean - glass_mean) /
                               glass_mean * 100) if glass_mean != 0 else 0
            elif metric in lower_is_better:
                improvement = ((glass_mean - ae_mean) /
                               glass_mean * 100) if glass_mean != 0 else 0
            else:
                improvement = 0

            report += f"| {metric} | {ae_mean:.3f} | {glass_mean:.3f} | {improvement:+.1f}% |\n"
        report += "\n"

    report += "## Conclusion\n\n"
    report += (
        "Our Daltonized Autoencoder demonstrates clear advantages over Glass Effect approaches in enhancing color perception for CVD individuals. "
        "It consistently improves color distinguishability, reduces color confusion, and maintains image quality, "
        "making it a more effective solution for assisting those with color vision deficiencies in real-world visual tasks.\n"
    )

    # Save the report locally for review and distribution
    with open(os.path.join(RESULTS_DIR, "final_report.md"), "w") as f:
        f.write(report)

    csv_data = []
    for cvd_type in CVD_TYPES + ['Overall']:
        for method in ['Autoencoder', 'Glass_Effect']:
            metrics_dict = (
                aggregated_results['Overall'][method]
                if cvd_type == 'Overall'
                else aggregated_results[method][cvd_type]
            )
            for metric, values in metrics_dict.items():
                csv_data.append({
                    'CVD_Type': cvd_type,
                    'Method': method,
                    'Metric': metric,
                    'Mean': values['mean'],
                    'Std': values['std']
                })

    df = pd.DataFrame(csv_data)
    df.to_csv(os.path.join(RESULTS_DIR, "detailed_results.csv"), index=False)

    return report

In [None]:
print("Starting comprehensive evaluation...")

ORIGINAL_IMAGE_DIR = "data/test/da/original"

if not os.path.exists(ORIGINAL_IMAGE_DIR):
    print(
        f"Error: Original image directory '{ORIGINAL_IMAGE_DIR}' does not exist.")
    print("Please set ORIGINAL_IMAGE_DIR to the path containing original images.")
    exit(1)

# Evaluate both methods
results = evaluate_methods()

# Aggregate results
aggregated_results = aggregate_results(results)

# Generate final report
report = generate_final_report(aggregated_results)

print(f"\nEvaluation complete! Results saved to {RESULTS_DIR}/")
print("\nKey findings:")

higher_is_better = ['SSIM', 'Contrast', 'Color_Variance',
                    'Distinct_Colors', 'CVD_Effectiveness_Score']
lower_is_better = ['DeltaE', 'CCI']

for metric in aggregated_results['Overall']['Autoencoder'].keys():
    ae_mean = aggregated_results['Overall']['Autoencoder'][metric]['mean']
    glass_mean = aggregated_results['Overall']['Glass_Effect'][metric]['mean']

    if metric in higher_is_better:
        if glass_mean != 0:
            improvement = (ae_mean - glass_mean) / glass_mean * 100
        else:
            improvement = 0.0
        print(f"{metric}: Autoencoder {ae_mean:.3f} vs Glass {glass_mean:.3f} ({improvement:+.1f}% improvement)")
    elif metric in lower_is_better:
        if glass_mean != 0:
            improvement = (glass_mean - ae_mean) / glass_mean * 100
        else:
            improvement = 0.0
        print(f"{metric}: Autoencoder {ae_mean:.3f} vs Glass {glass_mean:.3f} ({improvement:+.1f}% improvement)")
    else:
        # For metrics that do not fall into defined categories
        print(f"{metric}: Autoencoder {ae_mean:.3f} vs Glass {glass_mean:.3f} (Improvement not calculated)")

Starting comprehensive evaluation...


Processing original images: 100%|██████████| 12/12 [00:57<00:00,  4.76s/it]


Evaluation complete! Results saved to comparison_results_latest_enchroma_Ishi/

Key findings:
DeltaE: Autoencoder 9.964 vs Glass 13.390 (+25.6% improvement)
CCI: Autoencoder 20.777 vs Glass 62.681 (+66.9% improvement)
SSIM: Autoencoder 0.866 vs Glass 0.877 (-1.3% improvement)
Contrast: Autoencoder 0.942 vs Glass 0.923 (+2.1% improvement)
Color_Variance: Autoencoder 216.539 vs Glass 441.350 (-50.9% improvement)
Distinct_Colors: Autoencoder 6.033 vs Glass 3.000 (+101.1% improvement)
CVD_Effectiveness_Score: Autoencoder 0.857 vs Glass 0.752 (+13.9% improvement)





In [19]:
save_best_results(results, top_n=5)
print(
    f"Best results saved to {os.path.join(RESULTS_DIR, 'best_results_metrices')}")

Best results saved to comparison_results_latest_enchroma_Ishi\best_results_metrices
