In [None]:
# Cell 1: Imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
import pandas as pd
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from scipy import stats

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [None]:
# Cell 2: Data download
data_path = Path("../data")
results_path = Path("../results")

print("Loading original data...")
low_dose = np.load(data_path / "03_denoising_SrTiO3_High_mag_Low_dose.npy")
high_dose = np.load(data_path / "03_denoising_SrTiO3_High_mag_High_dose.npy")


pacbed_high = np.mean(high_dose, axis=(0, 1))
threshold = 0.1 * pacbed_high.max()
bf_mask = pacbed_high > threshold

print(f"Data shape: {low_dose.shape}")
print(f"Bright field pixels: {bf_mask.sum()} / {bf_mask.size}")

In [None]:
# Cell 3: Loading denoising results
denoised_results = {}


possible_files = [
    ('U-Net', ['denoised_u-net.npy', 'denoised_unet.npy', 'denoised_U-Net.npy', 
               'denoised_original_unet.npy', 'denoised.npy']),
    ('CNN Autoencoder', ['denoised_cnn_autoencoder.npy', 'denoised_cnn.npy', 
                        'denoised_autoencoder.npy', 'denoised_simple_cnn.npy'])
]

print("\nSearching for denoised results...")
for model_name, file_names in possible_files:
    for file_name in file_names:
        file_path = results_path / file_name
        if file_path.exists():
            print(f"Found {model_name} results: {file_name}")
            denoised_results[model_name] = np.load(file_path)
            break

if not denoised_results:
    print("\nNo denoised results found!")
    print("Looking for any .npy files in results folder...")
    npy_files = list(results_path.glob("denoised*.npy"))
    for i, file_path in enumerate(npy_files[:5]): 
        print(f"  {i+1}. {file_path.name}")
        denoised_results[f"Model {i+1}"] = np.load(file_path)

print(f"\nLoaded {len(denoised_results)} denoised results")

In [None]:
# Cell 4: Model information
model_info = {
    'U-Net': {
        'parameters': 1.2e6,  
        'has_skip_connections': True,
        'architecture': 'Encoder-Decoder with skip connections'
    },
    'CNN Autoencoder': {
        'parameters': 0.2e6,  
        'has_skip_connections': False,
        'architecture': 'Encoder-Decoder without skip connections'
    }
}

In [None]:
# Cell 5: Metrics calculation function
def compute_metrics(prediction, target, mask=None):
    
    if mask is not None:
        prediction = prediction * mask
        target = target * mask
    
    
    pred_norm = prediction / prediction.max() if prediction.max() > 0 else prediction
    target_norm = target / target.max() if target.max() > 0 else target
    
    
    if prediction.max() == 0 or target.max() == 0:
        return {
            'MSE': np.inf,
            'MAE': np.inf,
            'PSNR': 0,
            'SSIM': 0,
            'Correlation': 0
        }
    
    metrics = {
        'MSE': np.mean((prediction - target) ** 2),
        'MAE': np.mean(np.abs(prediction - target)),
        'PSNR': psnr(target_norm, pred_norm, data_range=1.0),
        'SSIM': ssim(target_norm, pred_norm, data_range=1.0),
        'Correlation': np.corrcoef(prediction.flatten(), target.flatten())[0, 1]
    }
    
    return metrics


In [None]:
# Cell 6: Calculating metrics for all results
print("\n=== COMPUTING QUALITY METRICS ===")

all_metrics = {name: [] for name in denoised_results.keys()}
n_positions = 1000  


np.random.seed(42)
sample_positions = []
for _ in range(n_positions):
    x = np.random.randint(1, low_dose.shape[0] - 1)
    y = np.random.randint(1, low_dose.shape[1] - 1)
    sample_positions.append((x, y))

print(f"Computing metrics for {n_positions} random positions...")

for x, y in tqdm(sample_positions):
    target = high_dose[x, y]
    
    for name, denoised_data in denoised_results.items():
        pred = denoised_data[x, y]
        metrics = compute_metrics(pred, target, bf_mask)
        all_metrics[name].append(metrics)


aggregated_metrics = {}
for name in denoised_results.keys():
    metrics_list = all_metrics[name]
    aggregated_metrics[name] = {
        metric: {
            'mean': np.mean([m[metric] for m in metrics_list if not np.isinf(m[metric])]),
            'std': np.std([m[metric] for m in metrics_list if not np.isinf(m[metric])]),
            'median': np.median([m[metric] for m in metrics_list if not np.isinf(m[metric])])
        }
        for metric in metrics_list[0].keys()
    }

In [None]:
# Cell 7: Visualisation of parameter comparison
print("\n=== MODEL COMPARISON ===")


model_comparison = []
for name in denoised_results.keys():
    info = model_info.get(name, {
        'parameters': 'Unknown',
        'has_skip_connections': 'Unknown',
        'architecture': 'Unknown'
    })
    
    model_comparison.append({
        'Model': name,
        'Parameters': f"{info['parameters']/1e6:.1f}M" if isinstance(info['parameters'], (int, float)) else info['parameters'],
        'Skip Connections': info['has_skip_connections'],
        'PSNR (dB)': f"{aggregated_metrics[name]['PSNR']['mean']:.2f}±{aggregated_metrics[name]['PSNR']['std']:.2f}"
    })

comparison_df = pd.DataFrame(model_comparison)
print("\nModel Comparison:")
print(comparison_df.to_string(index=False))

In [None]:
# Cell 8: PSNR visualisation - the main metric
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))


model_names = list(denoised_results.keys())
psnr_means = [aggregated_metrics[name]['PSNR']['mean'] for name in model_names]
psnr_stds = [aggregated_metrics[name]['PSNR']['std'] for name in model_names]


colors = plt.cm.viridis(np.linspace(0, 1, len(model_names)))


bars = ax1.bar(model_names, psnr_means, yerr=psnr_stds, 
                capsize=10, color=colors, alpha=0.8)
ax1.set_ylabel('PSNR (dB)', fontsize=12)
ax1.set_title('Peak Signal-to-Noise Ratio Comparison', fontsize=14)
ax1.set_ylim([0, max(psnr_means) * 1.2])
ax1.grid(axis='y', alpha=0.3)


for bar, mean, std in zip(bars, psnr_means, psnr_stds):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + std,
             f'{mean:.2f}', ha='center', va='bottom', fontsize=10)


metrics_names = ['PSNR', 'SSIM', 'Correlation']
x = np.arange(len(model_names))
width = 0.25

for i, metric in enumerate(metrics_names):
    values = [aggregated_metrics[name][metric]['mean'] for name in model_names]
    
    
    if metric == 'PSNR':
        values = [v/40 for v in values]
        
    offset = (i - 1) * width
    bars = ax2.bar(x + offset, values, width, label=metric, alpha=0.8)

ax2.set_xlabel('Model', fontsize=12)
ax2.set_ylabel('Normalized Metric Value', fontsize=12)
ax2.set_title('Multi-Metric Comparison', fontsize=14)
ax2.set_xticks(x)
ax2.set_xticklabels(model_names)
ax2.legend()
ax2.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig(results_path / 'figures/metrics_comparison_from_results.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Cell 9: Visual comparison of results
print("\n=== VISUAL COMPARISON ===")


positions = [(100, 100), (20, 20), (30, 30), (40, 40), (60, 60), (70, 70), (80, 80), (90, 90), (110, 110), (120, 120), (130, 130), (140, 140)]
n_models = len(denoised_results)

fig, axes = plt.subplots(len(positions), n_models + 2, 
                        figsize=(4*(n_models+2), 4*len(positions)))


if len(positions) == 1:
    axes = axes.reshape(1, -1)

for i, (x, y) in enumerate(positions):

    vmax = max([low_dose[x, y].max()] + 
               [denoised_results[name][x, y].max() for name in model_names])
    vmin = 0
    
    # Noisy
    im = axes[i, 0].imshow(low_dose[x, y], cmap='viridis', vmin=vmin, vmax=vmax)
    axes[i, 0].set_title(f'Noisy\n({x}, {y})')
    axes[i, 0].axis('off')
    
    # Denoised results
    for j, name in enumerate(model_names):
        im = axes[i, j+1].imshow(denoised_results[name][x, y], 
                                cmap='viridis')
        axes[i, j+1].set_title(name)
        axes[i, j+1].axis('off')
    
    # High dose reference
    im = axes[i, -1].imshow(high_dose[x, y], cmap='viridis')
    axes[i, -1].set_title('High Dose')
    axes[i, -1].axis('off')

plt.tight_layout()
plt.savefig(results_path / 'figures/visual_comparison_results.png', dpi=300, bbox_inches='tight')
plt.show()
