# Ground Truth Score vs Denoising Score Matching Comparison

This notebook demonstrates the difference between:
1. **Ground Truth Score Learning**: Training a neural network to directly predict the analytical score function ∇_x log p(x)
2. **Denoising Score Matching**: Training via the EDM denoising objective (standard diffusion training)

We'll use a Gaussian mixture model where we can compute the exact score analytically.

In [None]:
import sys
sys.path.append("/n/home12/binxuwang/Github/DiffusionLearningCurve")

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from core.gaussian_mixture_lib import GaussianMixture
from core.gmm_general_diffusion_lib import gaussian_mixture_score_torch
from core.diffusion_edm_lib import (
    UNetBlockStyleMLP_backbone, 
    EDMPrecondWrapper,
    EDMLoss,
    train_score_model_custom_loss,
    edm_sampler
)

plt.style.use('seaborn-v0_8')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Setup: Create a Gaussian Mixture Model

We'll create a simple 2D Gaussian mixture with known analytical score function.

In [None]:
# Define a 2D Gaussian mixture
mus = [np.array([-2.0, -1.0]), np.array([2.0, 1.0]), np.array([0.0, 2.5])]
covs = [np.array([[0.8, 0.2], [0.2, 0.8]]), 
        np.array([[1.2, -0.4], [-0.4, 1.2]]),
        np.array([[0.6, 0.0], [0.0, 0.6]])]
weights = [0.4, 0.4, 0.2]

gmm = GaussianMixture(mus, covs, weights)

# Generate training data
n_samples = 5000
X_train, components, _ = gmm.sample(n_samples)
X_train_torch = torch.tensor(X_train, dtype=torch.float32)

# Visualize the data and analytical score field
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot data
ax1.scatter(X_train[:, 0], X_train[:, 1], alpha=0.6, s=10)
ax1.set_title('Training Data from GMM')
ax1.set_xlabel('x1')
ax1.set_ylabel('x2')
ax1.axis('equal')

# Plot analytical score field
x_range = np.linspace(-4, 4, 20)
y_range = np.linspace(-3, 4, 20)
XX, YY = np.meshgrid(x_range, y_range)
grid_points = np.stack([XX.flatten(), YY.flatten()], axis=1)
true_scores = gmm.score(grid_points)

ax2.quiver(XX, YY, true_scores[:, 0].reshape(XX.shape), 
           true_scores[:, 1].reshape(XX.shape), alpha=0.7)
ax2.set_title('True Score Field ∇log p(x)')
ax2.set_xlabel('x1')
ax2.set_ylabel('x2')
ax2.axis('equal')

plt.tight_layout()
plt.show()

print(f"Generated {n_samples} samples from {gmm.n_component}-component GMM")

## Method 1: Ground Truth Score Learning

Train a neural network to directly predict the analytical score function.

In [None]:
class GroundTruthScoreLoss:
    """Loss function for direct score matching against analytical ground truth"""
    def __init__(self, gmm):
        self.gmm = gmm
    
    def __call__(self, model, X):
        X_np = X.detach().cpu().numpy()
        true_scores = self.gmm.score(X_np)
        true_scores_torch = torch.tensor(true_scores, dtype=torch.float32, device=X.device)
        
        # Dummy time input (not used for ground truth score)
        t_dummy = torch.zeros(X.shape[0], device=X.device)
        pred_scores = model(X, t_dummy)
        
        loss = F.mse_loss(pred_scores, true_scores_torch, reduction='none')
        return loss

# Create model for ground truth score learning
gt_model = UNetBlockStyleMLP_backbone(ndim=2, nlayers=4, nhidden=64, time_embed_dim=32)
gt_loss_fn = GroundTruthScoreLoss(gmm)

print("Training Ground Truth Score Model...")
gt_model_trained, gt_loss_traj = train_score_model_custom_loss(
    X_train_torch, gt_model, gt_loss_fn,
    lr=0.001, nepochs=1000, batch_size=512, device=device
)

print(f"Ground truth model final loss: {gt_loss_traj[-1]:.6f}")

## Method 2: Denoising Score Matching (EDM)

Train using standard EDM denoising objective.

In [None]:
# Create model for denoising score matching
dsm_model = UNetBlockStyleMLP_backbone(ndim=2, nlayers=4, nhidden=64, time_embed_dim=32)
dsm_model_precd = EDMPrecondWrapper(dsm_model, sigma_data=0.5)
edm_loss_fn = EDMLoss(P_mean=-1.2, P_std=1.2, sigma_data=0.5)

print("Training Denoising Score Matching Model...")
dsm_model_trained, dsm_loss_traj = train_score_model_custom_loss(
    X_train_torch, dsm_model_precd, edm_loss_fn,
    lr=0.001, nepochs=1000, batch_size=512, device=device
)

print(f"Denoising model final loss: {dsm_loss_traj[-1]:.6f}")

## Comparison: Loss Trajectories

In [None]:
plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.plot(gt_loss_traj, label='Ground Truth Score Loss', alpha=0.8)
plt.plot(dsm_loss_traj, label='Denoising Score Matching Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.yscale('log')

plt.subplot(1, 2, 2)
plt.plot(gt_loss_traj[100:], label='Ground Truth Score Loss', alpha=0.8)
plt.plot(dsm_loss_traj[100:], label='Denoising Score Matching Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss (after epoch 100)')
plt.legend()
plt.yscale('log')

plt.tight_layout()
plt.show()

## Comparison: Score Field Predictions

In [None]:
# Evaluate both models on a grid
with torch.no_grad():
    grid_torch = torch.tensor(grid_points, dtype=torch.float32, device=device)
    t_dummy = torch.zeros(grid_torch.shape[0], device=device)
    
    # Ground truth model predictions
    gt_pred_scores = gt_model_trained(grid_torch, t_dummy).cpu().numpy()
    
    # Denoising model predictions (at σ=0, which should give the score)
    sigma_eval = torch.full((grid_torch.shape[0],), 0.01, device=device)  # Very small σ
    dsm_pred_clean = dsm_model_trained(grid_torch, sigma_eval).cpu().numpy()
    # Convert from denoised prediction to score: score = -(x_noisy - x_clean) / σ²
    dsm_pred_scores = -(grid_torch.cpu().numpy() - dsm_pred_clean) / (0.01**2)

# Visualize score field comparisons
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# True scores
axes[0,0].quiver(XX, YY, true_scores[:, 0].reshape(XX.shape), 
                 true_scores[:, 1].reshape(XX.shape), alpha=0.7)
axes[0,0].set_title('True Score Field')
axes[0,0].axis('equal')

# Ground truth model scores
axes[0,1].quiver(XX, YY, gt_pred_scores[:, 0].reshape(XX.shape), 
                 gt_pred_scores[:, 1].reshape(XX.shape), alpha=0.7, color='red')
axes[0,1].set_title('Ground Truth Model Predictions')
axes[0,1].axis('equal')

# Denoising model scores
axes[0,2].quiver(XX, YY, dsm_pred_scores[:, 0].reshape(XX.shape), 
                 dsm_pred_scores[:, 1].reshape(XX.shape), alpha=0.7, color='green')
axes[0,2].set_title('Denoising Model Predictions')
axes[0,2].axis('equal')

# Error maps
gt_error = np.linalg.norm(gt_pred_scores - true_scores, axis=1)
dsm_error = np.linalg.norm(dsm_pred_scores - true_scores, axis=1)

im1 = axes[1,0].scatter(grid_points[:, 0], grid_points[:, 1], c=gt_error, cmap='viridis', s=20)
axes[1,0].set_title('Ground Truth Model Error')
plt.colorbar(im1, ax=axes[1,0])
axes[1,0].axis('equal')

im2 = axes[1,1].scatter(grid_points[:, 0], grid_points[:, 1], c=dsm_error, cmap='viridis', s=20)
axes[1,1].set_title('Denoising Model Error')
plt.colorbar(im2, ax=axes[1,1])
axes[1,1].axis('equal')

# Error comparison
axes[1,2].hist(gt_error, alpha=0.7, label=f'GT Model (mean: {gt_error.mean():.3f})', bins=30)
axes[1,2].hist(dsm_error, alpha=0.7, label=f'DSM Model (mean: {dsm_error.mean():.3f})', bins=30)
axes[1,2].set_xlabel('Score Prediction Error')
axes[1,2].set_ylabel('Count')
axes[1,2].set_title('Error Distribution')
axes[1,2].legend()

plt.tight_layout()
plt.show()

print(f"Ground Truth Model Mean Error: {gt_error.mean():.6f}")
print(f"Denoising Model Mean Error: {dsm_error.mean():.6f}")

## Sampling Quality Comparison

Compare the quality of samples generated by both approaches.

In [None]:
# Generate samples using both models
n_gen_samples = 2000

# For ground truth model, we need to implement a simple Langevin sampler
def langevin_sampler(score_model, n_samples, n_steps=1000, step_size=0.01, init_noise_scale=2.0):
    """Simple Langevin dynamics sampler for ground truth score model"""
    with torch.no_grad():
        # Initialize with noise
        x = torch.randn(n_samples, 2, device=device) * init_noise_scale
        t_dummy = torch.zeros(n_samples, device=device)
        
        for i in tqdm(range(n_steps), desc="Langevin sampling"):
            score = score_model(x, t_dummy)
            noise = torch.randn_like(x) * np.sqrt(2 * step_size)
            x = x + step_size * score + noise
            
    return x.cpu().numpy()

print("Sampling from Ground Truth Model...")
gt_samples = langevin_sampler(gt_model_trained, n_gen_samples, n_steps=500)

print("Sampling from Denoising Model...")
with torch.no_grad():
    noise_init = torch.randn(n_gen_samples, 2, device=device)
    dsm_samples = edm_sampler(
        dsm_model_trained, noise_init, 
        num_steps=50, sigma_min=0.002, sigma_max=80, rho=7
    ).cpu().numpy()

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original data
axes[0].scatter(X_train[:, 0], X_train[:, 1], alpha=0.6, s=10)
axes[0].set_title('Original Training Data')
axes[0].axis('equal')
axes[0].set_xlim(-5, 5)
axes[0].set_ylim(-4, 5)

# Ground truth model samples
axes[1].scatter(gt_samples[:, 0], gt_samples[:, 1], alpha=0.6, s=10, color='red')
axes[1].set_title('Ground Truth Model Samples')
axes[1].axis('equal')
axes[1].set_xlim(-5, 5)
axes[1].set_ylim(-4, 5)

# Denoising model samples
axes[2].scatter(dsm_samples[:, 0], dsm_samples[:, 1], alpha=0.6, s=10, color='green')
axes[2].set_title('Denoising Model Samples')
axes[2].axis('equal')
axes[2].set_xlim(-5, 5)
axes[2].set_ylim(-4, 5)

plt.tight_layout()
plt.show()

## Quantitative Comparison Metrics

In [None]:
from scipy.spatial.distance import cdist
from scipy.stats import wasserstein_distance

def compute_sample_metrics(true_samples, gen_samples):
    """Compute various metrics to compare sample quality"""
    metrics = {}
    
    # 1. Mean and covariance comparison
    true_mean = np.mean(true_samples, axis=0)
    gen_mean = np.mean(gen_samples, axis=0)
    metrics['mean_error'] = np.linalg.norm(true_mean - gen_mean)
    
    true_cov = np.cov(true_samples.T)
    gen_cov = np.cov(gen_samples.T)
    metrics['cov_frobenius_error'] = np.linalg.norm(true_cov - gen_cov, 'fro')
    
    # 2. Wasserstein distances (1D marginals)
    metrics['wasserstein_x'] = wasserstein_distance(true_samples[:, 0], gen_samples[:, 0])
    metrics['wasserstein_y'] = wasserstein_distance(true_samples[:, 1], gen_samples[:, 1])
    
    # 3. Nearest neighbor distances (coverage)
    dists_true_to_gen = cdist(true_samples, gen_samples)
    min_dists_coverage = np.min(dists_true_to_gen, axis=1)
    metrics['coverage_mean'] = np.mean(min_dists_coverage)
    
    # 4. Precision (how close generated samples are to true samples)
    dists_gen_to_true = cdist(gen_samples, true_samples)
    min_dists_precision = np.min(dists_gen_to_true, axis=1)
    metrics['precision_mean'] = np.mean(min_dists_precision)
    
    return metrics

gt_metrics = compute_sample_metrics(X_train, gt_samples)
dsm_metrics = compute_sample_metrics(X_train, dsm_samples)

print("Sample Quality Comparison:")
print("=" * 50)
print(f"{'Metric':<20} {'GT Model':<15} {'DSM Model':<15}")
print("-" * 50)
for key in gt_metrics.keys():
    print(f"{key:<20} {gt_metrics[key]:<15.6f} {dsm_metrics[key]:<15.6f}")

# Summary
print("\nSummary:")
print("=" * 30)
if gt_metrics['precision_mean'] < dsm_metrics['precision_mean']:
    print("✓ Ground Truth model has better precision (samples closer to true data)")
else:
    print("✓ Denoising model has better precision (samples closer to true data)")
    
if gt_metrics['coverage_mean'] < dsm_metrics['coverage_mean']:
    print("✓ Ground Truth model has better coverage (true data better covered)")
else:
    print("✓ Denoising model has better coverage (true data better covered)")

## Key Findings and Conclusions

**Ground Truth Score Learning:**
- Directly learns the score function ∇ log p(x)
- Requires analytical score (only available for simple distributions)
- Can achieve very low prediction error on the score field
- Sampling requires Langevin dynamics or other MCMC methods

**Denoising Score Matching (EDM):**
- Learns score implicitly through denoising at multiple noise levels
- Works for any distribution (no analytical score needed)
- Provides efficient deterministic sampling via ODE/SDE solvers
- More robust and scalable to high-dimensional data

**Comparison:**
- Ground truth learning can be more accurate when analytical scores are available
- Denoising score matching is more practical and generalizable
- Both approaches can generate high-quality samples when properly trained
- The choice depends on whether analytical scores are available and computational constraints