# CycleGAN Test: IHC to HE Virtual Staining Evaluation

This notebook provides comprehensive evaluation of IHC → HE translation using trained CycleGAN model.

## Evaluation Metrics (Standard for Virtual Staining)
- **PSNR**: Peak Signal-to-Noise Ratio
- **SSIM**: Structural Similarity Index
- **LPIPS**: Learned Perceptual Image Patch Similarity
- **FID**: Fréchet Inception Distance (GAN standard)
- **Color Histogram Similarity**: Critical for staining evaluation

## 1. Import Libraries and Setup

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import inception_v3
from PIL import Image
import torch.utils.data as data
from glob import glob
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg
from scipy.stats import gaussian_kde
import pandas as pd
import seaborn as sns
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
import os
warnings.filterwarnings('ignore')

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

## 2. Install Required Packages

In [None]:
!pip install lpips scikit-image -q

## 3. Test Configuration

In [None]:
# Test parameters
params = {
    'batch_size': 1,
    'input_size': 512,
    'test_sample_count': 100,
    'model_epoch': 99,
    'img_form': 'png'
}

# Paths
data_dir = '../../data/IHC_HE_Pair_Data_GA_SS/patches'
model_dir = '../../model/HE_IHC_translation/internal_ss/PD-L1'
result_dir = '../../results/HE_IHC_translation/internal_ss/PD-L1/test_results'

# Create directories
create_dir(result_dir)
create_dir(f'{result_dir}/visualizations')
create_dir(f'{result_dir}/generated_images')

print("="*60)
print("TEST CONFIGURATION")
print("="*60)
print(f"Model epoch: {params['model_epoch']}")
print(f"Test samples: {params['test_sample_count']}")
print(f"Image size: {params['input_size']}x{params['input_size']}")
print(f"Results directory: {result_dir}")
print("="*60)

## 4. Define Generator Architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(4):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

print("✓ Generator architecture defined")

## 5. Load Test Dataset

In [None]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, HE_image_list, IHC_image_list):
        super(DatasetFromFolder, self).__init__()
        self.HE_image_list = HE_image_list
        self.IHC_image_list = IHC_image_list
    
    def __getitem__(self, index):
        return self.HE_image_list[index], self.IHC_image_list[index]
    
    def __len__(self):
        return len(self.HE_image_list)

# Transform
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor()
])

# Load test data
print("Loading test data...")
test_data_HE = glob(f'{data_dir}/he_mpp1/*.{params["img_form"]}')
test_max_count = min(params['test_sample_count'], len(test_data_HE))
test_data_HE = random.sample(test_data_HE, test_max_count)
test_data_IHC = [f.replace('/he_mpp1/', '/pdl1_mpp1/') for f in test_data_HE]

print(f"Found {len(test_data_HE)} test image pairs")

# Preload images
test_image_HE = torch.zeros((len(test_data_HE), 3, params['input_size'], params['input_size']))
test_image_IHC = torch.zeros((len(test_data_IHC), 3, params['input_size'], params['input_size']))

for i in tqdm(range(len(test_data_HE)), desc="Loading test images"):
    img = Image.open(test_data_HE[i]).convert('RGB')
    target = Image.open(test_data_IHC[i]).convert('RGB')
    img = transform(img) * 2. - 1
    target = transform(target) * 2. - 1
    test_image_HE[i] = img
    test_image_IHC[i] = target

test_dataset = DatasetFromFolder(test_image_HE, test_image_IHC)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=params['batch_size'],
    shuffle=False
)

print(f"✓ Test dataset loaded: {len(test_dataset)} image pairs")

## 6. Load Trained Model (Generator F: IHC → HE)

In [None]:
# Initialize and load model
F = Generator(3, 3).to(device)
model_path = f'{model_dir}/F_{params["model_epoch"]}.pth'
F.load_state_dict(torch.load(model_path, map_location=device))
F.eval()

print(f"✓ Loaded model: {model_path}")
print(f"✓ Generator F (IHC → HE) ready for testing")

## 7. Define Evaluation Metrics

In [None]:
import lpips

# Initialize LPIPS
lpips_model = lpips.LPIPS(net='alex').to(device)

def denormalize(img):
    """Denormalize from [-1, 1] to [0, 1]"""
    return img * 0.5 + 0.5

def calculate_psnr(img1, img2):
    img1_np = img1.cpu().numpy().transpose(1, 2, 0)
    img2_np = img2.cpu().numpy().transpose(1, 2, 0)
    return psnr(img1_np, img2_np, data_range=1.0)

def calculate_ssim(img1, img2):
    img1_np = img1.cpu().numpy().transpose(1, 2, 0)
    img2_np = img2.cpu().numpy().transpose(1, 2, 0)
    return ssim(img1_np, img2_np, multichannel=True, data_range=1.0, channel_axis=2)

def calculate_lpips(img1, img2):
    with torch.no_grad():
        lpips_value = lpips_model(img1.unsqueeze(0), img2.unsqueeze(0))
    return lpips_value.item()

def calculate_mae(img1, img2):
    return torch.mean(torch.abs(img1 - img2)).item()

def calculate_histogram_similarity(img1, img2, bins=256):
    """Calculate histogram correlation for RGB channels"""
    correlations = []
    for channel in range(3):
        hist1, _ = np.histogram(img1[channel].flatten(), bins=bins, range=(0, 1))
        hist2, _ = np.histogram(img2[channel].flatten(), bins=bins, range=(0, 1))
        hist1 = hist1 / hist1.sum()
        hist2 = hist2 / hist2.sum()
        correlation = np.corrcoef(hist1, hist2)[0, 1]
        correlations.append(correlation)
    return np.mean(correlations)

print("✓ Evaluation metrics initialized")
print("  - PSNR, SSIM, LPIPS, MAE")
print("  - Color Histogram Similarity")
print("  - FID (will be calculated separately)")

## 8. Run Evaluation on Test Set

In [None]:
# Storage for metrics
results = {
    'psnr': [],
    'ssim': [],
    'lpips': [],
    'mae': [],
    'hist_sim': []
}

# Storage for ALL generated images (needed for best/worst analysis)
generated_images = []
real_he_images = []
input_ihc_images = []

print("\n" + "="*60)
print("RUNNING EVALUATION ON TEST SET")
print("="*60)

with torch.no_grad():
    for idx, (real_he, real_ihc) in enumerate(tqdm(test_loader, desc="Evaluating")):
        real_he = real_he.to(device)
        real_ihc = real_ihc.to(device)
        
        # Generate fake HE from IHC
        fake_he = F(real_ihc)
        
        # Denormalize
        real_he_denorm = denormalize(real_he[0])
        fake_he_denorm = denormalize(fake_he[0])
        
        # Calculate metrics
        results['psnr'].append(calculate_psnr(fake_he_denorm, real_he_denorm))
        results['ssim'].append(calculate_ssim(fake_he_denorm, real_he_denorm))
        results['lpips'].append(calculate_lpips(fake_he[0], real_he[0]))
        results['mae'].append(calculate_mae(fake_he_denorm, real_he_denorm))
        results['hist_sim'].append(calculate_histogram_similarity(fake_he_denorm.cpu(), real_he_denorm.cpu()))
        
        # Store all images
        generated_images.append(fake_he_denorm.cpu())
        real_he_images.append(real_he_denorm.cpu())
        input_ihc_images.append(denormalize(real_ihc[0]).cpu())

print("\n" + "="*60)
print("✓ Evaluation completed!")
print("="*60)

## 9. Calculate FID Score

In [None]:
from scipy.linalg import sqrtm

# Load Inception V3
print("Loading Inception V3 model for FID calculation...")
inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
inception_model.fc = nn.Identity()
inception_model.eval()

def get_inception_features(images):
    features = []
    with torch.no_grad():
        for img in tqdm(images, desc="Extracting features"):
            # Resize to 299x299 for Inception
            img_resized = torch.nn.functional.interpolate(
                img.unsqueeze(0), size=(299, 299), mode='bilinear', align_corners=False
            )
            # Normalize to [-1, 1]
            img_normalized = img_resized * 2 - 1
            feat = inception_model(img_normalized.to(device))
            features.append(feat.cpu().numpy())
    return np.concatenate(features, axis=0)

def calculate_fid(real_features, fake_features):
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

# Extract features
print("Extracting features from real HE images...")
real_features = get_inception_features(real_he_images)

print("Extracting features from generated HE images...")
fake_features = get_inception_features(generated_images)

# Calculate FID
fid_score = calculate_fid(real_features, fake_features)

print("\n" + "="*60)
print(f"FID Score: {fid_score:.2f}")
print("="*60)
print("Interpretation:")
print("  - FID < 50: Excellent")
print("  - FID 50-100: Good")
print("  - FID 100-200: Moderate")
print("  - FID > 200: Poor")
print("="*60)

## 10. Calculate Statistics

In [None]:
# Calculate statistics
statistics = {}
for metric_name, values in results.items():
    statistics[metric_name] = {
        'mean': np.mean(values),
        'std': np.std(values),
        'min': np.min(values),
        'max': np.max(values),
        'median': np.median(values)
    }

# Add FID to statistics
statistics['fid'] = {
    'mean': fid_score,
    'std': 0,
    'min': fid_score,
    'max': fid_score,
    'median': fid_score
}

# Display results
print("\n" + "="*70)
print(" "*20 + "QUANTITATIVE RESULTS")
print("="*70)
print(f"{'Metric':<15} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
print("-"*70)

metric_display = {
    'psnr': 'PSNR (dB)',
    'ssim': 'SSIM',
    'lpips': 'LPIPS',
    'mae': 'MAE',
    'hist_sim': 'Hist Corr',
    'fid': 'FID'
}

for metric_name, display_name in metric_display.items():
    stats = statistics[metric_name]
    print(f"{display_name:<15} {stats['mean']:<12.4f} {stats['std']:<12.4f} "
          f"{stats['min']:<12.4f} {stats['max']:<12.4f}")

print("="*70)

# Save to CSV
df = pd.DataFrame(statistics).T
df.columns = ['Mean', 'Std', 'Min', 'Max', 'Median']
csv_path = f'{result_dir}/quantitative_results.csv'
df.to_csv(csv_path)
print(f"\n✓ Results saved to: {csv_path}")

## 11. Visualize Metric Distributions

In [None]:
# Set style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

# Create subplots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Distribution of Image Quality Metrics (IHC → HE Translation)', 
             fontsize=16, fontweight='bold', y=0.995)

metrics_info = [
    ('psnr', 'PSNR (dB)', 'Higher is better'),
    ('ssim', 'SSIM', 'Higher is better'),
    ('lpips', 'LPIPS', 'Lower is better'),
    ('mae', 'MAE', 'Lower is better'),
    ('hist_sim', 'Histogram Correlation', 'Higher is better')
]

for idx, (metric_name, label, interpretation) in enumerate(metrics_info):
    row = idx // 3
    col = idx % 3
    ax = axes[row, col]
    
    values = results[metric_name]
    
    # Histogram with KDE
    ax.hist(values, bins=30, alpha=0.6, color='skyblue', edgecolor='black', density=True)
    
    # KDE line
    kde = gaussian_kde(values)
    x_range = np.linspace(min(values), max(values), 100)
    ax.plot(x_range, kde(x_range), 'r-', linewidth=2, label='KDE')
    
    # Mean and median lines
    mean_val = statistics[metric_name]['mean']
    median_val = statistics[metric_name]['median']
    ax.axvline(mean_val, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
    ax.axvline(median_val, color='orange', linestyle='--', linewidth=2, label=f'Median: {median_val:.4f}')
    
    ax.set_xlabel(label, fontsize=11, fontweight='bold')
    ax.set_ylabel('Density', fontsize=11, fontweight='bold')
    ax.set_title(f'{label}\n({interpretation})', fontsize=12)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

# FID display in last subplot
axes[1, 2].axis('off')
axes[1, 2].text(0.5, 0.5, f'FID Score\n{fid_score:.2f}', 
                ha='center', va='center', fontsize=24, fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

plt.tight_layout()
fig_path = f'{result_dir}/visualizations/metric_distributions.png'
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Metric distributions saved to: {fig_path}")

## 12. Color Histogram Analysis

In [None]:
# Visualize color histograms
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle('Color Histogram Comparison: Generated vs Real HE', fontsize=14, fontweight='bold')

colors = ['red', 'green', 'blue']
channel_names = ['Red', 'Green', 'Blue']

for ch_idx in range(3):
    ax = axes[ch_idx]
    
    # Calculate average histogram
    real_hist = np.zeros(256)
    gen_hist = np.zeros(256)
    
    for real_img, gen_img in zip(real_he_images, generated_images):
        h1, _ = np.histogram(real_img[ch_idx].flatten(), bins=256, range=(0, 1))
        h2, _ = np.histogram(gen_img[ch_idx].flatten(), bins=256, range=(0, 1))
        real_hist += h1
        gen_hist += h2
    
    real_hist = real_hist / real_hist.sum()
    gen_hist = gen_hist / gen_hist.sum()
    
    x = np.linspace(0, 1, 256)
    ax.plot(x, real_hist, color=colors[ch_idx], linewidth=2, label='Real HE', alpha=0.7)
    ax.plot(x, gen_hist, color=colors[ch_idx], linewidth=2, label='Generated HE', 
            linestyle='--', alpha=0.7)
    ax.set_title(f'{channel_names[ch_idx]} Channel', fontsize=12, fontweight='bold')
    ax.set_xlabel('Intensity', fontsize=10)
    ax.set_ylabel('Frequency', fontsize=10)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
hist_path = f'{result_dir}/visualizations/color_histograms.png'
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Color histogram saved to: {hist_path}")
print(f"\nMean Histogram Correlation: {statistics['hist_sim']['mean']:.4f}")

## 13. Qualitative Results

In [None]:
# Show 8 random samples
num_samples = 8
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
fig.suptitle('Qualitative Results: IHC → HE Translation', fontsize=18, fontweight='bold')

for i in range(num_samples):
    # Input IHC
    axes[i, 0].imshow(input_ihc_images[i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title('Input (IHC)', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Generated HE
    axes[i, 1].imshow(generated_images[i].permute(1, 2, 0).numpy())
    axes[i, 1].set_title('Generated (HE)', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Real HE
    axes[i, 2].imshow(real_he_images[i].permute(1, 2, 0).numpy())
    axes[i, 2].set_title('Ground Truth (HE)', fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
qual_path = f'{result_dir}/visualizations/qualitative_results.png'
plt.savefig(qual_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Qualitative results saved to: {qual_path}")

## 14. Best and Worst Cases

In [None]:
# Find best and worst based on SSIM
ssim_values = np.array(results['ssim'])
psnr_values = np.array(results['psnr'])

best_indices = np.argsort(ssim_values)[-4:][::-1]
worst_indices = np.argsort(ssim_values)[:4]

# Best cases
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
fig.suptitle('Best Translation Results (Highest SSIM)', fontsize=16, fontweight='bold')

for idx, sample_idx in enumerate(best_indices):
    axes[idx, 0].imshow(input_ihc_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 0].set_title('Input IHC', fontsize=11)
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(generated_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 1].set_title('Generated HE', fontsize=11)
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(real_he_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 2].set_title('Ground Truth HE', fontsize=11)
    axes[idx, 2].axis('off')
    
    metrics_text = f'SSIM: {ssim_values[sample_idx]:.4f}\nPSNR: {psnr_values[sample_idx]:.2f} dB'
    axes[idx, 0].text(-0.15, 0.5, metrics_text, transform=axes[idx, 0].transAxes,
                     fontsize=10, fontweight='bold', verticalalignment='center',
                     bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

plt.tight_layout()
best_path = f'{result_dir}/visualizations/best_cases.png'
plt.savefig(best_path, dpi=300, bbox_inches='tight')
plt.show()

# Worst cases
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
fig.suptitle('Worst Translation Results (Lowest SSIM)', fontsize=16, fontweight='bold')

for idx, sample_idx in enumerate(worst_indices):
    axes[idx, 0].imshow(input_ihc_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 0].set_title('Input IHC', fontsize=11)
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(generated_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 1].set_title('Generated HE', fontsize=11)
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(real_he_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 2].set_title('Ground Truth HE', fontsize=11)
    axes[idx, 2].axis('off')
    
    metrics_text = f'SSIM: {ssim_values[sample_idx]:.4f}\nPSNR: {psnr_values[sample_idx]:.2f} dB'
    axes[idx, 0].text(-0.15, 0.5, metrics_text, transform=axes[idx, 0].transAxes,
                     fontsize=10, fontweight='bold', verticalalignment='center',
                     bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.8))

plt.tight_layout()
worst_path = f'{result_dir}/visualizations/worst_cases.png'
plt.savefig(worst_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Best cases saved to: {best_path}")
print(f"✓ Worst cases saved to: {worst_path}")

## 15. Generate Paper-Ready Summary

In [None]:
# Create paper-ready table
paper_table = pd.DataFrame({
    'Metric': ['PSNR (dB) ↑', 'SSIM ↑', 'LPIPS ↓', 'MAE ↓', 'Hist Corr ↑', 'FID ↓'],
    'Mean ± Std': [
        f"{statistics['psnr']['mean']:.2f} ± {statistics['psnr']['std']:.2f}",
        f"{statistics['ssim']['mean']:.4f} ± {statistics['ssim']['std']:.4f}",
        f"{statistics['lpips']['mean']:.4f} ± {statistics['lpips']['std']:.4f}",
        f"{statistics['mae']['mean']:.4f} ± {statistics['mae']['std']:.4f}",
        f"{statistics['hist_sim']['mean']:.4f} ± {statistics['hist_sim']['std']:.4f}",
        f"{fid_score:.2f}"
    ]
})

print("\n" + "="*70)
print(" "*20 + "PAPER-READY RESULTS")
print("="*70)
print(paper_table.to_string(index=False))
print("="*70)
print("Note: ↑ = higher is better, ↓ = lower is better")

# Save LaTeX table
latex_table = paper_table.to_latex(index=False, escape=False)
latex_path = f'{result_dir}/paper_table.tex'
with open(latex_path, 'w') as f:
    f.write(latex_table)

print(f"\n✓ LaTeX table saved to: {latex_path}")

## 16. Final Summary Report

In [None]:
# Generate comprehensive report
report = f"""
{'='*80}
                    CYCLEGAN VIRTUAL STAINING TEST REPORT
                        IHC to HE Translation Evaluation
{'='*80}

TEST CONFIGURATION
{'-'*80}
Model Path:           {model_path}
Test Samples:         {len(test_dataset)}
Image Size:           {params['input_size']}x{params['input_size']}
Device:               {device}
Model Epoch:          {params['model_epoch']}

QUANTITATIVE RESULTS (Standard Virtual Staining Metrics)
{'-'*80}
Metric                Mean ± Std              Range
{'-'*80}
PSNR (dB)            {statistics['psnr']['mean']:.2f} ± {statistics['psnr']['std']:.2f}           [{statistics['psnr']['min']:.2f}, {statistics['psnr']['max']:.2f}]
SSIM                 {statistics['ssim']['mean']:.4f} ± {statistics['ssim']['std']:.4f}         [{statistics['ssim']['min']:.4f}, {statistics['ssim']['max']:.4f}]
LPIPS                {statistics['lpips']['mean']:.4f} ± {statistics['lpips']['std']:.4f}         [{statistics['lpips']['min']:.4f}, {statistics['lpips']['max']:.4f}]
MAE                  {statistics['mae']['mean']:.4f} ± {statistics['mae']['std']:.4f}         [{statistics['mae']['min']:.4f}, {statistics['mae']['max']:.4f}]
Histogram Corr       {statistics['hist_sim']['mean']:.4f} ± {statistics['hist_sim']['std']:.4f}         [{statistics['hist_sim']['min']:.4f}, {statistics['hist_sim']['max']:.4f}]
FID Score            {fid_score:.2f}
{'-'*80}

INTERPRETATION
{'-'*80}
PSNR: {'Excellent' if statistics['psnr']['mean'] > 25 else 'Good' if statistics['psnr']['mean'] > 20 else 'Poor'} (target: >25 dB)
SSIM: {'Excellent' if statistics['ssim']['mean'] > 0.8 else 'Good' if statistics['ssim']['mean'] > 0.7 else 'Poor'} (target: >0.8)
FID:  {'Excellent' if fid_score < 50 else 'Good' if fid_score < 100 else 'Moderate' if fid_score < 200 else 'Poor'} (target: <50)
Color Match: {'Excellent' if statistics['hist_sim']['mean'] > 0.9 else 'Good' if statistics['hist_sim']['mean'] > 0.8 else 'Moderate'} (target: >0.9)

OUTPUT FILES
{'-'*80}
✓ Quantitative results:     {result_dir}/quantitative_results.csv
✓ LaTeX table:             {result_dir}/paper_table.tex
✓ Metric distributions:    {result_dir}/visualizations/metric_distributions.png
✓ Color histograms:        {result_dir}/visualizations/color_histograms.png
✓ Qualitative results:     {result_dir}/visualizations/qualitative_results.png
✓ Best cases:              {result_dir}/visualizations/best_cases.png
✓ Worst cases:             {result_dir}/visualizations/worst_cases.png

{'='*80}
                    EVALUATION COMPLETED SUCCESSFULLY
{'='*80}
"""

print(report)

# Save report
report_path = f'{result_dir}/evaluation_report.txt'
with open(report_path, 'w') as f:
    f.write(report)

print(f"✓ Full report saved to: {report_path}")