# 07 - Comprehensive Evaluation

Compare all trained models: Baseline, Soft, Hard, and Hybrid.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from pathlib import Path
from skimage.metrics import structural_similarity as ssim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load Models

In [None]:
# Create 4 separate model instances
model_baseline = NeRF().to(device)
model_soft = NeRF().to(device)
model_hard = NeRF().to(device)
model_hybrid = NeRF().to(device)

# Load trained weights
model_baseline.load_state_dict(torch.load('results/baseline/model_baseline.pth', map_location=device))
model_soft.load_state_dict(torch.load('results/soft/model_soft.pth', map_location=device))
model_hard.load_state_dict(torch.load('results/hard/model_hard.pth', map_location=device))
model_hybrid.load_state_dict(torch.load('results/hybrid/model_hybrid.pth', map_location=device))

model_baseline.eval()
model_soft.eval()
model_hard.eval()
model_hybrid.eval()

print('✅ All models loaded')

## Evaluation Functions

In [None]:
def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse < 1e-10:
        return 100.0
    return -10 * np.log10(mse)

def compute_ssim(img1, img2):
    return ssim(img1, img2, multichannel=True, data_range=1.0, channel_axis=2)

@torch.no_grad()
def render_full_image(model, c2w, H, W, focal, chunk=4096):
    rays_o, rays_d = get_rays(H, W, focal, torch.from_numpy(c2w).float().to(device))
    rays_o = rays_o.view(-1, 3)
    rays_d = rays_d.view(-1, 3)
    
    rgb_chunks = []
    for i in range(0, rays_o.shape[0], chunk):
        rgb, _ = render_rays(model, rays_o[i:i+chunk], rays_d[i:i+chunk],
                            near=2.0, far=6.0, n_samples=128, perturb=False)
        rgb_chunks.append(rgb)
    
    img = torch.cat(rgb_chunks, 0).view(H, W, 3).cpu().numpy()
    return img

## Run Evaluation on Validation Set

In [None]:
models = {
    'Baseline': model_baseline,
    'Soft': model_soft,
    'Hard': model_hard,
    'Hybrid': model_hybrid
}

results = {name: {'psnr': [], 'ssim': [], 'mse': []} for name in models.keys()}

print('Evaluating on validation set...')

for i in tqdm(range(len(imgs_val))):
    gt_img = imgs_val[i]
    pose = poses_val[i]
    
    for name, model in models.items():
        pred_img = render_full_image(model, pose, H, W, focal)
        
        psnr_val = compute_psnr(pred_img, gt_img)
        ssim_val = compute_ssim(pred_img, gt_img)
        mse_val = np.mean((pred_img - gt_img) ** 2)
        
        results[name]['psnr'].append(psnr_val)
        results[name]['ssim'].append(ssim_val)
        results[name]['mse'].append(mse_val)

print('✅ Evaluation complete!')

## Summary Statistics

In [None]:
print('\n=== EVALUATION RESULTS ===\n')

for name in models.keys():
    psnr_mean = np.mean(results[name]['psnr'])
    psnr_std = np.std(results[name]['psnr'])
    ssim_mean = np.mean(results[name]['ssim'])
    ssim_std = np.std(results[name]['ssim'])
    mse_mean = np.mean(results[name]['mse'])
    
    print(f'{name}:')
    print(f'  PSNR: {psnr_mean:.2f} ± {psnr_std:.2f} dB')
    print(f'  SSIM: {ssim_mean:.4f} ± {ssim_std:.4f}')
    print(f'  MSE:  {mse_mean:.6f}')
    print()

## Visualization

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

metrics = ['psnr', 'ssim']
titles = ['PSNR (dB)', 'SSIM']

for idx, (metric, title) in enumerate(zip(metrics, titles)):
    ax = axes[idx]
    for name in models.keys():
        ax.plot(results[name][metric], label=name, alpha=0.7)
    ax.set_xlabel('Validation Image')
    ax.set_ylabel(title)
    ax.set_title(f'{title} Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Bar plot
ax = axes[2]
names = list(models.keys())
psnr_means = [np.mean(results[name]['psnr']) for name in names]
ax.bar(names, psnr_means)
ax.set_ylabel('PSNR (dB)')
ax.set_title('Average PSNR')
ax.grid(True, alpha=0.3, axis='y')

ax = axes[3]
ssim_means = [np.mean(results[name]['ssim']) for name in names]
ax.bar(names, ssim_means)
ax.set_ylabel('SSIM')
ax.set_title('Average SSIM')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('results/evaluation_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print('✅ Evaluation plots saved to results/evaluation_comparison.png')

## Save Results

In [None]:
# Save numerical results
np.save('results/evaluation_results.npy', results)
print('✅ Results saved to results/evaluation_results.npy')