# NanoMAD ML - Test Suite

This notebook stress-tests all components in Phase_6.0 to verify everything works correctly.

**What it tests:**
- Module imports
- core_shell.py: Particle creation, diffraction computation, ground truth labels
- mad_model.py: CNN architecture (forward pass with dummy data) - *requires PyTorch*
- mad_loss.py: Loss function, F_N reconstruction - *requires PyTorch*
- data_augmentation.py: Augmentation transforms
- visualization.py: Plotting functions
- ScatteringFactors: f'/f'' data loading
- generate_training_data.py: Single particle generation

**What it does NOT do:**
- Train a full model (too slow)
- Generate large datasets (too slow)
- Run 3D inference on large volumes

---

In [1]:
import numpy as np
import sys
from pathlib import Path
import traceback

# Add parent directory to path for imports
sys.path.insert(0, str(Path('.').resolve().parent))

# Check if PyTorch is available
try:
    import torch
    TORCH_AVAILABLE = True
    print(f"PyTorch {torch.__version__} available (CUDA: {torch.cuda.is_available()})")
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not available - PyTorch-dependent tests will be skipped")

# Test tracking
test_results = []

def run_test(name, func, requires_torch=False):
    """Run a test and track results."""
    print(f"\n{'='*70}")
    print(f"TEST: {name}")
    print(f"{'='*70}")
    
    if requires_torch and not TORCH_AVAILABLE:
        print(f"  SKIPPED: Requires PyTorch")
        test_results.append((name, 'SKIPPED', 'Requires PyTorch'))
        return
    
    try:
        func()
        print(f"\n  PASSED: {name}")
        test_results.append((name, 'PASSED', None))
    except Exception as e:
        print(f"\n  FAILED: {name}")
        print(f"  Error: {e}")
        traceback.print_exc()
        test_results.append((name, 'FAILED', str(e)))

print("\nTest framework ready.")

PyTorch not available - PyTorch-dependent tests will be skipped

Test framework ready.


---
## Test 1: Core Module Imports

In [2]:
def test_core_imports():
    """Test that all core modules can be imported."""
    
    # core_shell.py
    from src.core_shell import (
        create_particle_with_shape,
        apply_displacement_to_particle,
        create_layered_displacement_field,
        create_random_strain_field,
        compute_diffraction_oversampled_cropped,
        compute_ground_truth_labels,
        extract_label_patches,
        ScatteringFactors,
        get_total_density,
        compute_f0_thomson,
        SPECIES_NI,
        SPECIES_FE,
    )
    print("  core_shell.py imports OK")
    
    # visualization.py
    from src.visualization import (
        plot_particle,
        plot_diffraction,
        plot_diffraction_multi_energy,
        plot_reconstruction,
        plot_block_analysis,
    )
    print("  visualization.py imports OK")
    
    # augmentation.py (renamed from data_augmentation.py)
    from src.augmentation import (
        augment_sample,
        random_rot90,
        random_flip,
        random_intensity_scale,
        add_poisson_noise,
    )
    print("  augmentation.py imports OK")
    
    # thomson_factors.py (renamed from d_fthomson_IT92.py)
    from data.thomson_factors import d_fthomson_IT92
    assert 'Ni' in d_fthomson_IT92, "Missing Ni in Thomson table"
    assert 'Fe' in d_fthomson_IT92, "Missing Fe in Thomson table"
    print("  thomson_factors.py data loaded OK")
    
run_test("Core Module Imports", test_core_imports)


TEST: Core Module Imports
  core_shell.py imports OK
  visualization.py imports OK
  augmentation.py imports OK
  thomson_factors.py data loaded OK

  PASSED: Core Module Imports


---
## Test 2: PyTorch Module Imports

In [3]:
def test_torch_imports():
    """Test that PyTorch-dependent modules can be imported."""
    
    import torch
    print(f"  PyTorch {torch.__version__}")
    print(f"  CUDA available: {torch.cuda.is_available()}")
    
    # mad_model.py
    from src.mad_model import MADNet, count_parameters
    print("  mad_model.py imports OK")
    
    # mad_loss.py
    from src.mad_loss import MADPhysicsLoss, compute_F_N
    print("  mad_loss.py imports OK")
    
    # train.py (renamed from train_cnn.py)
    from src.train import MADDataset
    print("  train.py imports OK")
    
run_test("PyTorch Module Imports", test_torch_imports, requires_torch=True)


TEST: PyTorch Module Imports
  SKIPPED: Requires PyTorch


---
## Test 3: Scattering Factors Loading

In [4]:
def test_scattering_factors():
    """Test scattering factor data loading and interpolation."""
    from src.core_shell import ScatteringFactors, compute_f0_thomson
    
    # Load scattering factors
    sf = ScatteringFactors(data_dir='../data')
    print("  ScatteringFactors loaded OK")
    
    # Test f'/f'' at Ni K-edge
    E = 8333  # Ni K-edge
    fp_Ni = sf.get_f_prime('Ni', E)
    fpp_Ni = sf.get_f_double_prime('Ni', E)
    print(f"  Ni at {E} eV: f'={fp_Ni:.3f}, f''={fpp_Ni:.3f}")
    
    fp_Fe = sf.get_f_prime('Fe', E)
    fpp_Fe = sf.get_f_double_prime('Fe', E)
    print(f"  Fe at {E} eV: f'={fp_Fe:.3f}, f''={fpp_Fe:.3f}")
    
    # Test Thomson scattering (from core_shell)
    # Signature: compute_f0_thomson(element, q_magnitude)
    f0_Ni = compute_f0_thomson('Ni', 0.0)
    f0_Fe = compute_f0_thomson('Fe', 0.0)
    print(f"  f0(Q=0): Ni={f0_Ni:.1f}, Fe={f0_Fe:.1f}")
    
    # Verify values are reasonable
    assert -10 < fp_Ni < 0, f"Ni f' should be negative near edge, got {fp_Ni}"
    assert fpp_Ni > 0, f"Ni f'' should be positive, got {fpp_Ni}"
    assert abs(f0_Ni - 28) < 1, f"Ni f0 should be ~28, got {f0_Ni}"
    assert abs(f0_Fe - 26) < 1, f"Fe f0 should be ~26, got {f0_Fe}"
    print("  All scattering factor values reasonable")

run_test("Scattering Factors", test_scattering_factors)


TEST: Scattering Factors
Loading scattering factors for Ni from ../data/Nickel.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
Loading scattering factors for Fe from ../data/Iron.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
  ScatteringFactors loaded OK
  Ni at 8333 eV: f'=-7.941, f''=2.521
  Fe at 8333 eV: f'=-0.831, f''=3.025
  f0(Q=0): Ni=28.0, Fe=26.0
  All scattering factor values reasonable

  PASSED: Scattering Factors


---
## Test 4: Particle Creation (All Types)

In [5]:
def test_particle_creation():
    """Test particle creation for all shape types."""
    from src.core_shell import create_particle_with_shape, get_total_density, SPECIES_NI, SPECIES_FE
    
    GRID_SIZE = 64  # Small for speed
    shape_types = ['circle', 'hexagon', 'polygon', 'polygon_centrosymmetric']
    
    for shape_type in shape_types:
        # Shape-specific params
        if shape_type == 'hexagon':
            params = {'anisotropy': 1.1}
        elif shape_type in ['polygon', 'polygon_centrosymmetric']:
            params = {'n_vertices': 6}
        else:
            params = {}
        
        particle, info = create_particle_with_shape(
            grid_size=GRID_SIZE,
            shape_type=shape_type,
            outer_radius=15,
            core_fraction=0.5,
            pixel_size=5.0,
            shape_params=params,
            verbose=False
        )
        
        # Verify structure
        assert particle.shape == (2, GRID_SIZE, GRID_SIZE), f"Wrong shape: {particle.shape}"
        # Particle can be real (before displacement) or complex (after)
        assert particle.dtype in [np.float64, np.complex128], f"Wrong dtype: {particle.dtype}"
        
        # Verify composition
        ni_total = np.abs(particle[SPECIES_NI]).sum()
        fe_total = np.abs(particle[SPECIES_FE]).sum()
        assert ni_total > 0, "No Ni content"
        assert fe_total > 0, "No Fe content"
        
        # Verify info dict
        assert 'outer_mask' in info
        assert 'core_mask' in info
        assert 'shell_mask' in info
        
        print(f"  {shape_type}: shape={particle.shape}, dtype={particle.dtype}, Ni={ni_total:.0f}, Fe={fe_total:.0f}")

run_test("Particle Creation (All Types)", test_particle_creation)


TEST: Particle Creation (All Types)
  circle: shape=(2, 64, 64), dtype=float64, Ni=665, Fe=44
  hexagon: shape=(2, 64, 64), dtype=float64, Ni=143, Fe=9
  polygon: shape=(2, 64, 64), dtype=float64, Ni=410, Fe=30
  polygon_centrosymmetric: shape=(2, 64, 64), dtype=float64, Ni=198, Fe=30

  PASSED: Particle Creation (All Types)


---
## Test 5: Strain Field Generation

In [6]:
def test_strain_fields():
    """Test displacement/strain field generation."""
    from src.core_shell import (
        create_particle_with_shape,
        create_layered_displacement_field,
        create_random_strain_field,
        apply_displacement_to_particle,
    )
    
    GRID_SIZE = 64
    PIXEL_SIZE = 5.0
    Q_BRAGG = 3.09
    
    # Create particle
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE,
        shape_type='circle',
        outer_radius=15,
        core_fraction=0.5,
        pixel_size=PIXEL_SIZE,
        verbose=False
    )
    print("  Particle created")
    
    # Create analytic displacement
    displacement_analytic, disp_info = create_layered_displacement_field(
        core_mask=info['core_mask'],
        outer_mask=info['outer_mask'],
        pixel_size=PIXEL_SIZE,
        interface_amplitude=1.0,
        surface_amplitude=0.5,
        verbose=False
    )
    assert displacement_analytic.shape == (GRID_SIZE, GRID_SIZE)
    print(f"  Analytic displacement: range [{displacement_analytic.min():.2f}, {displacement_analytic.max():.2f}] A")
    
    # Create random strain
    displacement_random = create_random_strain_field(
        grid_size=GRID_SIZE,
        pixel_size=PIXEL_SIZE,
        displacement_amplitude=0.2,
        correlation_length=0.1,
        verbose=False
    )
    assert displacement_random.shape == (GRID_SIZE, GRID_SIZE)
    print(f"  Random strain: range [{displacement_random.min():.2f}, {displacement_random.max():.2f}] A")
    
    # Apply displacement
    total_displacement = displacement_analytic + displacement_random
    particle_strained = apply_displacement_to_particle(
        particle=particle,
        displacement=total_displacement,
        q_bragg_magnitude=Q_BRAGG
    )
    assert particle_strained.shape == particle.shape
    assert particle_strained.dtype == np.complex128
    print("  Displacement applied to particle")

run_test("Strain Field Generation", test_strain_fields)


TEST: Strain Field Generation
  Particle created
  Analytic displacement: range [0.00, 0.62] A
  Random strain: range [-0.19, 0.20] A
  Displacement applied to particle

  PASSED: Strain Field Generation


---
## Test 6: Diffraction Computation

In [7]:
def test_diffraction():
    """Test multi-energy diffraction computation."""
    from src.core_shell import (
        create_particle_with_shape,
        apply_displacement_to_particle,
        create_layered_displacement_field,
        compute_diffraction_oversampled_cropped,
        ScatteringFactors,
    )
    
    GRID_SIZE = 64
    OUTPUT_SIZE = 32
    PIXEL_SIZE = 5.0
    Q_BRAGG = 3.09
    ENERGIES = [8313, 8333, 8348]  # 3 energies for speed
    
    sf = ScatteringFactors(data_dir='../data')
    
    # Create strained particle
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE, shape_type='circle',
        outer_radius=15, core_fraction=0.5,
        pixel_size=PIXEL_SIZE, verbose=False
    )
    displacement, _ = create_layered_displacement_field(
        core_mask=info['core_mask'], outer_mask=info['outer_mask'],
        pixel_size=PIXEL_SIZE, verbose=False
    )
    particle_strained = apply_displacement_to_particle(
        particle, displacement, Q_BRAGG
    )
    print("  Strained particle created")
    
    # Compute diffraction
    diffractions = compute_diffraction_oversampled_cropped(
        particle=particle_strained,
        energies=ENERGIES,
        pixel_size=PIXEL_SIZE,
        scattering_factors=sf,
        output_size=OUTPUT_SIZE,
        verbose=False
    )
    
    assert len(diffractions) == len(ENERGIES), f"Wrong number of patterns: {len(diffractions)}"
    for E, D in diffractions.items():
        assert D.shape == (OUTPUT_SIZE, OUTPUT_SIZE), f"Wrong shape at E={E}: {D.shape}"
        assert np.iscomplexobj(D), f"Diffraction should be complex at E={E}"
        intensity = np.abs(D)**2
        assert intensity.max() > 0, f"Zero intensity at E={E}"
        print(f"  E={E} eV: shape={D.shape}, I_max={intensity.max():.2e}")

run_test("Diffraction Computation", test_diffraction)


TEST: Diffraction Computation
Loading scattering factors for Ni from ../data/Nickel.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
Loading scattering factors for Fe from ../data/Iron.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
  Strained particle created
  E=8313 eV: shape=(32, 32), I_max=2.31e+08
  E=8333 eV: shape=(32, 32), I_max=1.92e+08
  E=8348 eV: shape=(32, 32), I_max=2.31e+08

  PASSED: Diffraction Computation


---
## Test 7: Ground Truth Labels

In [8]:
def test_ground_truth():
    """Test ground truth label computation."""
    from src.core_shell import (
        create_particle_with_shape,
        apply_displacement_to_particle,
        create_layered_displacement_field,
        compute_ground_truth_labels,
        extract_label_patches,
    )
    
    GRID_SIZE = 64
    OUTPUT_SIZE = 32
    PATCH_SIZE = 16
    PIXEL_SIZE = 5.0
    Q_BRAGG = 3.09
    
    # Create particle
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE, shape_type='circle',
        outer_radius=15, core_fraction=0.5,
        pixel_size=PIXEL_SIZE, verbose=False
    )
    displacement, _ = create_layered_displacement_field(
        core_mask=info['core_mask'], outer_mask=info['outer_mask'],
        pixel_size=PIXEL_SIZE, verbose=False
    )
    particle_strained = apply_displacement_to_particle(
        particle, displacement, Q_BRAGG
    )
    
    # Compute ground truth
    labels = compute_ground_truth_labels(
        particle=particle_strained,
        pixel_size=PIXEL_SIZE,
        output_size=OUTPUT_SIZE,
        verbose=False
    )
    
    # Check all expected keys
    expected_keys = ['F_T_mag', 'F_A_mag', 'F_N_mag', 'delta_phi', 'sin_delta_phi', 'cos_delta_phi']
    for key in expected_keys:
        assert key in labels, f"Missing key: {key}"
        assert labels[key].shape == (OUTPUT_SIZE, OUTPUT_SIZE), f"Wrong shape for {key}"
    print(f"  All expected label keys present")
    print(f"  Full keys: {list(labels.keys())}")
    
    # Verify sin^2+cos^2 = 1
    sin2_cos2 = labels['sin_delta_phi']**2 + labels['cos_delta_phi']**2
    assert np.allclose(sin2_cos2, 1.0, atol=1e-6), "sin^2+cos^2 != 1"
    print("  sin^2(dphi) + cos^2(dphi) = 1 verified")
    
    # Verify F_N derivation
    F_T = labels['F_T_mag']
    F_A = labels['F_A_mag']
    F_N = labels['F_N_mag']
    delta_phi = labels['delta_phi']
    
    F_N_check = np.sqrt(np.maximum(F_T**2 + F_A**2 - 2*F_T*F_A*np.cos(delta_phi), 0))
    assert np.allclose(F_N, F_N_check, rtol=1e-4), "F_N derivation mismatch"
    print("  F_N = sqrt(F_T^2 + F_A^2 - 2*F_T*F_A*cos(dphi)) verified")
    
    # Extract patches
    patches = extract_label_patches(
        labels=labels,
        patch_size=PATCH_SIZE,
        label_keys=['F_T_mag', 'F_A_mag', 'sin_delta_phi', 'cos_delta_phi']
    )
    n_patches_per_dim = OUTPUT_SIZE // PATCH_SIZE
    n_patches = n_patches_per_dim ** 2
    expected_shape = (n_patches_per_dim, n_patches_per_dim, PATCH_SIZE, PATCH_SIZE, 4)
    assert patches.shape == expected_shape, f"Patches shape {patches.shape} != expected {expected_shape}"
    print(f"  Extracted patches shape: {patches.shape}")

run_test("Ground Truth Labels", test_ground_truth)


TEST: Ground Truth Labels
  All expected label keys present
  Full keys: ['F_T_mag', 'F_A_mag', 'delta_phi', 'cos_delta_phi', 'sin_delta_phi', 'F_N_mag', 'F_N_mag_derived', 'f0_Ni', 'f0_Fe']
  sin^2(dphi) + cos^2(dphi) = 1 verified
  F_N = sqrt(F_T^2 + F_A^2 - 2*F_T*F_A*cos(dphi)) verified
  Extracted patches shape: (2, 2, 16, 16, 4)

  PASSED: Ground Truth Labels


---
## Test 8: CNN Architecture

In [9]:
def test_cnn_architecture():
    """Test CNN model forward pass."""
    import torch
    from src.mad_model import MADNet, count_parameters
    
    # Create model
    model = MADNet(
        in_channels=8,
        out_channels=4,
        base_filters=32
    )
    model.eval()
    
    n_params = count_parameters(model)
    print(f"  MADNet created: {n_params:,} parameters")
    
    # Test forward pass
    batch_size = 2
    patch_size = 16
    n_energies = 8
    
    # Inputs
    x = torch.randn(batch_size, n_energies, patch_size, patch_size)
    f_prime = torch.randn(batch_size, n_energies)
    f_double_prime = torch.randn(batch_size, n_energies)
    
    with torch.no_grad():
        y = model(x, f_prime, f_double_prime)
    
    expected_shape = (batch_size, 4, patch_size, patch_size)
    assert y.shape == expected_shape, f"Wrong output shape: {y.shape} vs {expected_shape}"
    print(f"  Forward pass: input {x.shape} -> output {y.shape}")
    
    # Check output ranges
    # Channels 0,1 (magnitudes) should be >= 0 after softplus
    # Channels 2,3 (sin/cos) should be in [-1, 1] after tanh
    assert (y[:, :2] >= 0).all(), "Magnitudes should be non-negative"
    assert (y[:, 2:].abs() <= 1).all(), "sin/cos should be in [-1, 1]"
    print("  Output ranges correct (magnitudes >= 0, sin/cos in [-1, 1])")

run_test("CNN Architecture", test_cnn_architecture, requires_torch=True)


TEST: CNN Architecture
  SKIPPED: Requires PyTorch


---
## Test 9: Loss Function

In [10]:
def test_loss_function():
    """Test physics-informed loss function."""
    import torch
    from src.mad_loss import MADPhysicsLoss, compute_F_N
    
    # Create loss
    loss_fn = MADPhysicsLoss(
        mag_weight=1.0,
        phase_weight=1.0,
        physics_weight=0.1,
        normalize_weight=0.01
    )
    print("  MADPhysicsLoss created")
    
    # Create dummy predictions and targets
    batch_size = 4
    patch_size = 16
    
    pred = torch.rand(batch_size, 4, patch_size, patch_size)
    pred[:, :2] = pred[:, :2] * 5  # Magnitudes in [0, 5]
    pred[:, 2:] = pred[:, 2:] * 2 - 1  # sin/cos in [-1, 1]
    
    target = torch.rand(batch_size, 4, patch_size, patch_size)
    target[:, :2] = target[:, :2] * 5
    target[:, 2:] = target[:, 2:] * 2 - 1
    
    f_prime = torch.randn(batch_size, 8)
    f_double_prime = torch.abs(torch.randn(batch_size, 8))
    
    # Compute loss
    loss = loss_fn(pred, target, f_prime, f_double_prime)
    
    assert loss.ndim == 0, f"Loss should be scalar, got shape {loss.shape}"
    assert loss.item() > 0, "Loss should be positive"
    assert not torch.isnan(loss), "Loss is NaN"
    print(f"  Loss computation: {loss.item():.4f}")
    
    # Test F_N computation
    F_T = torch.rand(batch_size, patch_size, patch_size) * 100
    F_A = torch.rand(batch_size, patch_size, patch_size) * 50
    delta_phi = torch.rand(batch_size, patch_size, patch_size) * 2 * np.pi - np.pi
    
    F_N = compute_F_N(F_T, F_A, delta_phi)
    
    assert F_N.shape == F_T.shape, f"F_N shape mismatch: {F_N.shape}"
    assert (F_N >= 0).all(), "F_N should be non-negative"
    print(f"  compute_F_N: range [{F_N.min():.2f}, {F_N.max():.2f}]")

run_test("Loss Function", test_loss_function, requires_torch=True)


TEST: Loss Function
  SKIPPED: Requires PyTorch


---
## Test 10: Data Augmentation

In [11]:
def test_augmentation():
    """Test physics-preserving augmentation."""
    from src.augmentation import (
        augment_sample,
        random_rot90,
        random_flip,
        random_intensity_scale,
        add_poisson_noise,
    )
    
    # Create dummy sample
    patch_size = 16
    n_energies = 8
    
    X = np.random.rand(patch_size, patch_size, n_energies).astype(np.float32) * 1000
    Y = np.random.rand(patch_size, patch_size, 4).astype(np.float32)
    Y[..., :2] *= 10  # Magnitudes
    Y[..., 2:] = Y[..., 2:] * 2 - 1  # sin/cos
    
    # Test individual augmentations
    X_rot, Y_rot = random_rot90(X.copy(), Y.copy())
    assert X_rot.shape == X.shape, f"rot90 changed X shape"
    print("  random_rot90 OK")
    
    X_flip, Y_flip = random_flip(X.copy(), Y.copy())
    assert X_flip.shape == X.shape, f"flip changed X shape"
    print("  random_flip OK")
    
    X_scaled, Y_scaled = random_intensity_scale(X.copy(), Y.copy())
    assert X_scaled.shape == X.shape, f"intensity_scale changed X shape"
    print("  random_intensity_scale OK")
    
    # add_poisson_noise signature: (intensity, noise_level=None, noise_range=(0.01, 0.1))
    X_noisy = add_poisson_noise(X.copy(), noise_level=0.05)
    assert X_noisy.shape == X.shape, f"poisson_noise changed X shape"
    assert (X_noisy >= 0).all(), "Poisson noise made negative values"
    print("  add_poisson_noise OK")
    
    # Test full augmentation pipeline
    X_aug, Y_aug = augment_sample(X.copy(), Y.copy())
    assert X_aug.shape == X.shape, f"augment_sample changed X shape"
    assert Y_aug.shape == Y.shape, f"augment_sample changed Y shape"
    print("  augment_sample pipeline OK")

run_test("Data Augmentation", test_augmentation)


TEST: Data Augmentation
  random_rot90 OK
  random_flip OK
  random_intensity_scale OK
  add_poisson_noise OK
  augment_sample pipeline OK

  PASSED: Data Augmentation


---
## Test 11: MAD Equation Consistency

In [12]:
def test_mad_equation():
    """Verify MAD equation: I = |F_T|^2 + (f'^2+f''^2)|F_A|^2/f0^2 + 2|F_T||F_A|/f0*[f'cos(dphi)+f''sin(dphi)]"""
    from src.core_shell import (
        create_particle_with_shape,
        apply_displacement_to_particle,
        create_layered_displacement_field,
        compute_diffraction_oversampled_cropped,
        compute_ground_truth_labels,
        ScatteringFactors,
        compute_f0_thomson,
    )
    
    GRID_SIZE = 64
    OUTPUT_SIZE = 32
    PIXEL_SIZE = 5.0
    Q_BRAGG = 3.09
    ENERGIES = [8313, 8333, 8348]
    
    sf = ScatteringFactors()  # Auto-detect data directory
    
    # Create particle
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE, shape_type='circle',
        outer_radius=15, core_fraction=0.5,
        pixel_size=PIXEL_SIZE, verbose=False
    )
    displacement, _ = create_layered_displacement_field(
        core_mask=info['core_mask'], outer_mask=info['outer_mask'],
        pixel_size=PIXEL_SIZE, verbose=False
    )
    particle_strained = apply_displacement_to_particle(
        particle, displacement, Q_BRAGG
    )
    
    # Compute diffraction
    diffractions = compute_diffraction_oversampled_cropped(
        particle=particle_strained,
        energies=ENERGIES,
        pixel_size=PIXEL_SIZE,
        scattering_factors=sf,
        output_size=OUTPUT_SIZE,
        verbose=False
    )
    
    # Compute ground truth
    labels = compute_ground_truth_labels(
        particle=particle_strained,
        pixel_size=PIXEL_SIZE,
        output_size=OUTPUT_SIZE,
        verbose=False
    )
    
    F_T = labels['F_T_mag']
    F_A = labels['F_A_mag']
    delta_phi = labels['delta_phi']
    
    # Test MAD equation at each energy
    print("  Testing MAD equation at each energy...")
    # Signature: compute_f0_thomson(element, q_magnitude)
    f0 = compute_f0_thomson('Ni', 0.0)  # Approximate f0
    
    for E in ENERGIES:
        fp = sf.get_f_prime('Ni', E)
        fpp = sf.get_f_double_prime('Ni', E)
        
        # Predicted intensity from MAD equation
        I_mad = (
            F_T**2 +
            (fp**2 + fpp**2) * (F_A / f0)**2 +
            2 * F_T * F_A / f0 * (fp * np.cos(delta_phi) + fpp * np.sin(delta_phi))
        )
        
        # Actual intensity from diffraction
        I_actual = np.abs(diffractions[E])**2
        
        # They should be correlated (not exactly equal due to approximations)
        corr = np.corrcoef(I_mad.flatten(), I_actual.flatten())[0, 1]
        print(f"    E={E} eV: correlation = {corr:.4f}")
        
        # Correlation should be high (> 0.5 typically, allowing for Q-dependent f0 approximation)
        assert corr > 0.5, f"MAD equation correlation too low at E={E}: {corr}"
    
    print("  MAD equation produces correlated intensities")

run_test("MAD Equation Consistency", test_mad_equation)


TEST: MAD Equation Consistency
Loading scattering factors for Ni from ../data/Nickel.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
Loading scattering factors for Fe from ../data/Iron.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
  Testing MAD equation at each energy...
    E=8313 eV: correlation = 1.0000
    E=8333 eV: correlation = 1.0000
    E=8348 eV: correlation = 1.0000
  MAD equation produces correlated intensities

  PASSED: MAD Equation Consistency


---
## Test 12: Training Data Generation

In [13]:
def test_training_data_generation():
    """Test single particle training data generation using process_particle."""
    import tempfile
    from pathlib import Path
    
    # Import generation functions
    from src.generate_data import process_particle, save_particle_data, sample_parameters, sample_shape_type
    from src.core_shell import ScatteringFactors
    
    sf = ScatteringFactors()  # Auto-detect data directory
    
    with tempfile.TemporaryDirectory() as tmpdir:
        output_dir = Path(tmpdir)
        
        # Sample random parameters for a particle
        shape_type = sample_shape_type()
        params = sample_parameters(shape_type)
        print(f"  Sampled shape_type: {shape_type}")
        print(f"  Params keys: {list(params.keys())}")
        
        # Generate one particle
        # Signature: process_particle(params, scattering_factors, verbose=False)
        data = process_particle(
            params=params,
            scattering_factors=sf,
            verbose=False
        )
        
        assert data is not None, "process_particle returned None"
        print(f"  process_particle returned data with keys: {list(data.keys())}")
        
        # Save and verify
        save_particle_data(data, output_dir, particle_idx=0)
        
        # Check output file exists
        output_file = output_dir / 'particle_0000.npz'
        assert output_file.exists(), f"Output file not created: {output_file}"
        print(f"  Generated {output_file.name}")
        
        # Load and verify
        loaded = np.load(output_file)
        
        required_keys = ['X', 'Y', 'f_prime', 'f_double_prime', 'energies']
        for key in required_keys:
            assert key in loaded, f"Missing key: {key}"
        print(f"  All required keys present: {list(loaded.keys())}")
        
        # Check shapes
        X = loaded['X']
        Y = loaded['Y']
        print(f"  X shape: {X.shape}, Y shape: {Y.shape}")
        
        # Verify data types
        assert X.dtype == np.float32, f"X dtype wrong: {X.dtype}"
        assert Y.dtype == np.float32, f"Y dtype wrong: {Y.dtype}"
        print("  Data types correct (float32)")

run_test("Training Data Generation", test_training_data_generation)


TEST: Training Data Generation
Loading scattering factors for Ni from ../data/Nickel.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
Loading scattering factors for Fe from ../data/Iron.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
  Sampled shape_type: polygon_centrosymmetric
  Params keys: ['shape_type', 'composition_mode', 'outer_radius', 'core_fraction', 'interface_amplitude', 'surface_amplitude', 'random_amplitude', 'random_correlation', 'n_vertices', 'truncation_fraction', 'truncation_angle', 'core_offset', 'composition_params', 'module_noise_amplitude']
  process_particle returned data with keys: ['X', 'Y', 'f_prime', 'f_double_prime', 'energies', 'shape_type', 'params', 'n_patches']
  Generated particle_0000.npz
  All required keys present: ['X', 'Y', 'f_prime', 'f_double_prime', 'energies', 'shape_type', 'params_json']
  X shape: (64, 16, 16, 8), Y shape: (64, 16, 16, 4)
  Data types correct (float32)

  PASSED: Training Data Gener

---
## Test 13: Dataset Loading

In [14]:
def test_dataset_loading():
    """Test MADDataset loading and batching."""
    import torch
    from torch.utils.data import DataLoader
    import tempfile
    from pathlib import Path
    from src.train import MADDataset
    from src.generate_data import process_particle, save_particle_data, sample_parameters, sample_shape_type
    from src.core_shell import ScatteringFactors
    
    sf = ScatteringFactors()  # Auto-detect data directory
    
    with tempfile.TemporaryDirectory() as tmpdir:
        output_dir = Path(tmpdir)
        
        # Generate a few particles
        for i in range(3):
            shape_type = sample_shape_type()
            params = sample_parameters(shape_type)
            data = process_particle(
                params=params,
                scattering_factors=sf,
                verbose=False
            )
            save_particle_data(data, output_dir, particle_idx=i)
        print(f"  Generated 3 particle files")
        
        # Create dataset
        dataset = MADDataset(output_dir)
        print(f"  MADDataset created with {len(dataset)} samples")
        
        # Test single sample
        X, Y, fp, fpp = dataset[0]
        print(f"  Single sample: X={X.shape}, Y={Y.shape}, fp={fp.shape}, fpp={fpp.shape}")
        
        # Test DataLoader
        loader = DataLoader(dataset, batch_size=4, shuffle=True)
        batch = next(iter(loader))
        X_batch, Y_batch, fp_batch, fpp_batch = batch
        print(f"  Batch loading: X={X_batch.shape}, Y={Y_batch.shape}")

run_test("Dataset Loading", test_dataset_loading, requires_torch=True)


TEST: Dataset Loading
  SKIPPED: Requires PyTorch


---
## Test 14: Ellipse Shape Creation

In [15]:
def test_ellipse_shape():
    """Test ellipse shape creation with various aspect ratios and rotations."""
    from src.core_shell import create_ellipse_mask, create_particle_with_shape
    
    GRID_SIZE = 64
    
    # Note: create_ellipse_mask expects NORMALIZED coordinates (0-1), not pixel coordinates!
    # center: (0.5, 0.5) = center of grid
    # semi_major/semi_minor: fraction of half-grid (0.3 = 30% of half-grid width)
    
    # Test 1: Basic ellipse mask creation with normalized coordinates
    center_normalized = (0.5, 0.5)  # Center of grid in normalized coords
    semi_major = 0.3   # 30% of half-grid = ~10 pixels for 64x64
    semi_minor = 0.15  # 15% of half-grid = ~5 pixels for 64x64
    rotation = 0.0
    
    mask, vertices = create_ellipse_mask(
        grid_size=GRID_SIZE,
        center=center_normalized,
        semi_major=semi_major,
        semi_minor=semi_minor,
        rotation_angle=rotation,
        seed=42,
        verbose=False
    )
    
    assert mask.shape == (GRID_SIZE, GRID_SIZE), f"Wrong mask shape: {mask.shape}"
    assert mask.dtype == bool, f"Wrong mask dtype: {mask.dtype}"
    assert mask.sum() > 0, "Empty mask"
    print(f"  Ellipse mask: {mask.sum()} pixels, a={semi_major}, b={semi_minor} (normalized)")
    
    # Test 2: Verify aspect ratio affects shape
    mask_circular, _ = create_ellipse_mask(GRID_SIZE, (0.5, 0.5), 0.25, 0.25, 0.0, 42, False)
    mask_elongated, _ = create_ellipse_mask(GRID_SIZE, (0.5, 0.5), 0.35, 0.12, 0.0, 42, False)
    
    # Elongated should have different extent along axes
    circular_extent_y = np.where(mask_circular.any(axis=1))[0]
    circular_extent_x = np.where(mask_circular.any(axis=0))[0]
    elongated_extent_y = np.where(mask_elongated.any(axis=1))[0]
    elongated_extent_x = np.where(mask_elongated.any(axis=0))[0]
    
    circular_ratio = len(circular_extent_y) / len(circular_extent_x)
    elongated_ratio = len(elongated_extent_y) / len(elongated_extent_x)
    
    print(f"  Circular aspect: {circular_ratio:.2f}, Elongated aspect: {elongated_ratio:.2f}")
    assert abs(circular_ratio - 1.0) < 0.1, "Circular should have ~1:1 aspect"
    assert elongated_ratio < 0.6, f"Elongated should be flattened, got {elongated_ratio}"
    
    # Test 3: Full particle creation with ellipse shape
    # Note: create_particle_with_shape handles the coordinate conversion internally
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE,
        shape_type='ellipse',
        outer_radius=18,
        core_fraction=0.5,
        pixel_size=5.0,
        shape_params={'aspect_ratio': 0.6, 'rotation': np.pi/4},
        verbose=False
    )
    
    assert particle.shape == (2, GRID_SIZE, GRID_SIZE), f"Wrong particle shape: {particle.shape}"
    assert 'outer_mask' in info, "Missing outer_mask in info"
    print(f"  Ellipse particle created: {info['outer_mask'].sum()} pixels in outer region")

run_test("Ellipse Shape Creation", test_ellipse_shape)


TEST: Ellipse Shape Creation
  Ellipse mask: 140 pixels, a=0.3, b=0.15 (normalized)
  Circular aspect: 1.00, Elongated aspect: 0.36
  Ellipse particle created: 412 pixels in outer region

  PASSED: Ellipse Shape Creation


---
## Test 15: Boundary Validation Functions

In [16]:
def test_boundary_validation():
    """Test boundary validation to ensure particles fit within grid."""
    from src.core_shell import validate_particle_bounds, clamp_radius_to_grid
    
    GRID_SIZE = 64
    
    # Test 1: Particle that fits
    center = (32, 32)  # Center of grid
    radius = 20
    margin = 2
    
    result = validate_particle_bounds(radius, center, GRID_SIZE, margin)
    assert result == True, f"Centered particle should fit: radius={radius}, center={center}"
    print(f"  Centered particle (r={radius}) at {center}: fits={result}")
    
    # Test 2: Particle too big to fit
    radius_large = 35  # Would extend beyond edges
    result = validate_particle_bounds(radius_large, center, GRID_SIZE, margin)
    assert result == False, f"Large particle should NOT fit: radius={radius_large}"
    print(f"  Large particle (r={radius_large}) at {center}: fits={result}")
    
    # Test 3: Off-center particle that still fits
    center_off = (32, 40)
    radius_small = 15
    result = validate_particle_bounds(radius_small, center_off, GRID_SIZE, margin)
    assert result == True, f"Off-center small particle should fit"
    print(f"  Off-center particle (r={radius_small}) at {center_off}: fits={result}")
    
    # Test 4: Off-center particle that doesn't fit
    center_edge = (32, 55)
    radius_medium = 15
    result = validate_particle_bounds(radius_medium, center_edge, GRID_SIZE, margin)
    assert result == False, f"Particle near edge should NOT fit"
    print(f"  Edge particle (r={radius_medium}) at {center_edge}: fits={result}")
    
    # Test 5: clamp_radius_to_grid
    center = (32, 32)
    max_radius = clamp_radius_to_grid(center, GRID_SIZE, margin)
    expected_max = 32 - margin  # 30
    assert max_radius == expected_max, f"Expected max radius {expected_max}, got {max_radius}"
    print(f"  Max radius at center {center}: {max_radius}")
    
    # Test 6: clamp for off-center position
    center_off = (32, 50)  # Closer to right edge
    max_radius = clamp_radius_to_grid(center_off, GRID_SIZE, margin)
    expected_max = GRID_SIZE - margin - 50  # 12
    assert max_radius == expected_max, f"Expected max radius {expected_max}, got {max_radius}"
    print(f"  Max radius at off-center {center_off}: {max_radius}")
    
    # Test 7: Verify minimum radius enforcement
    center_extreme = (5, 5)  # Very close to corner
    max_radius = clamp_radius_to_grid(center_extreme, GRID_SIZE, margin)
    assert max_radius >= 10, f"Should enforce minimum radius of 10, got {max_radius}"
    print(f"  Max radius at extreme position {center_extreme}: {max_radius} (min enforced)")

run_test("Boundary Validation Functions", test_boundary_validation)


TEST: Boundary Validation Functions
  Centered particle (r=20) at (32, 32): fits=True
  Large particle (r=35) at (32, 32): fits=False
  Off-center particle (r=15) at (32, 40): fits=True
  Edge particle (r=15) at (32, 55): fits=False
  Max radius at center (32, 32): 30
  Max radius at off-center (32, 50): 12
  Max radius at extreme position (5, 5): 10.0 (min enforced)

  PASSED: Boundary Validation Functions


---
## Test 16: Composition Modes

In [17]:
def test_composition_modes():
    """Test all composition modes: sharp, radial_gradient, linear_gradient, janus, multi_shell, uniform."""
    from src.core_shell import create_particle_with_shape, SPECIES_NI, SPECIES_FE
    
    GRID_SIZE = 64
    
    composition_tests = [
        # (mode, params, description, verification_func)
        ('sharp', None, 'Traditional core-shell', None),
        ('radial_gradient', {'transition_width': 0.3}, 'Smooth center-to-edge', None),
        ('linear_gradient', {'gradient_direction': np.pi/4, 'gradient_width': 20.0}, 'Linear gradient', None),
        ('janus', {'split_angle': 0.0, 'interface_width': 0.0}, 'Left-right split', None),
        ('multi_shell', {'n_shells': 3, 'transition_width': 0.0}, '3-shell onion', None),
        ('uniform', {'composition': {'Ni': 0.75, 'Fe': 0.25}}, 'Uniform Ni3Fe', None),
        ('uniform', {'composition': {'Ni': 1.0, 'Fe': 0.0}}, 'Pure Ni', None),
        ('uniform', {'composition': {'Ni': 0.0, 'Fe': 1.0}}, 'Pure Fe', None),
    ]
    
    for mode, params, desc, _ in composition_tests:
        particle, info = create_particle_with_shape(
            grid_size=GRID_SIZE,
            shape_type='circle',
            outer_radius=20,
            core_fraction=0.5,
            pixel_size=5.0,
            composition_mode=mode,
            composition_params=params,
            verbose=False
        )
        
        assert particle.shape == (2, GRID_SIZE, GRID_SIZE), f"Wrong shape for {mode}"
        
        # Get total composition
        outer_mask = info['outer_mask']
        ni_content = np.abs(particle[SPECIES_NI][outer_mask]).sum()
        fe_content = np.abs(particle[SPECIES_FE][outer_mask]).sum()
        total = ni_content + fe_content
        ni_frac = ni_content / total if total > 0 else 0
        fe_frac = fe_content / total if total > 0 else 0
        
        print(f"  {mode:16s} ({desc:20s}): Ni={ni_frac:.2f}, Fe={fe_frac:.2f}")
        
        # Specific checks for uniform compositions
        if mode == 'uniform':
            expected_ni = params['composition']['Ni']
            expected_fe = params['composition']['Fe']
            # Allow some tolerance for averaging effects
            assert abs(ni_frac - expected_ni) < 0.05, f"Ni fraction wrong: {ni_frac} vs {expected_ni}"
            assert abs(fe_frac - expected_fe) < 0.05, f"Fe fraction wrong: {fe_frac} vs {expected_fe}"
    
    print("  All composition modes work correctly")

run_test("Composition Modes", test_composition_modes)


TEST: Composition Modes
  sharp            (Traditional core-shell): Ni=0.94, Fe=0.06
  radial_gradient  (Smooth center-to-edge): Ni=0.86, Fe=0.14
  linear_gradient  (Linear gradient     ): Ni=0.88, Fe=0.12
  janus            (Left-right split    ): Ni=0.88, Fe=0.12
  multi_shell      (3-shell onion       ): Ni=0.80, Fe=0.20
  uniform          (Uniform Ni3Fe       ): Ni=0.75, Fe=0.25
  uniform          (Pure Ni             ): Ni=1.00, Fe=0.00
  uniform          (Pure Fe             ): Ni=0.00, Fe=1.00
  All composition modes work correctly

  PASSED: Composition Modes


---
## Test 17: Winterbottom Truncation

In [18]:
def test_winterbottom_truncation():
    """Test Winterbottom truncation at various fractions."""
    from src.core_shell import apply_winterbottom_truncation, create_particle_with_shape
    
    GRID_SIZE = 64
    center = (GRID_SIZE // 2, GRID_SIZE // 2)
    
    # Create a circular mask for testing
    y, x = np.ogrid[:GRID_SIZE, :GRID_SIZE]
    dist_sq = (y - center[0])**2 + (x - center[1])**2
    radius = 25
    base_mask = dist_sq <= radius**2
    base_pixels = base_mask.sum()
    print(f"  Base circular mask: {base_pixels} pixels")
    
    # Test various truncation fractions
    truncation_tests = [0.1, 0.25, 0.5, 0.75]
    
    for frac in truncation_tests:
        truncated = apply_winterbottom_truncation(
            mask=base_mask.copy(),
            center=center,
            truncation_fraction=frac,
            truncation_angle=0.0  # Truncate from bottom
        )
        
        truncated_pixels = truncated.sum()
        removed_fraction = 1 - truncated_pixels / base_pixels
        
        print(f"  Truncation={frac:.2f}: {truncated_pixels} pixels, removed ~{removed_fraction:.2%}")
        
        # Verify truncation removes approximately the expected amount
        # (won't be exact due to pixelation, but should be in right ballpark)
        assert truncated_pixels < base_pixels, f"Truncation should remove pixels"
        assert truncated_pixels > 0, f"Truncation shouldn't remove everything"
    
    # Test truncation at different angles
    angle_tests = [0, np.pi/4, np.pi/2, np.pi]
    for angle in angle_tests:
        truncated = apply_winterbottom_truncation(
            mask=base_mask.copy(),
            center=center,
            truncation_fraction=0.3,
            truncation_angle=angle
        )
        print(f"  Angle={np.degrees(angle):.0f}°: {truncated.sum()} pixels")
    
    # Test full particle with truncation via create_particle_with_shape
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE,
        shape_type='circle',
        outer_radius=20,
        core_fraction=0.5,
        pixel_size=5.0,
        truncation_fraction=0.4,
        truncation_angle=np.pi/6,
        verbose=False
    )
    
    assert particle.shape == (2, GRID_SIZE, GRID_SIZE), f"Wrong shape"
    print(f"  Truncated particle created: {info['outer_mask'].sum()} pixels")

run_test("Winterbottom Truncation", test_winterbottom_truncation)


TEST: Winterbottom Truncation
  Base circular mask: 1961 pixels
  Truncation=0.10: 116 pixels, removed ~94.08%
  Truncation=0.25: 385 pixels, removed ~80.37%
  Truncation=0.50: 1006 pixels, removed ~48.70%
  Truncation=0.75: 1576 pixels, removed ~19.63%
  Angle=0°: 518 pixels
  Angle=45°: 482 pixels
  Angle=90°: 510 pixels
  Angle=180°: 503 pixels
  Truncated particle created: 470 pixels

  PASSED: Winterbottom Truncation


---
## Test 18: Off-Center Core

In [19]:
def test_off_center_core():
    """Test off-center (eccentric) core placement."""
    from src.core_shell import compute_off_center_core_mask, create_particle_with_shape
    
    GRID_SIZE = 64
    center = (GRID_SIZE // 2, GRID_SIZE // 2)
    
    # Create outer mask
    y, x = np.ogrid[:GRID_SIZE, :GRID_SIZE]
    dist_sq = (y - center[0])**2 + (x - center[1])**2
    outer_radius = 25
    outer_mask = dist_sq <= outer_radius**2
    
    # Test centered core
    core_mask_centered = compute_off_center_core_mask(
        outer_mask=outer_mask,
        center=center,
        core_fraction=0.5,
        core_offset=(0.0, 0.0),
        shape_type='circle'
    )
    
    # Test off-center core
    core_mask_offset = compute_off_center_core_mask(
        outer_mask=outer_mask,
        center=center,
        core_fraction=0.5,
        core_offset=(5.0, 8.0),  # Offset by 5 pixels in y, 8 pixels in x
        shape_type='circle'
    )
    
    print(f"  Centered core: {core_mask_centered.sum()} pixels")
    print(f"  Off-center core: {core_mask_offset.sum()} pixels")
    
    # Verify core is contained within outer mask
    assert (core_mask_offset & ~outer_mask).sum() == 0, "Core should be inside outer mask"
    
    # Find centroids
    def find_centroid(mask):
        y_coords, x_coords = np.where(mask)
        return np.mean(y_coords), np.mean(x_coords)
    
    cy_centered, cx_centered = find_centroid(core_mask_centered)
    cy_offset, cx_offset = find_centroid(core_mask_offset)
    
    print(f"  Centered core centroid: ({cy_centered:.1f}, {cx_centered:.1f})")
    print(f"  Off-center core centroid: ({cy_offset:.1f}, {cx_offset:.1f})")
    
    # The off-center core should have shifted centroid
    assert abs(cx_offset - cx_centered) > 3, "X offset should be visible in centroid"
    assert abs(cy_offset - cy_centered) > 2, "Y offset should be visible in centroid"
    
    # Test full particle with off-center core via create_particle_with_shape
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE,
        shape_type='circle',
        outer_radius=20,
        core_fraction=0.5,
        pixel_size=5.0,
        core_offset=(3.0, 4.0),
        verbose=False
    )
    
    assert particle.shape == (2, GRID_SIZE, GRID_SIZE), f"Wrong shape"
    print(f"  Off-center particle created: core={info['core_mask'].sum()} pixels")

run_test("Off-Center Core", test_off_center_core)


TEST: Off-Center Core
  Centered core: 489 pixels
  Off-center core: 489 pixels
  Centered core centroid: (32.0, 32.0)
  Off-center core centroid: (37.0, 40.0)
  Off-center particle created: core=317 pixels

  PASSED: Off-Center Core


---
## Test 19: Size Variation with Boundary Checking

In [20]:
def test_size_variation():
    """Test particles at various sizes and verify they fit in grid."""
    from src.core_shell import create_particle_with_shape, validate_particle_bounds
    
    GRID_SIZE = 128  # Standard size
    center = (GRID_SIZE // 2, GRID_SIZE // 2)
    
    # Test various radius sizes (simulating scale factors from 0.7 to 1.3)
    base_radius = 40
    scale_factors = [0.7, 0.85, 1.0, 1.15, 1.3]
    
    for scale in scale_factors:
        outer_radius = int(base_radius * scale)
        
        # Verify it fits before creating
        fits = validate_particle_bounds(outer_radius, center, GRID_SIZE, margin=2)
        
        if fits:
            particle, info = create_particle_with_shape(
                grid_size=GRID_SIZE,
                shape_type='circle',
                outer_radius=outer_radius,
                core_fraction=0.5,
                pixel_size=5.0,
                verbose=False
            )
            
            # Verify particle stays within bounds
            outer_mask = info['outer_mask']
            
            # Check no pixels at edges
            edge_top = outer_mask[0, :].sum()
            edge_bottom = outer_mask[-1, :].sum()
            edge_left = outer_mask[:, 0].sum()
            edge_right = outer_mask[:, -1].sum()
            
            total_edge = edge_top + edge_bottom + edge_left + edge_right
            
            print(f"  Scale={scale:.2f}, r={outer_radius}: {outer_mask.sum()} pixels, edge_pixels={total_edge}")
            
            # With margin=2, should have no pixels at the very edges
            assert total_edge == 0, f"Particle touches edges at scale={scale}"
        else:
            print(f"  Scale={scale:.2f}, r={outer_radius}: WOULD NOT FIT (boundary check caught it)")
    
    # Test that boundary check catches too-large particles
    max_possible = GRID_SIZE // 2 - 2  # 62 for 128 grid
    result = validate_particle_bounds(max_possible + 5, center, GRID_SIZE, margin=2)
    assert result == False, "Should reject particle larger than grid allows"
    print(f"  Boundary check correctly rejects r={max_possible + 5}")

run_test("Size Variation with Boundary Checking", test_size_variation)


TEST: Size Variation with Boundary Checking
  Scale=0.70, r=28: 2453 pixels, edge_pixels=0
  Scale=0.85, r=34: 3625 pixels, edge_pixels=0
  Scale=1.00, r=40: 5025 pixels, edge_pixels=0
  Scale=1.15, r=46: 6625 pixels, edge_pixels=0
  Scale=1.30, r=52: 8497 pixels, edge_pixels=0
  Boundary check correctly rejects r=67

  PASSED: Size Variation with Boundary Checking


---
## Test 20: Backwards Compatibility

In [21]:
def test_backwards_compatibility():
    """Test that old-style calls without new parameters still work."""
    from src.core_shell import create_particle_with_shape, SPECIES_NI, SPECIES_FE
    
    GRID_SIZE = 64
    
    # Original call style (no composition_mode, no truncation, no offset)
    particle, info = create_particle_with_shape(
        grid_size=GRID_SIZE,
        shape_type='hexagon',
        outer_radius=20,
        core_fraction=0.5,
        pixel_size=5.0,
        shape_params={'anisotropy': 1.2},
        verbose=False
    )
    
    assert particle.shape == (2, GRID_SIZE, GRID_SIZE), f"Wrong shape"
    assert 'outer_mask' in info, "Missing outer_mask"
    assert 'core_mask' in info, "Missing core_mask"
    assert 'shell_mask' in info, "Missing shell_mask"
    
    # Verify core-shell composition (default behavior)
    # Default composition: Core = Ni3Fe (75% Ni, 25% Fe), Shell = pure Ni (100% Ni)
    core_mask = info['core_mask']
    shell_mask = info['shell_mask']
    
    fe_in_core = np.abs(particle[SPECIES_FE][core_mask]).sum()
    ni_in_core = np.abs(particle[SPECIES_NI][core_mask]).sum()
    
    fe_in_shell = np.abs(particle[SPECIES_FE][shell_mask]).sum()
    ni_in_shell = np.abs(particle[SPECIES_NI][shell_mask]).sum()
    
    print(f"  Default hexagon particle created")
    print(f"  Core: Ni={ni_in_core:.1f}, Fe={fe_in_core:.1f} (expected ~75% Ni, ~25% Fe)")
    print(f"  Shell: Ni={ni_in_shell:.1f}, Fe={fe_in_shell:.1f} (expected ~100% Ni)")
    
    # Core is Ni3Fe (75% Ni, 25% Fe) - Ni-rich, with some Fe
    # Check that core has both Ni and Fe, with Ni dominant
    assert ni_in_core > fe_in_core, "Core should be Ni-rich (Ni3Fe = 75% Ni)"
    assert fe_in_core > 0, "Core should contain some Fe (Ni3Fe = 25% Fe)"
    
    # Shell should be nearly pure Ni (no Fe)
    assert ni_in_shell > 0, "Shell should contain Ni"
    assert fe_in_shell < ni_in_shell * 0.1, "Shell should be nearly pure Ni (< 10% Fe)"
    
    print("  Default core-shell composition verified (Ni3Fe core, pure Ni shell)")
    
    # Test that other old-style shapes still work
    for shape in ['circle', 'polygon', 'polygon_centrosymmetric']:
        params = {'n_vertices': 5} if 'polygon' in shape else {}
        p, i = create_particle_with_shape(
            grid_size=GRID_SIZE,
            shape_type=shape,
            outer_radius=18,
            core_fraction=0.5,
            pixel_size=5.0,
            shape_params=params,
            verbose=False
        )
        assert p.shape == (2, GRID_SIZE, GRID_SIZE), f"Shape {shape} failed"
        print(f"  {shape}: OK")

run_test("Backwards Compatibility", test_backwards_compatibility)


TEST: Backwards Compatibility
  Default hexagon particle created
  Core: Ni=52.5, Fe=17.5 (expected ~75% Ni, ~25% Fe)
  Shell: Ni=210.0, Fe=0.0 (expected ~100% Ni)
  Default core-shell composition verified (Ni3Fe core, pure Ni shell)
  circle: OK
  polygon: OK
  polygon_centrosymmetric: OK

  PASSED: Backwards Compatibility


---
## Test 21: Training Data Generation with New Features

In [22]:
def test_training_data_new_features():
    """Test training data generation with new composition modes and geometry features."""
    from src.generate_data import (
        sample_shape_type, sample_parameters, sample_composition_mode,
        sample_composition_parameters, process_particle,
        SHAPE_DISTRIBUTION, COMPOSITION_DISTRIBUTION
    )
    from src.core_shell import ScatteringFactors
    
    sf = ScatteringFactors()  # Auto-detect data directory
    
    # Test 1: Verify distributions are defined
    print(f"  SHAPE_DISTRIBUTION: {list(SHAPE_DISTRIBUTION.keys())}")
    assert 'ellipse' in SHAPE_DISTRIBUTION, "Missing ellipse shape"
    assert sum(SHAPE_DISTRIBUTION.values()) > 0.99, "Shape distribution should sum to ~1"
    
    print(f"  COMPOSITION_DISTRIBUTION: {list(COMPOSITION_DISTRIBUTION.keys())}")
    assert 'uniform' in COMPOSITION_DISTRIBUTION, "Missing uniform mode"
    assert 'radial_gradient' in COMPOSITION_DISTRIBUTION, "Missing radial_gradient mode"
    assert sum(COMPOSITION_DISTRIBUTION.values()) > 0.99, "Composition distribution should sum to ~1"
    
    # Test 2: Sample various shape types
    shape_counts = {}
    for _ in range(100):
        shape = sample_shape_type()
        shape_counts[shape] = shape_counts.get(shape, 0) + 1
    print(f"  Sampled shapes (100 trials): {shape_counts}")
    
    # Test 3: Sample various composition modes
    mode_counts = {}
    for _ in range(100):
        mode = sample_composition_mode()
        mode_counts[mode] = mode_counts.get(mode, 0) + 1
    print(f"  Sampled modes (100 trials): {mode_counts}")
    
    # Test 4: Generate particles with specific composition modes
    test_modes = ['sharp', 'radial_gradient', 'janus', 'uniform']
    for mode in test_modes:
        shape = sample_shape_type()
        params = sample_parameters(shape, mode)
        
        # Verify params contain new fields
        assert 'composition_mode' in params, f"Missing composition_mode in params for {mode}"
        assert params['composition_mode'] == mode, f"Wrong mode: {params['composition_mode']}"
        
        # Generate the particle
        data = process_particle(params, sf, verbose=False)
        assert data is not None, f"process_particle failed for mode={mode}"
        print(f"  {mode}: Generated {data['n_patches']} patches")
    
    # Test 5: Verify truncation and off-center are sometimes sampled
    has_truncation = False
    has_offset = False
    for _ in range(50):
        shape = sample_shape_type()
        mode = sample_composition_mode()
        params = sample_parameters(shape, mode)
        if params.get('truncation_fraction', 0) > 0:
            has_truncation = True
        if params.get('core_offset', (0, 0)) != (0, 0):
            has_offset = True
    
    print(f"  Truncation sampled in 50 trials: {has_truncation}")
    print(f"  Off-center core sampled in 50 trials: {has_offset}")

run_test("Training Data Generation with New Features", test_training_data_new_features)


TEST: Training Data Generation with New Features
Loading scattering factors for Ni from ../data/Nickel.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
Loading scattering factors for Fe from ../data/Iron.f1f2...
  Loaded 4601 data points
  Energy range: 2000.0 - 25000.0 eV
  SHAPE_DISTRIBUTION: ['hexagon', 'polygon', 'polygon_centrosymmetric', 'circle', 'ellipse']
  COMPOSITION_DISTRIBUTION: ['sharp', 'radial_gradient', 'linear_gradient', 'janus', 'multi_shell', 'uniform']
  Sampled shapes (100 trials): {'hexagon': 35, 'polygon': 21, 'circle': 17, 'polygon_centrosymmetric': 13, 'ellipse': 14}
  Sampled modes (100 trials): {'linear_gradient': 16, 'multi_shell': 4, 'sharp': 42, 'janus': 7, 'radial_gradient': 11, 'uniform': 20}
  sharp: Generated 64 patches
  radial_gradient: Generated 64 patches
  janus: Generated 64 patches
  uniform: Generated 64 patches
  Truncation sampled in 50 trials: True
  Off-center core sampled in 50 trials: True

  PASSED: Training Data G

---
## Results Summary

In [23]:
# Print summary
print("\n" + "="*80)
print("TEST RESULTS SUMMARY")
print("="*80)

passed = sum(1 for _, status, _ in test_results if status == 'PASSED')
failed = sum(1 for _, status, _ in test_results if status == 'FAILED')
skipped = sum(1 for _, status, _ in test_results if status == 'SKIPPED')

for name, status, error in test_results:
    if status == 'PASSED':
        symbol = "PASS"
    elif status == 'SKIPPED':
        symbol = "SKIP"
    else:
        symbol = "FAIL"
    print(f"  [{symbol}] {name}")
    if error and status == 'FAILED':
        # Show more of the error message
        error_display = error[:120] + "..." if len(error) > 120 else error
        print(f"         Error: {error_display}")

print("\n" + "-"*80)
print(f"TOTAL: {passed} passed, {failed} failed, {skipped} skipped out of {len(test_results)} tests")
print("="*80)

if failed == 0:
    if skipped > 0:
        print(f"\nALL RUNNABLE TESTS PASSED! ({skipped} skipped due to missing PyTorch)")
    else:
        print("\nALL TESTS PASSED!")
else:
    print(f"\n{failed} TEST(S) FAILED - see details above")


TEST RESULTS SUMMARY
  [PASS] Core Module Imports
  [SKIP] PyTorch Module Imports
  [PASS] Scattering Factors
  [PASS] Particle Creation (All Types)
  [PASS] Strain Field Generation
  [PASS] Diffraction Computation
  [PASS] Ground Truth Labels
  [SKIP] CNN Architecture
  [SKIP] Loss Function
  [PASS] Data Augmentation
  [PASS] MAD Equation Consistency
  [PASS] Training Data Generation
  [SKIP] Dataset Loading
  [PASS] Ellipse Shape Creation
  [PASS] Boundary Validation Functions
  [PASS] Composition Modes
  [PASS] Winterbottom Truncation
  [PASS] Off-Center Core
  [PASS] Size Variation with Boundary Checking
  [PASS] Backwards Compatibility
  [PASS] Training Data Generation with New Features

--------------------------------------------------------------------------------
TOTAL: 17 passed, 0 failed, 4 skipped out of 21 tests

ALL RUNNABLE TESTS PASSED! (4 skipped due to missing PyTorch)
