## import and init

In [8]:
# %%
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    evaluate_spot_metrics,
    format_metric_report,
    save_prediction_diagnostics,
)
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    save_mode_triptych,
)

print("✓ All modules imported successfully")

# %%


✓ All modules imported successfully


## random seed and device 

In [9]:
# %%
SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

print(f"✓ Random seed set to {SEED}")

# %%
if torch.cuda.is_available():
    device = torch.device('cuda:2')
    print(f'Using Device: {device}')
    print(f'GPU Name: {torch.cuda.get_device_name(device)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB')
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

# %%


✓ Random seed set to 424242
Using Device: CPU


## parameter and functions

In [None]:
# %%
z_layers = 40e-6
pixel_size = 1e-6
z_prop = 120e-6
wavelength = 1568e-9
z_input_to_first = 40e-6

field_size = 25
layer_size = 110
num_modes = 6

circle_focus_radius = 5
circle_detectsize = 10
eigenmode_focus_radius = 12.5
eigenmode_detectsize = 15

batch_size = 16
epochs = 1000

training_dataset_mode = "eigenmode"
phase_option = 4

evaluation_mode = "superposition"
num_superposition_eval_samples = 1000
superposition_eval_seed = 20240116

label_pattern_mode = "circle"
pred_case = 1

show_detection_overlap_debug = True
detection_overlap_label_index = 0
prop_slices_per_segment = 10
prop_output_slices = 10

num_layer_option = [2, 3, 4, 5, 6]

print("Parameters configured")

# %%
all_losses = []
all_phase_masks = []
model_metrics = []
all_amplitudes_diff = []
all_average_amplitudes_diff = []
all_amplitudes_relative_diff = []
all_complex_weights_pred = []
all_image_data_pred = []
all_cc_real = []
all_cc_imag = []
all_cc_recon_amp = []
all_cc_recon_phase = []
all_training_summaries = []

print("✓ Storage variables initialized")

# %%
def build_mode_context(base_modes, num_modes):
    if base_modes.shape[2] < num_modes:
        raise ValueError(
            f"Requested {num_modes} modes, but source file only has {base_modes.shape[2]}."
        )
    
    mmf_data = base_modes[:, :, :num_modes].transpose(2, 0, 1)
    mmf_data_amp = np.abs(mmf_data)
    mmf_data_amp_norm = (mmf_data_amp - mmf_data_amp.min()) / (mmf_data_amp.max() - mmf_data_amp.min())
    mmf_data = mmf_data_amp_norm * np.exp(1j * np.angle(mmf_data))

    if phase_option in [1, 2, 3, 5]:
        base_amplitudes_local, base_phases_local = generate_complex_weights(
            1000, num_modes, phase_option
        )
    elif phase_option == 4:
        base_amplitudes_local = np.eye(num_modes, dtype=np.float32)
        base_phases_local = np.zeros((num_modes, num_modes), dtype=np.float32)
    else:
        raise ValueError(f"Unsupported phase_option: {phase_option}")

    return {
        "mmf_data_np": mmf_data,
        "mmf_data_ts": torch.from_numpy(mmf_data),
        "base_amplitudes": base_amplitudes_local,
        "base_phases": base_phases_local,
    }


def build_uniform_fractions(count):
    if count <= 0:
        return ()
    fractions = np.linspace(1.0 / (count + 1), count / (count + 1), count, dtype=float)
    return tuple(float(f) for f in fractions)

print("✓ Helper functions defined")

# %%


Parameters configured
✓ Storage variables initialized
✓ Helper functions defined


## data loading and label generation

In [11]:
# %%
print("Loading eigenmode data...")
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_103modes_25_PD_1.15.mat',
    key='modes_field'
)

print(f"✓ Loaded modes shape: {eigenmodes_OM4.shape}")
print(f"  Data type: {eigenmodes_OM4.dtype}")

# %%
print(f"\nBuilding mode context (using {num_modes} modes)...")
mode_context = build_mode_context(eigenmodes_OM4, num_modes)

MMF_data = mode_context["mmf_data_np"]
MMF_data_ts = mode_context["mmf_data_ts"]
base_amplitudes = mode_context["base_amplitudes"]
base_phases = mode_context["base_phases"]

print(f"✓ Mode context built")
print(f"  MMF data shape: {MMF_data.shape}")

# %%
fig, axes = plt.subplots(2, num_modes, figsize=(3*num_modes, 6))

for i in range(num_modes):
    im_amp = axes[0, i].imshow(np.abs(MMF_data[i]), cmap='hot')
    axes[0, i].set_title(f'Mode {i+1} Amplitude')
    axes[0, i].axis('off')
    plt.colorbar(im_amp, ax=axes[0, i], fraction=0.046)
    
    im_phase = axes[1, i].imshow(np.angle(MMF_data[i]), cmap='hsv', vmin=-np.pi, vmax=np.pi)
    axes[1, i].set_title(f'Mode {i+1} Phase')
    axes[1, i].axis('off')
    plt.colorbar(im_phase, ax=axes[1, i], fraction=0.046)

plt.tight_layout()
os.makedirs('results', exist_ok=True)
plt.savefig('results/eigenmodes_visualization.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ Eigenmode visualization saved")

# %%
print("\nGenerating labels...")
label_size = layer_size
focus_radius = circle_focus_radius
detectsize = circle_detectsize

if pred_case == 1:
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    
    if label_pattern_mode == "eigenmode":
        print("  Using eigenmode patterns...")
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
        
    elif label_pattern_mode == "circle":
        print("  Using circular patterns...")
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1
        pattern_stack = generate_detector_patterns(
            pattern_size, pattern_size, num_detector, shape="circle"
        )
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    focus_radius = detector_focus_radius
    detectsize = detector_detectsize
    
    print(f"  Layout radius: {layout_radius}")

# %%
centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)

print(f"\nDetector centers:")
for i, (cx, cy) in enumerate(centers):
    print(f"  Detector {i+1}: ({cx:.1f}, {cy:.1f})")

# %%
mode_label_maps = [
    compose_labels_from_patterns(
        label_size,
        label_size,
        pattern_stack,
        centers,
        Index=i + 1,
        visualize=False,
    )
    for i in range(num_detector)
]

MMF_Label_data = torch.from_numpy(
    np.stack(mode_label_maps, axis=2).astype(np.float32)
)

print(f"✓ Label data generated, shape: {MMF_Label_data.shape}")

# %%
fig, axes = plt.subplots(1, num_modes, figsize=(4*num_modes, 4))
if num_modes == 1:
    axes = [axes]

for i in range(num_modes):
    im = axes[i].imshow(mode_label_maps[i], cmap='inferno')
    axes[i].set_title(f'Label for Mode {i+1}')
    axes[i].axis('off')
    plt.colorbar(im, ax=axes[i], fraction=0.046)

plt.tight_layout()
plt.savefig('results/labels_visualization.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ Label visualization saved")

# %%


Loading eigenmode data...
✓ Loaded modes shape: (25, 25, 103)
  Data type: complex64

Building mode context (using 3 modes)...
✓ Mode context built
  MMF data shape: (3, 25, 25)
✓ Eigenmode visualization saved

Generating labels...
  Using circular patterns...
  Layout radius: 5
相邻图案边缘间距： 行=50.00, 列=20.00
相邻图案中心间距： 行=60.00, 列=30.00
中心坐标： [(55, 25), (55, 55), (55, 85)]

Detector centers:
  Detector 1: (55.0, 25.0)
  Detector 2: (55.0, 55.0)
  Detector 3: (55.0, 85.0)
✓ Label data generated, shape: torch.Size([110, 110, 3])
✓ Label visualization saved


## trainning data set

In [12]:
# %%
print(f"\nBuilding training dataset (mode: {training_dataset_mode})...")

if training_dataset_mode == "eigenmode":
    if phase_option == 4:
        num_train_samples = num_modes
        amplitudes = base_amplitudes[:num_train_samples]
        phases = base_phases[:num_train_samples]
    else:
        amplitudes = base_amplitudes
        phases = base_phases
        num_train_samples = amplitudes.shape[0]

    print(f"  Number of training samples: {num_train_samples}")
    
    amplitudes_phases = np.hstack((amplitudes, phases[:, 1:] / (2 * np.pi)))
    
    label_data = torch.zeros([num_train_samples, 1, layer_size, layer_size])
    amplitude_weights = torch.from_numpy(amplitudes_phases[:, 0:num_modes]).float()
    energy_weights = amplitude_weights**2
    combined_labels = (
        energy_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels

    complex_weights = amplitudes * np.exp(1j * phases)
    complex_weights_ts = torch.from_numpy(complex_weights.astype(np.complex64))
    image_data = generate_fields_ts(
        complex_weights_ts, MMF_data_ts, num_train_samples, num_modes, field_size
    ).to(torch.complex64)

    train_dataset = [
        prepare_sample(image_data[i], label_data[i], layer_size) 
        for i in range(num_train_samples)
    ]
    train_tensor_data = TensorDataset(
        *[torch.stack(tensors) for tensors in zip(*train_dataset)]
    )
    
elif training_dataset_mode == "superposition":
    num_train_samples = 100
    print(f"  Number of training samples: {num_train_samples}")
    
    super_train_ctx = build_superposition_eval_context(
        num_train_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
        rng_seed=20240115,
    )
    train_dataset = super_train_ctx["dataset"]
    train_tensor_data = super_train_ctx["tensor_dataset"]
    image_data = super_train_ctx["image_data"]
    label_data = train_tensor_data.tensors[1]
    amplitudes = super_train_ctx["amplitudes"]
    phases = super_train_ctx["phases"]
    amplitudes_phases = super_train_ctx["amplitudes_phases"]
else:
    raise ValueError(f"Unknown training_dataset_mode: {training_dataset_mode}")

label_test_data = label_data
image_test_data = image_data

print(f"✓ Training dataset built")
print(f"  Image data shape: {image_data.shape}")
print(f"  Label data shape: {label_data.shape}")

# %%
fig, axes = plt.subplots(2, min(3, num_train_samples), figsize=(12, 8))
if num_train_samples == 1:
    axes = axes.reshape(2, 1)

for i in range(min(3, num_train_samples)):
    axes[0, i].imshow(np.abs(image_data[i, 0].cpu().numpy()), cmap='hot')
    axes[0, i].set_title(f'Sample {i+1} Input')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(label_data[i, 0].cpu().numpy(), cmap='inferno')
    axes[1, i].set_title(f'Sample {i+1} Label')
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig('results/training_samples_visualization.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ Training samples visualization saved")

# %%



Building training dataset (mode: eigenmode)...
  Number of training samples: 3
✓ Training dataset built
  Image data shape: torch.Size([3, 1, 25, 25])
  Label data shape: torch.Size([3, 1, 110, 110])
✓ Training samples visualization saved


## Test data set

In [13]:
# %%
print(f"\nBuilding test dataset (mode: {evaluation_mode})...")

g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,
    generator=g,
)

print(f"  Training batches: {len(train_loader)}")

superposition_eval_ctx = None

if evaluation_mode == "eigenmode":
    print("  Using eigenmode evaluation...")
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
    
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    
    print(f"  Generating {num_superposition_eval_samples} superposition samples...")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
        rng_seed=superposition_eval_seed,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
    superposition_eval_ctx = super_ctx
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

print(f"✓ Test dataset built")
print(f"  Test samples: {len(test_dataset)}")
print(f"  Test batches: {len(test_loader)}")

# %%



Building test dataset (mode: superposition)...
  Training batches: 1
  Generating 1000 superposition samples...
✓ Test dataset built
  Test samples: 1000
  Test batches: 63


## detector regions generation

In [14]:
# %%
print("\nGenerating detection regions...")

if pred_case == 1:
    evaluation_regions = create_evaluation_regions(
        layer_size, layer_size, num_detector, focus_radius, detectsize
    )
    
    print(f"✓ Detection regions generated:")
    for i, (x0, x1, y0, y1) in enumerate(evaluation_regions):
        print(f"  Detector {i+1}: x=[{x0}, {x1}], y=[{y0}, {y1}], size={x1-x0}×{y1-y0}")

# %%
if show_detection_overlap_debug:
    print("\nChecking detector overlap...")
    
    detection_debug_dir = Path("results/detection_region_debug")
    detection_debug_dir.mkdir(parents=True, exist_ok=True)
    
    overlap_map = np.zeros((layer_size, layer_size), dtype=np.float32)
    for (x0, x1, y0, y1) in evaluation_regions:
        overlap_map[y0:y1, x0:x1] += 1.0
    
    overlap_pixels = int(np.count_nonzero(overlap_map > 1.0 + 1e-6))
    max_overlap = float(overlap_map.max()) if overlap_map.size else 0.0
    
    label_sample_np = None
    if isinstance(label_data, torch.Tensor) and label_data.shape[0] > 0:
        sample_idx = min(max(0, detection_overlap_label_index), label_data.shape[0] - 1)
        label_sample_np = label_data[sample_idx, 0].detach().cpu().numpy()
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    if label_sample_np is not None:
        im0 = axes[0].imshow(label_sample_np, cmap="inferno")
        fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)
        axes[0].set_title(f"Label sample #{sample_idx + 1} with detectors")
    else:
        axes[0].imshow(np.zeros((layer_size, layer_size), dtype=np.float32), cmap="Greys")
        axes[0].set_title("Detector layout (no label sample)")
    axes[0].set_axis_off()
    
    circle_radius = focus_radius
    for idx_region, (x0, x1, y0, y1) in enumerate(evaluation_regions):
        color = plt.cm.tab20(idx_region % 20)
        
        rect = Rectangle(
            (x0, y0), x1 - x0, y1 - y0, 
            linewidth=1.0, edgecolor=color, facecolor='none'
        )
        axes[0].add_patch(rect)
        
        center_x = (x0 + x1) / 2.0
        center_y = (y0 + y1) / 2.0
        circle = Circle(
            (center_x, center_y),
            radius=circle_radius,
            linewidth=1.0,
            edgecolor=color,
            linestyle="--",
            fill=False,
        )
        axes[0].add_patch(circle)
        
        axes[0].text(
            x0 + 1, y0 + 4,
            f"M{idx_region + 1}",
            color=color,
            fontsize=8,
            weight="bold",
            ha="left",
            va="bottom",
            bbox=dict(boxstyle="round,pad=0.2", facecolor="black", alpha=0.4, edgecolor="none"),
        )
    
    im1 = axes[1].imshow(overlap_map, cmap="viridis")
    fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    axes[1].set_title("Detector coverage count (overlap map)")
    axes[1].set_axis_off()
    
    overlap_plot_path = detection_debug_dir / f"detection_overlap_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
    fig.tight_layout()
    fig.savefig(overlap_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    
    if overlap_pixels > 0:
        print(f"⚠ Detection regions overlap: {overlap_pixels} pixels")
    else:
        print("✔ No overlap detected")
    print(f"✔ Debug plot saved -> {overlap_plot_path}")

# %%



Generating detection regions...
✓ Detection regions generated:
  Detector 1: x=[20, 30], y=[50, 60], size=10×10
  Detector 2: x=[50, 60], y=[50, 60], size=10×10
  Detector 3: x=[80, 90], y=[50, 60], size=10×10

Checking detector overlap...


  plt.show()


✔ No overlap detected
✔ Debug plot saved -> results/detection_region_debug/detection_overlap_20260205_155543.png


## Trainning loop, evaluation and visiualization

In [15]:
# %%
print("\n" + "="*60)
print("Starting model training loop")
print("="*60)

for num_layer in num_layer_option:
    print(f"\n{'='*60}")
    print(f"Training D2NN with {num_layer} layers")
    print(f"{'='*60}\n")
    
    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,
    ).to(device)
    
    print(D2NN)
    
    total_params = sum(p.numel() for p in D2NN.parameters())
    trainable_params = sum(p.numel() for p in D2NN.parameters() if p.requires_grad)
    print(f"\nModel Statistics:")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99)
    scheduler = ExponentialLR(optimizer, gamma=0.99)
    
    print(f"\nTraining Configuration:")
    print(f"  Loss: MSE, Optimizer: Adam, LR: 1.99, Decay: 0.99")
    
    losses = []
    epoch_durations = []
    training_start_time = time.time()
    
    print(f"\nStarting training...")
    print(f"{'Epoch':<8} {'Loss':<20} {'Time (s)':<10} {'LR':<10}")
    print("-" * 50)
    
    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)
        losses.append(avg_loss)
        
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)
        
        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            current_lr = optimizer.param_groups[0]['lr']
            print(f"{epoch:<8} {avg_loss:<20.10f} {epoch_duration:<10.2f} {current_lr:<10.6f}")
    
    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    
    print(f"\n✓ Training completed")
    print(f"  Total time: {total_training_time:.2f} seconds ({total_training_time / 60:.2f} minutes)")
    print(f"  Final loss: {losses[-1]:.10f}")
    
    all_losses.append(losses)

# %%
# %%
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(epochs_array, losses, label="Training Loss", linewidth=2)
    ax.set_xlabel("Epoch", fontsize=12)
    ax.set_ylabel("Loss", fontsize=12)
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)", fontsize=14)
    ax.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
    ax.legend(fontsize=11)
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_m{num_modes}_ls{layer_size}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    
    fig_time, ax_time = plt.subplots(figsize=(10, 6))
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time", linewidth=2, color='tab:orange')
    ax_time.set_xlabel("Epoch", fontsize=12)
    ax_time.set_ylabel("Time (seconds)", fontsize=12)
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)", fontsize=14)
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)
    ax_time.legend(fontsize=11)
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_m{num_modes}_ls{layer_size}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)
    
    mat_path = training_output_dir / f"training_curves_layers{num_layer}_m{num_modes}_ls{layer_size}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )
    
    print(f"\n✔ Saved training plots and data")
    
    print("\nCapturing propagation...")
    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    layer_fractions = [build_uniform_fractions(prop_slices_per_segment) for _ in range(num_layer)]
    output_fractions = build_uniform_fractions(prop_output_slices)
    
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
        fractions_between_layers=layer_fractions,
        output_fractions=output_fractions,
    )
    
    print(f"✔ Saved propagation visualization")
    
    energies = np.asarray(propagation_summary.get("energies", []), dtype=np.float64)
    if energies.size > 0 and energies[0] != 0:
        energy_drop_pct = (energies[0] - energies[-1]) / energies[0] * 100.0
        print(f"   Energy drop: {energy_drop_pct:.2f}%")
    
    mode_triptych_records = []
    
    if evaluation_mode == "eigenmode":
        print("\nSaving mode triptychs...")
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_m{num_modes}_{timestamp_tag}"
        
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
                evaluation_regions=evaluation_regions,
                detect_radius=detectsize,
                show_mask_overlays=True,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(f"✔ Saved mode {mode_idx + 1} triptych")
    
    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )

# %%
# %%
    print("\nSaving model checkpoint...")
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)
    
    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers": len(D2NN.layers),
            "layer_size": layer_size,
            "z_layers": z_layers,
            "z_prop": z_prop,
            "pixel_size": pixel_size,
            "wavelength": wavelength,
            "padding_ratio": 0.5,
            "field_size": field_size,
            "num_modes": num_modes,
            "z_input_to_first": z_input_to_first,
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers_m{num_modes}_ls{layer_size}.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model checkpoint")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)
    
    print("\nEvaluating model...")
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )
    
    print("\nSaving prediction diagnostics...")
    diag_dir = Path("results/prediction_viz") / f"main_L{num_layer}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    diag_paths = save_prediction_diagnostics(
        D2NN,
        test_dataset,
        evaluation_regions=evaluation_regions,
        layer_size=layer_size,
        detect_radius=detectsize,
        num_samples=3,
        output_dir=diag_dir,
        device=device,
        tag=f"main_L{num_layer}",
    )
    if diag_paths:
        print(f"✔ Saved {len(diag_paths)} prediction diagnostics")
    
    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))
    
    print("\n" + "="*60)
    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )
    print("="*60)

# %%
print("\n" + "="*60)
print("Generating performance comparison")
print("="*60 + "\n")

if model_metrics:
    metrics_dir = Path("results/metrics_analysis")
    metrics_dir.mkdir(parents=True, exist_ok=True)
    metrics_tag = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    layer_counts = np.asarray(num_layer_option[:len(model_metrics)], dtype=np.int32)
    amp_err = np.asarray(all_average_amplitudes_diff[:len(layer_counts)], dtype=np.float64)
    amp_err_rel = np.asarray(all_amplitudes_relative_diff[:len(layer_counts)], dtype=np.float64)
    
    cc_amp_mean_list = []
    cc_amp_std_list = []
    for cc_arr in all_cc_recon_amp[:len(layer_counts)]:
        cc_np = np.asarray(cc_arr, dtype=np.float64)
        if cc_np.size:
            cc_amp_mean_list.append(float(np.nanmean(cc_np)))
            cc_amp_std_list.append(float(np.nanstd(cc_np)))
        else:
            cc_amp_mean_list.append(float("nan"))
            cc_amp_std_list.append(float("nan"))
    cc_amp_mean = np.asarray(cc_amp_mean_list, dtype=np.float64)
    cc_amp_std = np.asarray(cc_amp_std_list, dtype=np.float64)
    
    fig, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True)
    
    axes[0].plot(layer_counts, amp_err, marker="o", linewidth=2, markersize=8)
    axes[0].set_ylabel("Average Amplitude Error", fontsize=12)
    axes[0].grid(True, alpha=0.3, linestyle='--')
    axes[0].set_title("Performance Metrics vs. Number of Layers", fontsize=14, pad=20)
    
    axes[1].plot(layer_counts, amp_err_rel, marker="o", color="tab:orange", linewidth=2, markersize=8)
    axes[1].set_ylabel("Average Relative Error", fontsize=12)
    axes[1].grid(True, alpha=0.3, linestyle='--')
    
    axes[2].errorbar(
        layer_counts,
        cc_amp_mean,
        yerr=cc_amp_std,
        marker="o",
        color="tab:green",
        ecolor="tab:green",
        capsize=5,
        linewidth=2,
        markersize=8,
    )
    axes[2].set_xlabel("Number of Layers", fontsize=12)
    axes[2].set_ylabel("Correlation Coefficient\n(mean ± std)", fontsize=12)
    axes[2].grid(True, alpha=0.3, linestyle='--')
    
    fig.tight_layout()
    
    metrics_plot_path = metrics_dir / f"metrics_vs_layers_{metrics_tag}.png"
    fig.savefig(metrics_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    
    metrics_mat_path = metrics_dir / f"metrics_vs_layers_{metrics_tag}.mat"
    savemat(
        str(metrics_mat_path),
        {
            "layers": layer_counts.astype(np.float64),
            "avg_amp_error": amp_err,
            "avg_relative_amp_error": amp_err_rel,
            "cc_amp_mean": cc_amp_mean,
            "cc_amp_std": cc_amp_std,
        },
    )
    
    print(f"✔ Metrics comparison saved")
    
    print("\n" + "="*60)
    print("Performance Summary Table")
    print("="*60)
    print(f"{'Layers':<8} {'Amp Error':<15} {'Rel Error':<15} {'CC Mean':<15}")
    print("-"*60)
    for i, n_layer in enumerate(layer_counts):
        print(f"{n_layer:<8} {amp_err[i]:<15.6f} {amp_err_rel[i]:<15.6f} {cc_amp_mean[i]:<15.6f}")
    print("="*60)

# %%
print("\n" + "="*60)
print("Training Summary")
print("="*60 + "\n")

for i, summary in enumerate(all_training_summaries):
    print(f"Model {i+1}: {summary['num_layers']} layers")
    print(f"  Time: {summary['total_time']:.2f}s ({summary['total_time']/60:.2f}min)")
    if summary['mode_triptychs']:
        print(f"  Triptychs: {len(summary['mode_triptychs'])} saved")
    print()

print("="*60)
print("All training completed successfully!")
print("="*60)

# %%
def visualize_phase_masks(all_phase_masks, num_layer_option):
    for model_idx, phase_masks in enumerate(all_phase_masks):
        n_layers = len(phase_masks)
        fig, axes = plt.subplots(1, n_layers, figsize=(4*n_layers, 4))
        if n_layers == 1:
            axes = [axes]
        
        for layer_idx, phase in enumerate(phase_masks):
            im = axes[layer_idx].imshow(phase, cmap='hsv', vmin=0, vmax=2*np.pi)
            axes[layer_idx].set_title(f'Layer {layer_idx+1}')
            axes[layer_idx].axis('off')
            plt.colorbar(im, ax=axes[layer_idx], fraction=0.046, label='Phase (rad)')
        
        fig.suptitle(f'Phase Masks - {num_layer_option[model_idx]} Layers', fontsize=14)
        plt.tight_layout()
        plt.savefig(f'results/phase_masks_{num_layer_option[model_idx]}layers.png', dpi=150, bbox_inches='tight')
        plt.close()
        print(f"✓ Saved phase masks for {num_layer_option[model_idx]}-layer model")

if all_phase_masks:
    visualize_phase_masks(all_phase_masks, num_layer_option[:len(all_phase_masks)])

# %%





Starting model training loop

Training D2NN with 2 layers

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0): DiffractionLayer()
    (1): DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)

Model Statistics:
  Total parameters: 24,200
  Trainable parameters: 24,200

Training Configuration:
  Loss: MSE, Optimizer: Adam, LR: 1.99, Decay: 0.99

Starting training...
Epoch    Loss                 Time (s)   LR        
--------------------------------------------------
1        0.0066927727         0.06       1.970100  
100      0.0011214772         0.01       0.728404  
200      0.0011037089         0.02       0.266620  


KeyboardInterrupt: 