# CycleGAN Test: IHC to HE Virtual Staining Evaluation

This notebook provides comprehensive evaluation of IHC → HE translation using trained CycleGAN model.

## Evaluation Metrics (Appropriate for Unpaired Virtual Staining)

**Note**: IHC and HE are from consecutive tissue sections, NOT pixel-aligned. 
Therefore, pixel-wise metrics (PSNR, SSIM, MAE) are inappropriate.

### Distribution-based Metrics (Used):
- **FID**: Fréchet Inception Distance - Primary metric for GAN evaluation
- **Inception Score (IS)**: Measures quality and diversity of generated images
- **Color Histogram Similarity**: Critical for staining quality assessment
- **Texture Metrics**: GLCM-based Haralick features for tissue structure analysis

### NOT Used (Require pixel alignment):
- ~~PSNR, SSIM, MAE~~ - Inappropriate for unpaired images
- ~~LPIPS~~ - Assumes spatial correspondence

## 1. Import Libraries and Setup

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import inception_v3
from PIL import Image
import torch.utils.data as data
from glob import glob
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg
from scipy.stats import gaussian_kde, entropy
import pandas as pd
import seaborn as sns
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2gray
import warnings
import os
warnings.filterwarnings('ignore')

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

## 2. Test Configuration

In [None]:
# Test parameters
params = {
    'batch_size': 1,
    'input_size': 512,
    'test_sample_count': 100,
    'model_epoch': 99,
    'img_form': 'png'
}

# Paths
data_dir = '../../data/IHC_HE_Pair_Data_GA_SS/patches'
model_dir = '../../model/HE_IHC_translation/internal_ss/PD-L1'
result_dir = '../../results/HE_IHC_translation/internal_ss/PD-L1/test_results'

# Create directories
create_dir(result_dir)
create_dir(f'{result_dir}/visualizations')
create_dir(f'{result_dir}/generated_images')

print("="*60)
print("TEST CONFIGURATION")
print("="*60)
print(f"Model epoch: {params['model_epoch']}")
print(f"Test samples: {params['test_sample_count']}")
print(f"Image size: {params['input_size']}x{params['input_size']}")
print(f"Results directory: {result_dir}")
print("="*60)

## 3. Define Generator Architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(4):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

print("✓ Generator architecture defined")

## 4. Load Test Dataset

In [None]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, HE_image_list, IHC_image_list):
        super(DatasetFromFolder, self).__init__()
        self.HE_image_list = HE_image_list
        self.IHC_image_list = IHC_image_list
    
    def __getitem__(self, index):
        return self.HE_image_list[index], self.IHC_image_list[index]
    
    def __len__(self):
        return len(self.HE_image_list)

# Transform
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor()
])

# Load test data
print("Loading test data...")
test_data_HE = glob(f'{data_dir}/he_mpp1/*.{params["img_form"]}')
test_max_count = min(params['test_sample_count'], len(test_data_HE))
test_data_HE = random.sample(test_data_HE, test_max_count)
test_data_IHC = [f.replace('/he_mpp1/', '/pdl1_mpp1/') for f in test_data_HE]

print(f"Found {len(test_data_HE)} test image pairs")

# Preload images
test_image_HE = torch.zeros((len(test_data_HE), 3, params['input_size'], params['input_size']))
test_image_IHC = torch.zeros((len(test_data_IHC), 3, params['input_size'], params['input_size']))

for i in tqdm(range(len(test_data_HE)), desc="Loading test images"):
    img = Image.open(test_data_HE[i]).convert('RGB')
    target = Image.open(test_data_IHC[i]).convert('RGB')
    img = transform(img) * 2. - 1
    target = transform(target) * 2. - 1
    test_image_HE[i] = img
    test_image_IHC[i] = target

test_dataset = DatasetFromFolder(test_image_HE, test_image_IHC)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=params['batch_size'],
    shuffle=False
)

print(f"✓ Test dataset loaded: {len(test_dataset)} image pairs")

## 5. Load Trained Model (Generator F: IHC → HE)

In [None]:
# Initialize and load model
F = Generator(3, 3).to(device)
model_path = f'{model_dir}/F_{params["model_epoch"]}.pth'
F.load_state_dict(torch.load(model_path, map_location=device))
F.eval()

print(f"✓ Loaded model: {model_path}")
print(f"✓ Generator F (IHC → HE) ready for testing")

## 6. Define Evaluation Metrics

In [None]:
def denormalize(img):
    """Denormalize from [-1, 1] to [0, 1]"""
    return img * 0.5 + 0.5

def calculate_histogram_similarity(img1, img2, bins=256):
    """Calculate histogram correlation for RGB channels"""
    correlations = []
    for channel in range(3):
        hist1, _ = np.histogram(img1[channel].flatten(), bins=bins, range=(0, 1))
        hist2, _ = np.histogram(img2[channel].flatten(), bins=bins, range=(0, 1))
        hist1 = hist1 / hist1.sum()
        hist2 = hist2 / hist2.sum()
        correlation = np.corrcoef(hist1, hist2)[0, 1]
        correlations.append(correlation)
    return np.mean(correlations)

def calculate_haralick_features(img):
    """Calculate GLCM-based Haralick texture features"""
    # Convert to grayscale and scale to [0, 255]
    if img.shape[0] == 3:  # CHW format
        gray = rgb2gray(img.transpose(1, 2, 0))
    else:
        gray = img
    
    gray_uint8 = (gray * 255).astype(np.uint8)
    
    # Calculate GLCM for 4 directions
    distances = [1]
    angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]
    glcm = graycomatrix(gray_uint8, distances=distances, angles=angles, 
                        levels=256, symmetric=True, normed=True)
    
    # Calculate Haralick features
    features = {
        'contrast': graycoprops(glcm, 'contrast').mean(),
        'dissimilarity': graycoprops(glcm, 'dissimilarity').mean(),
        'homogeneity': graycoprops(glcm, 'homogeneity').mean(),
        'energy': graycoprops(glcm, 'energy').mean(),
        'correlation': graycoprops(glcm, 'correlation').mean(),
        'ASM': graycoprops(glcm, 'ASM').mean()
    }
    return features

def calculate_texture_similarity(img1, img2):
    """Calculate texture similarity based on Haralick features"""
    features1 = calculate_haralick_features(img1)
    features2 = calculate_haralick_features(img2)
    
    # Calculate normalized L2 distance
    diff_sum = 0
    norm_sum = 0
    for key in features1.keys():
        diff = (features1[key] - features2[key]) ** 2
        norm = (features1[key] + features2[key]) ** 2 + 1e-10
        diff_sum += diff
        norm_sum += norm
    
    # Convert to similarity score (0 to 1, higher is better)
    similarity = 1 - np.sqrt(diff_sum / norm_sum)
    return similarity

print("✓ Evaluation metrics initialized")
print("  - Color Histogram Similarity (distribution-based)")
print("  - Texture Similarity (GLCM Haralick features)")
print("  - FID (will be calculated separately)")
print("  - Inception Score (will be calculated separately)")
print("\nNote: Pixel-wise metrics (PSNR, SSIM, MAE, LPIPS) are NOT used")
print("      because IHC and HE images are from consecutive sections")
print("      and are NOT pixel-aligned.")

## 7. Run Evaluation on Test Set

In [None]:
# Storage for metrics
results = {
    'hist_sim': [],
    'texture_sim': []
}

# Storage for ALL generated images (needed for FID and visualization)
generated_images = []
real_he_images = []
input_ihc_images = []

print("\n" + "="*60)
print("RUNNING EVALUATION ON TEST SET")
print("="*60)

with torch.no_grad():
    for idx, (real_he, real_ihc) in enumerate(tqdm(test_loader, desc="Evaluating")):
        real_he = real_he.to(device)
        real_ihc = real_ihc.to(device)
        
        # Generate fake HE from IHC
        fake_he = F(real_ihc)
        
        # Denormalize
        real_he_denorm = denormalize(real_he[0])
        fake_he_denorm = denormalize(fake_he[0])
        
        # Calculate color histogram similarity (distribution-based)
        results['hist_sim'].append(calculate_histogram_similarity(fake_he_denorm.cpu(), real_he_denorm.cpu()))
        
        # Calculate texture similarity (GLCM Haralick features)
        results['texture_sim'].append(calculate_texture_similarity(
            fake_he_denorm.cpu().numpy(), 
            real_he_denorm.cpu().numpy()
        ))
        
        # Store all images
        generated_images.append(fake_he_denorm.cpu())
        real_he_images.append(real_he_denorm.cpu())
        input_ihc_images.append(denormalize(real_ihc[0]).cpu())

print("\n" + "="*60)
print("✓ Evaluation completed!")
print("="*60)

## 8. Calculate FID Score

In [None]:
from scipy.linalg import sqrtm

# Load Inception V3
print("Loading Inception V3 model for FID calculation...")
inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
inception_model.fc = nn.Identity()
inception_model.eval()

def get_inception_features(images):
    features = []
    with torch.no_grad():
        for img in tqdm(images, desc="Extracting features"):
            # Resize to 299x299 for Inception
            img_resized = torch.nn.functional.interpolate(
                img.unsqueeze(0), size=(299, 299), mode='bilinear', align_corners=False
            )
            # Normalize to [-1, 1]
            img_normalized = img_resized * 2 - 1
            feat = inception_model(img_normalized.to(device))
            features.append(feat.cpu().numpy())
    return np.concatenate(features, axis=0)

def calculate_fid(features1, features2):
    mu1, sigma1 = features1.mean(axis=0), np.cov(features1, rowvar=False)
    mu2, sigma2 = features2.mean(axis=0), np.cov(features2, rowvar=False)
    
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

# Extract features
print("Extracting features from real HE images...")
real_features = get_inception_features(real_he_images)

print("Extracting features from generated HE images...")
fake_features = get_inception_features(generated_images)

# Calculate FID scores
print("\n" + "="*60)
print("FID Score Calculation")
print("="*60)

# Randomly shuffle both real and fake features for fair comparison
np.random.seed(42)  # For reproducibility

# 1. Real vs Real baseline (random split)
real_indices = np.random.permutation(len(real_features))
mid_point = len(real_features) // 2
real_features_1 = real_features[real_indices[:mid_point]]
real_features_2 = real_features[real_indices[mid_point:]]

fid_real_vs_real = calculate_fid(real_features_1, real_features_2)
print(f"FID (Real vs Real - Baseline): {fid_real_vs_real:.2f}")
print(f"  (Real features randomly shuffled)")

# 2. Fake vs Real (both randomly shuffled)
fake_indices = np.random.permutation(len(fake_features))
real_indices_full = np.random.permutation(len(real_features))

fake_features_shuffled = fake_features[fake_indices]
real_features_shuffled = real_features[real_indices_full]

fid_fake_vs_real = calculate_fid(fake_features_shuffled, real_features_shuffled)
print(f"\nFID (Fake vs Real): {fid_fake_vs_real:.2f}")
print(f"  (Both fake and real features randomly shuffled)")

# Calculate relative FID
fid_difference = fid_fake_vs_real - fid_real_vs_real
print(f"\nFID Difference (Fake-Real): {fid_difference:.2f}")

print("\n" + "="*60)
print("Interpretation:")
print(f"  - Baseline FID (Real vs Real): {fid_real_vs_real:.2f}")
print(f"    Expected to be very low (<10), represents noise floor")
print(f"  - Fake vs Real FID: {fid_fake_vs_real:.2f}")
if fid_fake_vs_real < 50:
    print("    → Excellent quality")
elif fid_fake_vs_real < 100:
    print("    → Good quality")
elif fid_fake_vs_real < 200:
    print("    → Moderate quality")
else:
    print("    → Poor quality")
print(f"  - Relative FID increase: {fid_difference:.2f}")
print("    Lower is better (closer to baseline)")
print("="*60)

In [None]:
print("\n" + "="*60)
print("Inception Score Calculation")
print("="*60)

# Load Inception V3 for classification (not feature extraction)
inception_classifier = inception_v3(pretrained=True, transform_input=False).to(device)
inception_classifier.eval()

def calculate_inception_score(images, batch_size=32, splits=10):
    """
    Calculate Inception Score for generated images
    IS = exp(E[KL(p(y|x) || p(y))])
    Higher is better (measures quality and diversity)
    """
    n_images = len(images)
    
    # Get predictions
    preds = []
    with torch.no_grad():
        for i in tqdm(range(0, n_images, batch_size), desc="Computing IS"):
            batch = images[i:i+batch_size]
            batch_tensor = torch.stack([img for img in batch])
            
            # Resize to 299x299 for Inception
            batch_resized = torch.nn.functional.interpolate(
                batch_tensor, size=(299, 299), mode='bilinear', align_corners=False
            )
            # Normalize to [-1, 1]
            batch_normalized = batch_resized * 2 - 1
            
            # Get class probabilities
            pred = torch.nn.functional.softmax(inception_classifier(batch_normalized.to(device)), dim=1)
            preds.append(pred.cpu().numpy())
    
    preds = np.concatenate(preds, axis=0)
    
    # Calculate IS
    split_scores = []
    for k in range(splits):
        part = preds[k * (n_images // splits): (k + 1) * (n_images // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))
    
    return np.mean(split_scores), np.std(split_scores)

# Calculate IS for generated images
is_mean, is_std = calculate_inception_score(generated_images)

print(f"\nInception Score: {is_mean:.2f} ± {is_std:.2f}")
print("\nInterpretation:")
if is_mean > 8:
    print("  → Excellent quality and diversity")
elif is_mean > 5:
    print("  → Good quality and diversity")
elif is_mean > 3:
    print("  → Moderate quality")
else:
    print("  → Poor quality or low diversity")
print("\nNote: Higher IS indicates better image quality and diversity")
print("      Real images typically have IS > 10")
print("="*60)

## 10. Calculate Statistics

In [None]:
# Calculate statistics
statistics = {}
for metric_name, values in results.items():
    statistics[metric_name] = {
        'mean': np.mean(values),
        'std': np.std(values),
        'min': np.min(values),
        'max': np.max(values),
        'median': np.median(values)
    }

# Add FID scores to statistics
statistics['fid_baseline'] = {
    'mean': fid_real_vs_real,
    'std': 0,
    'min': fid_real_vs_real,
    'max': fid_real_vs_real,
    'median': fid_real_vs_real
}

statistics['fid_fake_real'] = {
    'mean': fid_fake_vs_real,
    'std': 0,
    'min': fid_fake_vs_real,
    'max': fid_fake_vs_real,
    'median': fid_fake_vs_real
}

statistics['fid_difference'] = {
    'mean': fid_difference,
    'std': 0,
    'min': fid_difference,
    'max': fid_difference,
    'median': fid_difference
}

# Add Inception Score to statistics
statistics['inception_score'] = {
    'mean': is_mean,
    'std': is_std,
    'min': is_mean,
    'max': is_mean,
    'median': is_mean
}

# Display results
print("\n" + "="*80)
print(" "*25 + "QUANTITATIVE RESULTS")
print("="*80)
print(f"{'Metric':<20} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
print("-"*80)

metric_display = {
    'hist_sim': 'Hist Corr',
    'texture_sim': 'Texture Sim',
    'inception_score': 'Inception Score',
    'fid_baseline': 'FID (Real-Real)',
    'fid_fake_real': 'FID (Fake-Real)',
    'fid_difference': 'FID Difference'
}

for metric_name, display_name in metric_display.items():
    stats = statistics[metric_name]
    print(f"{display_name:<20} {stats['mean']:<12.4f} {stats['std']:<12.4f} "
          f"{stats['min']:<12.4f} {stats['max']:<12.4f}")

print("="*80)
print("\nNote: Pixel-wise metrics (PSNR, SSIM, MAE, LPIPS) are NOT reported")
print("      because IHC and HE are from consecutive sections (unpaired).")

# Save to CSV
df = pd.DataFrame(statistics).T
df.columns = ['Mean', 'Std', 'Min', 'Max', 'Median']
csv_path = f'{result_dir}/quantitative_results.csv'
df.to_csv(csv_path)
print(f"\n✓ Results saved to: {csv_path}")

## 11. Visualize Metric Distributions

In [None]:
# Set style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

# Create subplots for distribution metrics
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle('Distribution-Based Metrics (IHC → HE Translation)', 
             fontsize=16, fontweight='bold', y=0.98)

# Histogram Correlation distribution
ax = axes[0]
values = results['hist_sim']
ax.hist(values, bins=30, alpha=0.6, color='skyblue', edgecolor='black', density=True)
kde = gaussian_kde(values)
x_range = np.linspace(min(values), max(values), 100)
ax.plot(x_range, kde(x_range), 'r-', linewidth=2, label='KDE')
mean_val = statistics['hist_sim']['mean']
median_val = statistics['hist_sim']['median']
ax.axvline(mean_val, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
ax.axvline(median_val, color='orange', linestyle='--', linewidth=2, label=f'Median: {median_val:.4f}')
ax.set_xlabel('Histogram Correlation', fontsize=12, fontweight='bold')
ax.set_ylabel('Density', fontsize=12, fontweight='bold')
ax.set_title('Color Histogram Correlation\n(Higher is better)', fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Texture Similarity distribution
ax = axes[1]
values = results['texture_sim']
ax.hist(values, bins=30, alpha=0.6, color='lightcoral', edgecolor='black', density=True)
kde = gaussian_kde(values)
x_range = np.linspace(min(values), max(values), 100)
ax.plot(x_range, kde(x_range), 'r-', linewidth=2, label='KDE')
mean_val = statistics['texture_sim']['mean']
median_val = statistics['texture_sim']['median']
ax.axvline(mean_val, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
ax.axvline(median_val, color='orange', linestyle='--', linewidth=2, label=f'Median: {median_val:.4f}')
ax.set_xlabel('Texture Similarity (GLCM)', fontsize=12, fontweight='bold')
ax.set_ylabel('Density', fontsize=12, fontweight='bold')
ax.set_title('Texture Similarity (Haralick Features)\n(Higher is better)', fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
fig_path = f'{result_dir}/visualizations/metric_distributions.png'
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
plt.show()

# Create separate figure for FID and IS summary
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.axis('off')

summary_text = 'GAN Evaluation Metrics Summary\n\n'
summary_text += f'FID Scores:\n'
summary_text += f'  Real vs Real (Baseline): {fid_real_vs_real:.2f}\n'
summary_text += f'  Fake vs Real: {fid_fake_vs_real:.2f}\n'
summary_text += f'  Difference: {fid_difference:.2f}\n\n'
if fid_fake_vs_real < 50:
    summary_text += '  FID Quality: Excellent\n\n'
elif fid_fake_vs_real < 100:
    summary_text += '  FID Quality: Good\n\n'
elif fid_fake_vs_real < 200:
    summary_text += '  FID Quality: Moderate\n\n'
else:
    summary_text += '  FID Quality: Poor\n\n'

summary_text += f'Inception Score:\n'
summary_text += f'  IS: {is_mean:.2f} ± {is_std:.2f}\n'
if is_mean > 8:
    summary_text += '  IS Quality: Excellent\n'
elif is_mean > 5:
    summary_text += '  IS Quality: Good\n'
elif is_mean > 3:
    summary_text += '  IS Quality: Moderate\n'
else:
    summary_text += '  IS Quality: Poor\n'

ax.text(0.5, 0.5, summary_text, transform=ax.transAxes,
        ha='center', va='center', fontsize=14, fontweight='bold',
        bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8),
        family='monospace')

summary_path = f'{result_dir}/visualizations/gan_metrics_summary.png'
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Metric distributions saved to: {fig_path}")
print(f"✓ GAN metrics summary saved to: {summary_path}")

## 12. Color Histogram Analysis

In [None]:
# Visualize color histograms
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
fig.suptitle('Color Histogram Comparison: Generated vs Real HE', fontsize=14, fontweight='bold')

colors = ['red', 'green', 'blue']
channel_names = ['Red', 'Green', 'Blue']

for ch_idx in range(3):
    ax = axes[ch_idx]
    
    # Calculate average histogram
    real_hist = np.zeros(256)
    gen_hist = np.zeros(256)
    
    for real_img, gen_img in zip(real_he_images, generated_images):
        h1, _ = np.histogram(real_img[ch_idx].flatten(), bins=256, range=(0, 1))
        h2, _ = np.histogram(gen_img[ch_idx].flatten(), bins=256, range=(0, 1))
        real_hist += h1
        gen_hist += h2
    
    real_hist = real_hist / real_hist.sum()
    gen_hist = gen_hist / gen_hist.sum()
    
    x = np.linspace(0, 1, 256)
    ax.plot(x, real_hist, color=colors[ch_idx], linewidth=2, label='Real HE', alpha=0.7)
    ax.plot(x, gen_hist, color=colors[ch_idx], linewidth=2, label='Generated HE', 
            linestyle='--', alpha=0.7)
    ax.set_title(f'{channel_names[ch_idx]} Channel', fontsize=12, fontweight='bold')
    ax.set_xlabel('Intensity', fontsize=10)
    ax.set_ylabel('Frequency', fontsize=10)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
hist_path = f'{result_dir}/visualizations/color_histograms.png'
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Color histogram saved to: {hist_path}")
print(f"\nMean Histogram Correlation: {statistics['hist_sim']['mean']:.4f}")

## 13. Qualitative Results

In [None]:
# Show 8 random samples
num_samples = 8
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
fig.suptitle('Qualitative Results: IHC → HE Translation', fontsize=18, fontweight='bold')

for i in range(num_samples):
    # Input IHC
    axes[i, 0].imshow(input_ihc_images[i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title('Input (IHC)', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Generated HE
    axes[i, 1].imshow(generated_images[i].permute(1, 2, 0).numpy())
    axes[i, 1].set_title('Generated (HE)', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Real HE
    axes[i, 2].imshow(real_he_images[i].permute(1, 2, 0).numpy())
    axes[i, 2].set_title('Ground Truth (HE)', fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')

plt.tight_layout()
qual_path = f'{result_dir}/visualizations/qualitative_results.png'
plt.savefig(qual_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Qualitative results saved to: {qual_path}")

## 14. Best and Worst Cases

In [None]:
# Find best and worst based on histogram correlation
hist_values = np.array(results['hist_sim'])

best_indices = np.argsort(hist_values)[-4:][::-1]
worst_indices = np.argsort(hist_values)[:4]

# Best cases
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
fig.suptitle('Best Translation Results (Highest Histogram Correlation)', fontsize=16, fontweight='bold')

for idx, sample_idx in enumerate(best_indices):
    axes[idx, 0].imshow(input_ihc_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 0].set_title('Input IHC', fontsize=11)
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(generated_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 1].set_title('Generated HE', fontsize=11)
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(real_he_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 2].set_title('Ground Truth HE', fontsize=11)
    axes[idx, 2].axis('off')
    
    metrics_text = f'Hist Corr: {hist_values[sample_idx]:.4f}'
    axes[idx, 0].text(-0.15, 0.5, metrics_text, transform=axes[idx, 0].transAxes,
                     fontsize=10, fontweight='bold', verticalalignment='center',
                     bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

plt.tight_layout()
best_path = f'{result_dir}/visualizations/best_cases.png'
plt.savefig(best_path, dpi=300, bbox_inches='tight')
plt.show()

# Worst cases
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
fig.suptitle('Worst Translation Results (Lowest Histogram Correlation)', fontsize=16, fontweight='bold')

for idx, sample_idx in enumerate(worst_indices):
    axes[idx, 0].imshow(input_ihc_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 0].set_title('Input IHC', fontsize=11)
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(generated_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 1].set_title('Generated HE', fontsize=11)
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(real_he_images[sample_idx].permute(1, 2, 0).numpy())
    axes[idx, 2].set_title('Ground Truth HE', fontsize=11)
    axes[idx, 2].axis('off')
    
    metrics_text = f'Hist Corr: {hist_values[sample_idx]:.4f}'
    axes[idx, 0].text(-0.15, 0.5, metrics_text, transform=axes[idx, 0].transAxes,
                     fontsize=10, fontweight='bold', verticalalignment='center',
                     bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.8))

plt.tight_layout()
worst_path = f'{result_dir}/visualizations/worst_cases.png'
plt.savefig(worst_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Best cases saved to: {best_path}")
print(f"✓ Worst cases saved to: {worst_path}")

## 15. Generate Paper-Ready Summary

In [None]:
# Create paper-ready table
paper_table = pd.DataFrame({
    'Metric': [
        'Hist Corr ↑', 
        'Texture Sim ↑', 
        'Inception Score ↑',
        'FID (Fake-Real) ↓', 
        'FID (Baseline) ↓', 
        'FID Difference ↓'
    ],
    'Value': [
        f"{statistics['hist_sim']['mean']:.4f} ± {statistics['hist_sim']['std']:.4f}",
        f"{statistics['texture_sim']['mean']:.4f} ± {statistics['texture_sim']['std']:.4f}",
        f"{is_mean:.2f} ± {is_std:.2f}",
        f"{fid_fake_vs_real:.2f}",
        f"{fid_real_vs_real:.2f}",
        f"{fid_difference:.2f}"
    ]
})

print("\n" + "="*70)
print(" "*20 + "PAPER-READY RESULTS")
print("="*70)
print(paper_table.to_string(index=False))
print("="*70)
print("Note: ↑ = higher is better, ↓ = lower is better")
print("\nMetrics are distribution-based (appropriate for unpaired images)")
print("  - Color Histogram Correlation: Stain color distribution match")
print("  - Texture Similarity: GLCM Haralick features for tissue structure")
print("  - Inception Score: GAN generation quality and diversity")
print("  - FID: Fréchet distance in Inception feature space")
print("\nPixel-wise metrics (PSNR, SSIM, MAE, LPIPS) are NOT applicable")
print("because IHC and HE are from consecutive tissue sections.")

# Save LaTeX table
latex_table = paper_table.to_latex(index=False, escape=False)
latex_path = f'{result_dir}/paper_table.tex'
with open(latex_path, 'w') as f:
    f.write(latex_table)

print(f"\n✓ LaTeX table saved to: {latex_path}")

## 16. Final Summary Report

In [None]:
# Generate comprehensive report
report = f"""
{'='*80}
                    CYCLEGAN VIRTUAL STAINING TEST REPORT
                        IHC to HE Translation Evaluation
{'='*80}

TEST CONFIGURATION
{'-'*80}
Model Path:           {model_path}
Test Samples:         {len(test_dataset)}
Image Size:           {params['input_size']}x{params['input_size']}
Device:               {device}
Model Epoch:          {params['model_epoch']}

IMPORTANT NOTE
{'-'*80}
IHC and HE images are from CONSECUTIVE TISSUE SECTIONS, not pixel-aligned.
Therefore, only DISTRIBUTION-BASED metrics are reported:
  ✓ FID (Fréchet Inception Distance) - Primary GAN metric
  ✓ Inception Score (IS) - Quality and diversity measure
  ✓ Color Histogram Correlation - Stain color distribution
  ✓ Texture Similarity - GLCM Haralick features for tissue structure

Pixel-wise metrics (PSNR, SSIM, MAE, LPIPS) are NOT applicable because
they require spatial correspondence between images.

QUANTITATIVE RESULTS
{'-'*80}
Metric                Value
{'-'*80}
Histogram Corr       {statistics['hist_sim']['mean']:.4f} ± {statistics['hist_sim']['std']:.4f}
Texture Similarity   {statistics['texture_sim']['mean']:.4f} ± {statistics['texture_sim']['std']:.4f}
Inception Score      {is_mean:.2f} ± {is_std:.2f}
FID (Baseline)       {fid_real_vs_real:.2f}
FID (Fake-Real)      {fid_fake_vs_real:.2f}
FID Difference       {fid_difference:.2f}
{'-'*80}

INTERPRETATION
{'-'*80}
FID:  {'Excellent' if fid_fake_vs_real < 50 else 'Good' if fid_fake_vs_real < 100 else 'Moderate' if fid_fake_vs_real < 200 else 'Poor'} (target: <50)
FID Baseline: {fid_real_vs_real:.2f} (noise floor)
Inception Score: {'Excellent' if is_mean > 8 else 'Good' if is_mean > 5 else 'Moderate' if is_mean > 3 else 'Poor'} (real images: >10)
Color Match: {'Excellent' if statistics['hist_sim']['mean'] > 0.9 else 'Good' if statistics['hist_sim']['mean'] > 0.8 else 'Moderate'} (target: >0.9)
Texture Match: {'Excellent' if statistics['texture_sim']['mean'] > 0.9 else 'Good' if statistics['texture_sim']['mean'] > 0.8 else 'Moderate'} (target: >0.8)

OUTPUT FILES
{'-'*80}
✓ Quantitative results:     {result_dir}/quantitative_results.csv
✓ LaTeX table:             {result_dir}/paper_table.tex
✓ Metric distributions:    {result_dir}/visualizations/metric_distributions.png
✓ GAN metrics summary:     {result_dir}/visualizations/gan_metrics_summary.png
✓ Color histograms:        {result_dir}/visualizations/color_histograms.png
✓ Qualitative results:     {result_dir}/visualizations/qualitative_results.png
✓ Best cases:              {result_dir}/visualizations/best_cases.png
✓ Worst cases:             {result_dir}/visualizations/worst_cases.png

{'='*80}
                    EVALUATION COMPLETED SUCCESSFULLY
{'='*80}
"""

print(report)

# Save report
report_path = f'{result_dir}/evaluation_report.txt'
with open(report_path, 'w') as f:
    f.write(report)

print(f"✓ Full report saved to: {report_path}")