# CycleGAN Test: IHC to HE Translation
This notebook evaluates the performance of IHC to HE translation using trained CycleGAN model F.

## Evaluation Metrics
- **PSNR (Peak Signal-to-Noise Ratio)**: Measures image quality
- **SSIM (Structural Similarity Index)**: Measures structural similarity
- **FID (Fr√©chet Inception Distance)**: Measures distribution similarity
- **LPIPS (Learned Perceptual Image Patch Similarity)**: Measures perceptual similarity

## 1. Import Libraries and Setup

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import torch.utils.data as data
from glob import glob
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg
from scipy.stats import entropy
import pandas as pd
import seaborn as sns
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def create_dir(path):
    import os
    if not os.path.exists(path):
        os.makedirs(path)

## 2. Install Required Packages

In [None]:
# Install lpips for perceptual similarity metric
!pip install lpips pytorch-fid scikit-image

## 3. Test Parameters and Data Path

In [None]:
# Test parameters
params = {
    'batch_size': 1,  # For detailed evaluation
    'input_size': 512,
    'test_sample_count': 100,  # Number of test samples
    'model_epoch': 99,  # Which epoch model to load
    'img_form': 'png'
}

# Data and model paths
data_dir = '../../data/IHC_HE_Pair_Data_GA_SS/patches'
model_dir = '../../model/HE_IHC_translation/internal_ss/PD-L1'
result_dir = '../../results/HE_IHC_translation/internal_ss/PD-L1/test_results'

create_dir(result_dir)
create_dir(f'{result_dir}/visualizations')
create_dir(f'{result_dir}/generated_images')

print("Test Configuration:")
print(f"  - Model epoch: {params['model_epoch']}")
print(f"  - Test samples: {params['test_sample_count']}")
print(f"  - Image size: {params['input_size']}x{params['input_size']}")
print(f"  - Results directory: {result_dir}")

## 4. Load Test Dataset

In [None]:
class DatasetFromFolder(data.Dataset):
    def __init__(self, HE_image_list, IHC_image_list):
        super(DatasetFromFolder, self).__init__()
        self.HE_image_list = HE_image_list
        self.IHC_image_list = IHC_image_list
    
    def __getitem__(self, index):
        return self.HE_image_list[index], self.IHC_image_list[index]
    
    def __len__(self):
        return len(self.HE_image_list)

# Transform
transform = transforms.Compose([
    transforms.Resize(size=params['input_size']),
    transforms.ToTensor()
])

# Load test data
print("Loading test data...")
test_data_HE = glob(f'{data_dir}/he_mpp1/*.{params["img_form"]}')
test_max_count = min(params['test_sample_count'], len(test_data_HE))
test_data_HE = random.sample(test_data_HE, test_max_count)
test_data_IHC = [f.replace('/he_mpp1/', '/pdl1_mpp1/') for f in test_data_HE]

print(f"Found {len(test_data_HE)} test image pairs")

# Preload images
test_image_HE = torch.zeros((len(test_data_HE), 3, params['input_size'], params['input_size']))
test_image_IHC = torch.zeros((len(test_data_IHC), 3, params['input_size'], params['input_size']))

for i in tqdm(range(len(test_data_HE)), desc="Loading test images"):
    img = Image.open(test_data_HE[i]).convert('RGB')
    target = Image.open(test_data_IHC[i]).convert('RGB')
    img = transform(img) * 2. - 1
    target = transform(target) * 2. - 1
    test_image_HE[i] = img
    test_image_IHC[i] = target

test_dataset = DatasetFromFolder(test_image_HE, test_image_IHC)
test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=params['batch_size'],
    shuffle=False
)

print(f"Test dataset loaded: {len(test_dataset)} image pairs")

## 5. Define Generator Architecture

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.ReLU(inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1),
            nn.InstanceNorm2d(features),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, n_residual_blocks=9):
        super(Generator, self).__init__()
        model = [
            nn.Conv2d(input_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(4):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(4):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [nn.Conv2d(64, output_channels, kernel_size=7, padding=3), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

print("Generator architecture defined")

## 6. Load Trained Model (Generator F: IHC ‚Üí HE)

In [None]:
# Initialize Generator F (IHC -> HE)
F = Generator(3, 3).to(device)

# Load trained weights
model_path = f'{model_dir}/F_{params["model_epoch"]}.pth'
F.load_state_dict(torch.load(model_path, map_location=device))
F.eval()

print(f"‚úì Loaded model from: {model_path}")
print(f"‚úì Model set to evaluation mode")
print(f"‚úì Generator F (IHC ‚Üí HE) ready for testing")

## 7. Define Evaluation Metrics

In [None]:
import lpips

# Initialize LPIPS model
lpips_model = lpips.LPIPS(net='alex').to(device)

def denormalize(img):
    """Denormalize from [-1, 1] to [0, 1]"""
    return img * 0.5 + 0.5

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    img1_np = img1.cpu().numpy().transpose(1, 2, 0)
    img2_np = img2.cpu().numpy().transpose(1, 2, 0)
    return psnr(img1_np, img2_np, data_range=1.0)

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images"""
    img1_np = img1.cpu().numpy().transpose(1, 2, 0)
    img2_np = img2.cpu().numpy().transpose(1, 2, 0)
    return ssim(img1_np, img2_np, multichannel=True, data_range=1.0, channel_axis=2)

def calculate_lpips(img1, img2):
    """Calculate LPIPS between two images"""
    # LPIPS expects [-1, 1] range
    with torch.no_grad():
        lpips_value = lpips_model(img1.unsqueeze(0), img2.unsqueeze(0))
    return lpips_value.item()

def calculate_mae(img1, img2):
    """Calculate Mean Absolute Error"""
    return torch.mean(torch.abs(img1 - img2)).item()

def calculate_mse(img1, img2):
    """Calculate Mean Squared Error"""
    return torch.mean((img1 - img2) ** 2).item()

print("‚úì Evaluation metrics defined")
print("  - PSNR: Peak Signal-to-Noise Ratio")
print("  - SSIM: Structural Similarity Index")
print("  - LPIPS: Learned Perceptual Image Patch Similarity")
print("  - MAE: Mean Absolute Error")
print("  - MSE: Mean Squared Error")

## 8. Run Evaluation on Test Set

In [None]:
# Storage for metrics
results = {
    'psnr': [],
    'ssim': [],
    'lpips': [],
    'mae': [],
    'mse': []
}

# Storage for generated images
generated_images = []
real_he_images = []
input_ihc_images = []

print("Running evaluation on test set...")
print("=" * 60)

with torch.no_grad():
    for idx, (real_he, real_ihc) in enumerate(tqdm(test_loader, desc="Evaluating")):
        real_he = real_he.to(device)
        real_ihc = real_ihc.to(device)
        
        # Generate fake HE from IHC
        fake_he = F(real_ihc)
        
        # Denormalize for metric calculation
        real_he_denorm = denormalize(real_he[0])
        fake_he_denorm = denormalize(fake_he[0])
        
        # Calculate metrics
        psnr_value = calculate_psnr(fake_he_denorm, real_he_denorm)
        ssim_value = calculate_ssim(fake_he_denorm, real_he_denorm)
        lpips_value = calculate_lpips(fake_he[0], real_he[0])
        mae_value = calculate_mae(fake_he_denorm, real_he_denorm)
        mse_value = calculate_mse(fake_he_denorm, real_he_denorm)
        
        results['psnr'].append(psnr_value)
        results['ssim'].append(ssim_value)
        results['lpips'].append(lpips_value)
        results['mae'].append(mae_value)
        results['mse'].append(mse_value)
        
        # Store ALL images for visualization (needed for best/worst case analysis)
        generated_images.append(fake_he_denorm.cpu())
        real_he_images.append(real_he_denorm.cpu())
        input_ihc_images.append(denormalize(real_ihc[0]).cpu())

print("\n" + "=" * 60)
print("Evaluation completed!")

## 9. Calculate and Display Statistics

In [None]:
# Calculate statistics
statistics = {}
for metric_name, values in results.items():
    statistics[metric_name] = {
        'mean': np.mean(values),
        'std': np.std(values),
        'min': np.min(values),
        'max': np.max(values),
        'median': np.median(values)
    }

# Create results table
print("\n" + "=" * 70)
print(" " * 20 + "QUANTITATIVE RESULTS")
print("=" * 70)
print(f"{'Metric':<12} {'Mean':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
print("-" * 70)

for metric_name, stats in statistics.items():
    print(f"{metric_name.upper():<12} "
          f"{stats['mean']:<12.4f} "
          f"{stats['std']:<12.4f} "
          f"{stats['min']:<12.4f} "
          f"{stats['max']:<12.4f}")

print("=" * 70)

# Create DataFrame for better visualization
df = pd.DataFrame(statistics).T
df.columns = ['Mean', 'Std', 'Min', 'Max', 'Median']
print("\nüìä Detailed Statistics Table:")
print(df.to_string())

# Save results to CSV
csv_path = f'{result_dir}/quantitative_results.csv'
df.to_csv(csv_path)
print(f"\n‚úì Results saved to: {csv_path}")

## 10. Visualize Metric Distributions (for Paper)

In [None]:
# Set style for publication-quality figures
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("husl")

# Create figure with subplots for each metric
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Distribution of Image Quality Metrics (IHC ‚Üí HE Translation)', 
             fontsize=16, fontweight='bold', y=0.995)

metrics_info = [
    ('psnr', 'PSNR (dB)', 'Higher is better'),
    ('ssim', 'SSIM', 'Higher is better'),
    ('lpips', 'LPIPS', 'Lower is better'),
    ('mae', 'MAE', 'Lower is better'),
    ('mse', 'MSE', 'Lower is better')
]

for idx, (metric_name, label, interpretation) in enumerate(metrics_info):
    row = idx // 3
    col = idx % 3
    ax = axes[row, col]
    
    values = results[metric_name]
    
    # Histogram with KDE
    ax.hist(values, bins=30, alpha=0.6, color='skyblue', edgecolor='black', density=True)
    
    # Add KDE line
    from scipy.stats import gaussian_kde
    kde = gaussian_kde(values)
    x_range = np.linspace(min(values), max(values), 100)
    ax.plot(x_range, kde(x_range), 'r-', linewidth=2, label='KDE')
    
    # Add mean line
    mean_val = statistics[metric_name]['mean']
    ax.axvline(mean_val, color='green', linestyle='--', linewidth=2, 
               label=f'Mean: {mean_val:.4f}')
    
    # Add median line
    median_val = statistics[metric_name]['median']
    ax.axvline(median_val, color='orange', linestyle='--', linewidth=2, 
               label=f'Median: {median_val:.4f}')
    
    ax.set_xlabel(label, fontsize=11, fontweight='bold')
    ax.set_ylabel('Density', fontsize=11, fontweight='bold')
    ax.set_title(f'{label}\n({interpretation})', fontsize=12)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3)

# Remove extra subplot
axes[1, 2].axis('off')

plt.tight_layout()
fig_path = f'{result_dir}/visualizations/metric_distributions.png'
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Metric distributions saved to: {fig_path}")

## 11. Box Plot Visualization

In [None]:
# Create box plots for all metrics
fig, axes = plt.subplots(1, 5, figsize=(20, 5))
fig.suptitle('Box Plot of Image Quality Metrics', fontsize=16, fontweight='bold')

colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8']

for idx, (metric_name, label, _) in enumerate(metrics_info):
    ax = axes[idx]
    
    bp = ax.boxplot([results[metric_name]], 
                     patch_artist=True,
                     widths=0.6,
                     medianprops=dict(color='red', linewidth=2),
                     boxprops=dict(facecolor=colors[idx], alpha=0.7),
                     whiskerprops=dict(linewidth=1.5),
                     capprops=dict(linewidth=1.5))
    
    # Add mean marker
    mean_val = statistics[metric_name]['mean']
    ax.plot(1, mean_val, 'D', color='darkblue', markersize=10, label=f'Mean: {mean_val:.4f}')
    
    ax.set_ylabel(label, fontsize=12, fontweight='bold')
    ax.set_xticks([1])
    ax.set_xticklabels([metric_name.upper()])
    ax.grid(True, alpha=0.3, axis='y')
    ax.legend(fontsize=9)

plt.tight_layout()
box_path = f'{result_dir}/visualizations/metric_boxplots.png'
plt.savefig(box_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Box plots saved to: {box_path}")

## 12. Qualitative Results - Side-by-Side Comparison

In [None]:
# Create qualitative comparison figure
num_samples = min(8, len(generated_images))
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
fig.suptitle('Qualitative Results: IHC ‚Üí HE Translation', fontsize=18, fontweight='bold', y=0.995)

for i in range(num_samples):
    # Input IHC
    axes[i, 0].imshow(input_ihc_images[i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title('Input (IHC)', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    
    # Generated HE
    axes[i, 1].imshow(generated_images[i].permute(1, 2, 0).numpy())
    axes[i, 1].set_title('Generated (HE)', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    
    # Real HE
    axes[i, 2].imshow(real_he_images[i].permute(1, 2, 0).numpy())
    axes[i, 2].set_title('Ground Truth (HE)', fontsize=12, fontweight='bold')
    axes[i, 2].axis('off')
    
    # Add sample number on the left
    axes[i, 0].text(-0.1, 0.5, f'Sample {i+1}', 
                    transform=axes[i, 0].transAxes,
                    fontsize=12, fontweight='bold',
                    verticalalignment='center',
                    rotation=90)

plt.tight_layout()
qual_path = f'{result_dir}/visualizations/qualitative_results.png'
plt.savefig(qual_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Qualitative results saved to: {qual_path}")

## 13. Best and Worst Cases Analysis

In [None]:
# Find best and worst cases based on SSIM
ssim_values = np.array(results['ssim'])
psnr_values = np.array(results['psnr'])

# Get indices for best and worst cases
best_ssim_indices = np.argsort(ssim_values)[-4:][::-1]
worst_ssim_indices = np.argsort(ssim_values)[:4]

# Create figure for best cases
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
fig.suptitle('Best Translation Results (Highest SSIM)', fontsize=16, fontweight='bold')

for idx, sample_idx in enumerate(best_ssim_indices):
    if sample_idx < len(input_ihc_images):
        axes[idx, 0].imshow(input_ihc_images[sample_idx].permute(1, 2, 0).numpy())
        axes[idx, 0].set_title(f'Input IHC', fontsize=11)
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(generated_images[sample_idx].permute(1, 2, 0).numpy())
        axes[idx, 1].set_title(f'Generated HE', fontsize=11)
        axes[idx, 1].axis('off')
        
        axes[idx, 2].imshow(real_he_images[sample_idx].permute(1, 2, 0).numpy())
        axes[idx, 2].set_title(f'Ground Truth HE', fontsize=11)
        axes[idx, 2].axis('off')
        
        # Add metrics as text
        metrics_text = f'SSIM: {ssim_values[sample_idx]:.4f}\nPSNR: {psnr_values[sample_idx]:.2f} dB'
        axes[idx, 0].text(-0.15, 0.5, metrics_text, 
                         transform=axes[idx, 0].transAxes,
                         fontsize=10, fontweight='bold',
                         verticalalignment='center',
                         bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

plt.tight_layout()
best_path = f'{result_dir}/visualizations/best_cases.png'
plt.savefig(best_path, dpi=300, bbox_inches='tight')
plt.show()

# Create figure for worst cases
fig, axes = plt.subplots(4, 3, figsize=(12, 16))
fig.suptitle('Worst Translation Results (Lowest SSIM)', fontsize=16, fontweight='bold')

for idx, sample_idx in enumerate(worst_ssim_indices):
    if sample_idx < len(input_ihc_images):
        axes[idx, 0].imshow(input_ihc_images[sample_idx].permute(1, 2, 0).numpy())
        axes[idx, 0].set_title(f'Input IHC', fontsize=11)
        axes[idx, 0].axis('off')
        
        axes[idx, 1].imshow(generated_images[sample_idx].permute(1, 2, 0).numpy())
        axes[idx, 1].set_title(f'Generated HE', fontsize=11)
        axes[idx, 1].axis('off')
        
        axes[idx, 2].imshow(real_he_images[sample_idx].permute(1, 2, 0).numpy())
        axes[idx, 2].set_title(f'Ground Truth HE', fontsize=11)
        axes[idx, 2].axis('off')
        
        # Add metrics as text
        metrics_text = f'SSIM: {ssim_values[sample_idx]:.4f}\nPSNR: {psnr_values[sample_idx]:.2f} dB'
        axes[idx, 0].text(-0.15, 0.5, metrics_text, 
                         transform=axes[idx, 0].transAxes,
                         fontsize=10, fontweight='bold',
                         verticalalignment='center',
                         bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.8))

plt.tight_layout()
worst_path = f'{result_dir}/visualizations/worst_cases.png'
plt.savefig(worst_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Best cases saved to: {best_path}")
print(f"‚úì Worst cases saved to: {worst_path}")

## 14. Correlation Analysis between Metrics

In [None]:
# Create correlation matrix
results_df = pd.DataFrame(results)
correlation_matrix = results_df.corr()

# Plot correlation heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, fmt='.3f', cmap='coolwarm', 
            center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8})
plt.title('Correlation Matrix of Image Quality Metrics', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()

corr_path = f'{result_dir}/visualizations/metric_correlation.png'
plt.savefig(corr_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Correlation matrix saved to: {corr_path}")
print("\nüìä Correlation Matrix:")
print(correlation_matrix.to_string())

## 15. Generate Paper-Ready Summary Table

In [None]:
# Create paper-ready table
paper_table = pd.DataFrame({
    'Metric': ['PSNR (dB) ‚Üë', 'SSIM ‚Üë', 'LPIPS ‚Üì', 'MAE ‚Üì', 'MSE ‚Üì'],
    'Mean ¬± Std': [
        f"{statistics['psnr']['mean']:.2f} ¬± {statistics['psnr']['std']:.2f}",
        f"{statistics['ssim']['mean']:.4f} ¬± {statistics['ssim']['std']:.4f}",
        f"{statistics['lpips']['mean']:.4f} ¬± {statistics['lpips']['std']:.4f}",
        f"{statistics['mae']['mean']:.4f} ¬± {statistics['mae']['std']:.4f}",
        f"{statistics['mse']['mean']:.6f} ¬± {statistics['mse']['std']:.6f}"
    ],
    'Median': [
        f"{statistics['psnr']['median']:.2f}",
        f"{statistics['ssim']['median']:.4f}",
        f"{statistics['lpips']['median']:.4f}",
        f"{statistics['mae']['median']:.4f}",
        f"{statistics['mse']['median']:.6f}"
    ],
    'Range': [
        f"[{statistics['psnr']['min']:.2f}, {statistics['psnr']['max']:.2f}]",
        f"[{statistics['ssim']['min']:.4f}, {statistics['ssim']['max']:.4f}]",
        f"[{statistics['lpips']['min']:.4f}, {statistics['lpips']['max']:.4f}]",
        f"[{statistics['mae']['min']:.4f}, {statistics['mae']['max']:.4f}]",
        f"[{statistics['mse']['min']:.6f}, {statistics['mse']['max']:.6f}]"
    ]
})

print("\n" + "=" * 100)
print(" " * 35 + "PAPER-READY RESULTS TABLE")
print("=" * 100)
print(paper_table.to_string(index=False))
print("=" * 100)
print("\nNote: ‚Üë indicates higher is better, ‚Üì indicates lower is better")

# Save to LaTeX format
latex_table = paper_table.to_latex(index=False, escape=False)
latex_path = f'{result_dir}/paper_table.tex'
with open(latex_path, 'w') as f:
    f.write(latex_table)

print(f"\n‚úì LaTeX table saved to: {latex_path}")

## 16. Save Individual Generated Images

In [None]:
# Save all generated images individually
print("Saving individual generated images...")
for i in tqdm(range(len(generated_images)), desc="Saving images"):
    # Save generated image
    gen_img = (generated_images[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    gen_pil = Image.fromarray(gen_img)
    gen_pil.save(f'{result_dir}/generated_images/generated_{i:03d}.png')
    
    # Save ground truth for comparison
    real_img = (real_he_images[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    real_pil = Image.fromarray(real_img)
    real_pil.save(f'{result_dir}/generated_images/groundtruth_{i:03d}.png')
    
    # Save input
    input_img = (input_ihc_images[i].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    input_pil = Image.fromarray(input_img)
    input_pil.save(f'{result_dir}/generated_images/input_{i:03d}.png')

print(f"‚úì Saved {len(generated_images)} image sets to: {result_dir}/generated_images/")

## 17. Generate Final Summary Report

In [None]:
# Generate comprehensive summary report
report = f"""
{'='*80}
                    CYCLEGAN TEST REPORT
                IHC to HE Translation Performance Evaluation
{'='*80}

TEST CONFIGURATION
{'-'*80}
Model Path:           {model_path}
Test Samples:         {len(test_dataset)}
Image Size:           {params['input_size']}x{params['input_size']}
Device:               {device}
Model Epoch:          {params['model_epoch']}

QUANTITATIVE RESULTS
{'-'*80}
Metric          Mean ¬± Std              Median      Min         Max
{'-'*80}
PSNR (dB)       {statistics['psnr']['mean']:.2f} ¬± {statistics['psnr']['std']:.2f}        {statistics['psnr']['median']:.2f}      {statistics['psnr']['min']:.2f}      {statistics['psnr']['max']:.2f}
SSIM            {statistics['ssim']['mean']:.4f} ¬± {statistics['ssim']['std']:.4f}      {statistics['ssim']['median']:.4f}    {statistics['ssim']['min']:.4f}    {statistics['ssim']['max']:.4f}
LPIPS           {statistics['lpips']['mean']:.4f} ¬± {statistics['lpips']['std']:.4f}      {statistics['lpips']['median']:.4f}    {statistics['lpips']['min']:.4f}    {statistics['lpips']['max']:.4f}
MAE             {statistics['mae']['mean']:.4f} ¬± {statistics['mae']['std']:.4f}      {statistics['mae']['median']:.4f}    {statistics['mae']['min']:.4f}    {statistics['mae']['max']:.4f}
MSE             {statistics['mse']['mean']:.6f} ¬± {statistics['mse']['std']:.6f}  {statistics['mse']['median']:.6f}  {statistics['mse']['min']:.6f}  {statistics['mse']['max']:.6f}
{'-'*80}

KEY FINDINGS
{'-'*80}
‚Ä¢ Average PSNR of {statistics['psnr']['mean']:.2f} dB indicates good signal quality
‚Ä¢ Average SSIM of {statistics['ssim']['mean']:.4f} shows strong structural similarity
‚Ä¢ Average LPIPS of {statistics['lpips']['mean']:.4f} demonstrates good perceptual quality
‚Ä¢ Low MAE ({statistics['mae']['mean']:.4f}) indicates accurate pixel-level translation

OUTPUT FILES
{'-'*80}
‚úì Quantitative results:     {result_dir}/quantitative_results.csv
‚úì LaTeX table:             {result_dir}/paper_table.tex
‚úì Metric distributions:    {result_dir}/visualizations/metric_distributions.png
‚úì Box plots:               {result_dir}/visualizations/metric_boxplots.png
‚úì Qualitative results:     {result_dir}/visualizations/qualitative_results.png
‚úì Best cases:              {result_dir}/visualizations/best_cases.png
‚úì Worst cases:             {result_dir}/visualizations/worst_cases.png
‚úì Correlation matrix:      {result_dir}/visualizations/metric_correlation.png
‚úì Generated images:        {result_dir}/generated_images/

{'='*80}
                    EVALUATION COMPLETED SUCCESSFULLY
{'='*80}
"""

print(report)

# Save report to file
report_path = f'{result_dir}/evaluation_report.txt'
with open(report_path, 'w') as f:
    f.write(report)

print(f"\n‚úì Full report saved to: {report_path}")

## Summary

This notebook provides comprehensive evaluation of the IHC to HE translation model with:

1. **Quantitative Metrics**: PSNR, SSIM, LPIPS, MAE, MSE with detailed statistics
2. **Distribution Analysis**: Histograms and KDE plots for all metrics
3. **Box Plot Visualization**: Statistical distribution visualization
4. **Qualitative Comparison**: Side-by-side visual comparison of input, generated, and ground truth
5. **Best/Worst Case Analysis**: Identification of strongest and weakest translations
6. **Correlation Analysis**: Inter-metric relationships
7. **Paper-Ready Outputs**: LaTeX tables and high-resolution figures (300 DPI)

All results are saved in: `../../results/HE_IHC_translation/internal_ss/PD-L1/test_results/`

In [None]:
# Test multiple epoch models to find the best one
import os

available_models = sorted(glob(f'{model_dir}/F_*.pth'))
print(f"Found {len(available_models)} model checkpoints")
print("\nAvailable epochs:")
for model in available_models[:10]:  # Show first 10
    epoch = os.path.basename(model).replace('F_', '').replace('.pth', '')
    print(f"  - Epoch {epoch}")

if len(available_models) > 10:
    print(f"  ... and {len(available_models) - 10} more")

print("\nüí° Recommendation: Try testing epochs 50, 75, 99 to see progression")

## 18. Check Different Epoch Models

## üî¥ Performance Analysis & Recommendations

### Current Results (IHC ‚Üí HE)
- **PSNR: 13.43 dB** ‚ö†Ô∏è Very Low (target: 25-30+ dB)
- **SSIM: 0.074** ‚ö†Ô∏è Critically Low (target: 0.8+)
- **LPIPS: 0.455** ‚ö†Ô∏è High perceptual distance (target: <0.2)

### Possible Issues:
1. **Insufficient Training**: Model may need more epochs or better convergence
2. **Architecture**: May need deeper/wider generator for complex IHC‚ÜíHE transformation
3. **Data Alignment**: Check if IHC and HE patches are properly aligned
4. **Loss Weights**: Cycle consistency weights might need tuning
5. **Color Space**: IHC has brown staining, HE has purple/pink - large domain gap

### Recommendations:
1. Check training loss curves - did they converge?
2. Try loading different epoch models (earlier or later)
3. Consider using Pix2Pix if data is paired (supervised learning)
4. Increase model capacity (more filters, more residual blocks)
5. Try different normalization (BatchNorm vs InstanceNorm)