已连接到 odnn_venv (Python 3.13.5)

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_visuals,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:1')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:1


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 13 # radius when using uniform circular detectors
circle_detectsize = 26  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "eigenmode"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
label_pattern_mode = "eigenmode"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=16.00, 列=5.50
相邻图案中心间距： 行=42.00, 列=31.50
中心坐标： [(29, 18), (29, 50), (29, 82), (71, 18), (71, 50), (71, 82)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)

Detection Regions: [(6, 32, 16, 42), (37, 63, 16, 42), (68, 94, 16, 42), (6, 32, 58, 84), (37, 63, 58, 84), (68, 94, 58, 84)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []

    for epoch in range(epochs):
        start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss) # the loss for each model
        end_time = time.time()
        elapsed_time = end_time - start_time

        if epoch % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, Time: {elapsed_time*100:.2f} seconds')
    all_losses.append(losses) #save the loss for each model
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [0/1000], Loss: 0.007294103503227234, Time: 8.50 seconds
Epoch [100/1000], Loss: 0.003736137878149748, Time: 1.00 seconds
Epoch [200/1000], Loss: 0.003663138253614306, Time: 1.01 seconds
Epoch [300/1000], Loss: 0.003646797034889460, Time: 1.04 seconds
Epoch [400/1000], Loss: 0.003643352771177888, Time: 1.06 seconds
Epoch [500/1000], Loss: 0.003642342519015074, Time: 1.05 seconds
Epoch [600/1000], Loss: 0.003641963470727205, Time: 1.08 seconds
Epoch [700/1000], Loss: 0.003641816787421703, Time: 1.09 seconds
Epoch [800/1000], Loss: 0.003641761373728514, Time: 1.02 seconds
Epoch [900/1000], Loss: 0.003641740186139941, Time: 1.08 seconds
✔ Saved model -> checkpoints/odnn_3layers.pth
3 layers: modes=6, phase_opt=4, pred_case=1
  amp_err=0.023463, amp_err_rel=0.057473
  snr_full

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_visuals,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:5


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 13 # radius when using uniform circular detectors
circle_detectsize = 26  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "eigenmode"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
label_pattern_mode = "eigenmode"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=16.00, 列=5.50
相邻图案中心间距： 行=42.00, 列=31.50
中心坐标： [(29, 18), (29, 50), (29, 82), (71, 18), (71, 50), (71, 82)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)

Detection Regions: [(6, 32, 16, 42), (37, 63, 16, 42), (68, 94, 16, 42), (6, 32, 58, 84), (37, 63, 58, 84), (68, 94, 58, 84)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []

    for epoch in range(epochs):
        start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss) # the loss for each model
        end_time = time.time()
        elapsed_time = end_time - start_time

        if epoch % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, Time: {elapsed_time*100:.2f} seconds')
    all_losses.append(losses) #save the loss for each model
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [0/1000], Loss: 0.007294103503227234, Time: 3.97 seconds
Epoch [100/1000], Loss: 0.003736137878149748, Time: 1.03 seconds
Epoch [200/1000], Loss: 0.003663138253614306, Time: 1.12 seconds
Epoch [300/1000], Loss: 0.003646797034889460, Time: 1.11 seconds
Epoch [400/1000], Loss: 0.003643352771177888, Time: 1.05 seconds
Epoch [500/1000], Loss: 0.003642342519015074, Time: 1.10 seconds
Epoch [600/1000], Loss: 0.003641963470727205, Time: 1.00 seconds
Epoch [700/1000], Loss: 0.003641816787421703, Time: 1.01 seconds
Epoch [800/1000], Loss: 0.003641761373728514, Time: 1.21 seconds
Epoch [900/1000], Loss: 0.003641740186139941, Time: 1.12 seconds
✔ Saved model -> checkpoints/odnn_3layers.pth
3 layers: modes=6, phase_opt=4, pred_case=1
  amp_err=0.023463, amp_err_rel=0.057473
  snr_full

In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []

    for epoch in range(epochs):
        start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss) # the loss for each model
        end_time = time.time()
        elapsed_time = end_time - start_time

        if epoch % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, Time: {elapsed_time*100:.2f} seconds')
    all_losses.append(losses) #save the loss for each model
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [0/1000], Loss: 0.007294465322047472, Time: 1.13 seconds
Epoch [100/1000], Loss: 0.003753445576876402, Time: 1.05 seconds
Epoch [200/1000], Loss: 0.003712210571393371, Time: 1.02 seconds
Epoch [300/1000], Loss: 0.003702156711369753, Time: 1.07 seconds
Epoch [400/1000], Loss: 0.003698547370731831, Time: 1.19 seconds
Epoch [500/1000], Loss: 0.003697032574564219, Time: 1.02 seconds
Epoch [600/1000], Loss: 0.003696394385769963, Time: 1.04 seconds
Epoch [700/1000], Loss: 0.003696147352457047, Time: 1.06 seconds
Epoch [800/1000], Loss: 0.003696054453030229, Time: 1.18 seconds
Epoch [900/1000], Loss: 0.003696019528433681, Time: 1.03 seconds
✔ Saved model -> checkpoints/odnn_3layers.pth
3 layers: modes=6, phase_opt=4, pred_case=1
  amp_err=0.022938, amp_err_rel=0.056186
  snr_full

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_visuals,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:5


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 13 # radius when using uniform circular detectors
circle_detectsize = 26  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "eigenmode"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
label_pattern_mode = "eigenmode"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=16.00, 列=5.50
相邻图案中心间距： 行=42.00, 列=31.50
中心坐标： [(29, 18), (29, 50), (29, 82), (71, 18), (71, 50), (71, 82)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)

Detection Regions: [(6, 32, 16, 42), (37, 63, 16, 42), (68, 94, 16, 42), (6, 32, 58, 84), (37, 63, 58, 84), (68, 94, 58, 84)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")
    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [1/1000], Loss: 0.007294103503227234, Epoch Time: 0.01 seconds
Epoch [100/1000], Loss: 0.003737372811883688, Epoch Time: 0.01 seconds
Epoch [200/1000], Loss: 0.003663503332063556, Epoch Time: 0.01 seconds
Epoch [300/1000], Loss: 0.003646859433501959, Epoch Time: 0.01 seconds
Epoch [400/1000], Loss: 0.003643369767814875, Epoch Time: 0.01 seconds
Epoch [500/1000], Loss: 0.003642348339781165, Epoch Time: 0.01 seconds
Epoch [600/1000], Loss: 0.003641965799033642, Epoch Time: 0.01 seconds
Epoch [700/1000], Loss: 0.003641817718744278, Epoch Time: 0.01 seconds
Epoch [800/1000], Loss: 0.003641761373728514, Epoch Time: 0.01 seconds
Epoch [900/1000], Loss: 0.003641740418970585, Epoch Time: 0.01 seconds
Epoch [1000/1000], Loss: 0.003641732502728701, Epoch Time: 0.01 seconds
Total tra

In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")


Training duration summary:
 - 3 layers: 10.53 s (~0.18 min)
   Loss curve: results/training_analysis/loss_curve_layers3_20251103_155057.png
   Time curve: results/training_analysis/epoch_time_layers3_20251103_155057.png
   Data (.mat): results/training_analysis/training_curves_layers3_20251103_155057.mat


In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_visuals,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

ImportError: cannot import name 'capture_eigenmode_propagation' from 'odnn_training_visualization' (/home/ydzhang/Desktop/odnn_code/odnn_training_visualization.py)

In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 13 # radius when using uniform circular detectors
circle_detectsize = 26  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "eigenmode"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
label_pattern_mode = "eigenmode"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

: 

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

: 

In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

: 

In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)

: 

In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )

: 

In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

#直观的看看输出和label的差异
for s in [0, 1, 2, 5]:
    plot_sys_vs_label_strict(
        D2NN,
        test_dataset,
        sample_idx=s,
        evaluation_regions=evaluation_regions,
        detect_radius=detectsize,
        save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
        device=device,
        use_big_canvas=False,
        sys_scale="bg_pct",
        sys_pct=99.5,
        clip_pct=99.5,
        mask_roi_for_scale=True,
        show_signed=True,
    )
    plot_reconstruction_vs_input(
        image_test_data=image_test_data,
        reconstructed_fields=all_image_data_pred,
        sample_idx=s,
        model_idx=0,
        save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
    )

#

: 

In [None]:
temp_dataset = test_dataset
FIXED_E_INDEX = 4

def get_fixed_input(dataset, idx, device):
    if isinstance(dataset, list):
        sample = dataset[idx][0]
    else:
        sample = dataset.tensors[0][idx]
    return sample.squeeze(0).to(device)


assert len(temp_dataset) > 0, "test_dataset 为空"
temp_E = get_fixed_input(temp_dataset, FIXED_E_INDEX % len(temp_dataset), device)

z_start = 0.0
z_step = 5e-6
z_prop_plus = z_prop

save_root = Path("results_MD")
save_root.mkdir(parents=True, exist_ok=True)
run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename_prefix = f"ODNN_vis_{run_stamp}"

for i_model, phase_masks in enumerate(all_phase_masks, start=1):
    model_dir = save_root / f"m{i_model}"
    scans, camera_field = visualize_model_slices(
        D2NN,
        phase_masks,
        temp_E,
        output_dir=model_dir,
        sample_tag=f"m{i_model}",
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop_plus=z_prop_plus,
        z_step=z_step,
        pixel_size=pixel_size,
        wavelength=wavelength,
    )

    phase_stack = np.stack([np.asarray(mask, dtype=np.float32) for mask in phase_masks], axis=0)
    meta = {
        "z_start": float(z_start),
        "z_step": float(z_step),
        "z_layers": float(z_layers),
        "z_prop": float(z_prop),
        "z_prop_plus": float(z_prop_plus),
        "pixel_size": float(pixel_size),
        "wavelength": float(wavelength),
        "layer_size": int(layer_size),
        "padding_ratio": 0.5,
    }

    mat_path = model_dir / f"{filename_prefix}_LIGHT_m{i_model}.mat"
    save_to_mat_light_plus(
        mat_path,
        phase_stack=phase_stack,
        input_field=temp_E.detach().cpu().numpy(),
        scans=scans,
        camera_field=camera_field,
        sample_stacks_kmax=20,
        save_amplitude_only=False,
        meta=meta,
    )
    print("Saved ->", mat_path)

    save_masks_one_file_per_layer(
        phase_masks,
        out_dir=model_dir,
        base_name=f"{filename_prefix}_MASK",
        save_degree=False,
        use_xlsx=True,
    )

: 

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_visuals,
    save_mode_triptych,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

ImportError: cannot import name 'capture_eigenmode_propagation' from 'odnn_training_visualization' (/home/ydzhang/Desktop/odnn_code/odnn_training_visualization.py)

In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 5 # radius when using uniform circular detectors
circle_detectsize = 10  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "eigenmode"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
label_pattern_mode = "circle"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

: 

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

: 

In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

: 

In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)

: 

In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    mode_triptych_records: list[dict[str, str | int]] = []
    if evaluation_mode == "eigenmode":
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_{timestamp_tag}"
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(
                f"✔ Saved mode {mode_idx + 1} triptych -> {record['fig_path']}\n"
                f"  MAT -> {record['mat_path']}"
            )

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )

: 

In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")
        mode_triptychs = summary.get("mode_triptychs", [])
        if mode_triptychs:
            print("   Mode triptychs:")
            for trip in mode_triptychs:
                print(
                    f"     Mode {trip['mode']}: fig={trip['fig']}, mat={trip['mat']}"
                )

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

# #直观的看看输出和label的差异
# for s in [0, 1, 2, 5]:
#     plot_sys_vs_label_strict(
#         D2NN,
#         test_dataset,
#         sample_idx=s,
#         evaluation_regions=evaluation_regions,
#         detect_radius=detectsize,
#         save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
#         device=device,
#         use_big_canvas=False,
#         sys_scale="bg_pct",
#         sys_pct=99.5,
#         clip_pct=99.5,
#         mask_roi_for_scale=True,
#         show_signed=True,
#     )
#     plot_reconstruction_vs_input(
#         image_test_data=image_test_data,
#         reconstructed_fields=all_image_data_pred,
#         sample_idx=s,
#         model_idx=0,
#         save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
#     )

# #

: 

已重启 odnn_venv (Python 3.13.5)

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_visuals,
    save_mode_triptych,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:5


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 5 # radius when using uniform circular detectors
circle_detectsize = 10  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "eigenmode"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
label_pattern_mode = "circle"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=26.67, 列=17.50
相邻图案中心间距： 行=36.67, 列=27.50
中心坐标： [(32, 22), (32, 50), (32, 78), (68, 22), (68, 50), (68, 78)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)

Detection Regions: [(17, 27, 27, 37), (45, 55, 27, 37), (73, 83, 27, 37), (17, 27, 63, 73), (45, 55, 63, 73), (73, 83, 63, 73)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    mode_triptych_records: list[dict[str, str | int]] = []
    if evaluation_mode == "eigenmode":
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_{timestamp_tag}"
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(
                f"✔ Saved mode {mode_idx + 1} triptych -> {record['fig_path']}\n"
                f"  MAT -> {record['mat_path']}"
            )

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [1/1000], Loss: 0.008090289309620857, Epoch Time: 0.10 seconds
Epoch [100/1000], Loss: 0.002760213334113359, Epoch Time: 0.01 seconds
Epoch [200/1000], Loss: 0.002684592502191663, Epoch Time: 0.01 seconds
Epoch [300/1000], Loss: 0.002658451907336712, Epoch Time: 0.01 seconds
Epoch [400/1000], Loss: 0.002648203633725643, Epoch Time: 0.01 seconds
Epoch [500/1000], Loss: 0.002644307911396027, Epoch Time: 0.01 seconds
Epoch [600/1000], Loss: 0.002642895793542266, Epoch Time: 0.01 seconds
Epoch [700/1000], Loss: 0.002642374718561769, Epoch Time: 0.01 seconds
Epoch [800/1000], Loss: 0.002642178907990456, Epoch Time: 0.01 seconds
Epoch [900/1000], Loss: 0.002642104867845774, Epoch Time: 0.01 seconds
Epoch [1000/1000], Loss: 0.002642077626660466, Epoch Time: 0.01 seconds
Total tra

  plt.tight_layout(rect=[0, 0, 1, 0.97])


✔ Saved eigenmode-3 propagation plot -> results/propagation_slices/propagation_mode3_layers3_20251103_164848.png
✔ Saved eigenmode-3 propagation data (.mat) -> results/propagation_slices/propagation_mode3_layers3_20251103_164848.mat
✔ Saved mode 1 triptych -> results/mode_triptychs/mode1_layers3_20251103_164848.png
  MAT -> results/mode_triptychs/mode1_layers3_20251103_164848.mat
✔ Saved mode 2 triptych -> results/mode_triptychs/mode2_layers3_20251103_164848.png
  MAT -> results/mode_triptychs/mode2_layers3_20251103_164848.mat
✔ Saved mode 3 triptych -> results/mode_triptychs/mode3_layers3_20251103_164848.png
  MAT -> results/mode_triptychs/mode3_layers3_20251103_164848.mat
✔ Saved mode 4 triptych -> results/mode_triptychs/mode4_layers3_20251103_164848.png
  MAT -> results/mode_triptychs/mode4_layers3_20251103_164848.mat
✔ Saved mode 5 triptych -> results/mode_triptychs/mode5_layers3_20251103_164848.png
  MAT -> results/mode_triptychs/mode5_layers3_20251103_164848.mat
✔ Saved mode 6 tr

In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")
        mode_triptychs = summary.get("mode_triptychs", [])
        if mode_triptychs:
            print("   Mode triptychs:")
            for trip in mode_triptychs:
                print(
                    f"     Mode {trip['mode']}: fig={trip['fig']}, mat={trip['mat']}"
                )

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

# #直观的看看输出和label的差异
# for s in [0, 1, 2, 5]:
#     plot_sys_vs_label_strict(
#         D2NN,
#         test_dataset,
#         sample_idx=s,
#         evaluation_regions=evaluation_regions,
#         detect_radius=detectsize,
#         save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
#         device=device,
#         use_big_canvas=False,
#         sys_scale="bg_pct",
#         sys_pct=99.5,
#         clip_pct=99.5,
#         mask_roi_for_scale=True,
#         show_signed=True,
#     )
#     plot_reconstruction_vs_input(
#         image_test_data=image_test_data,
#         reconstructed_fields=all_image_data_pred,
#         sample_idx=s,
#         model_idx=0,
#         save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
#     )

# #


Training duration summary:
 - 3 layers: 10.79 s (~0.18 min)
   Loss curve: results/training_analysis/loss_curve_layers3_20251103_164848.png
   Time curve: results/training_analysis/epoch_time_layers3_20251103_164848.png
   Data (.mat): results/training_analysis/training_curves_layers3_20251103_164848.mat
   Propagation plot: results/propagation_slices/propagation_mode3_layers3_20251103_164848.png
   Propagation data (.mat): results/propagation_slices/propagation_mode3_layers3_20251103_164848.mat
   Mode triptychs:
     Mode 1: fig=results/mode_triptychs/mode1_layers3_20251103_164848.png, mat=results/mode_triptychs/mode1_layers3_20251103_164848.mat
     Mode 2: fig=results/mode_triptychs/mode2_layers3_20251103_164848.png, mat=results/mode_triptychs/mode2_layers3_20251103_164848.mat
     Mode 3: fig=results/mode_triptychs/mode3_layers3_20251103_164848.png, mat=results/mode_triptychs/mode3_layers3_20251103_164848.mat
     Mode 4: fig=results/mode_triptychs/mode4_layers3_20251103_164848.p

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_triptych,
    save_mode_triptych,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

ImportError: cannot import name 'save_superposition_triptych' from 'odnn_training_visualization' (/home/ydzhang/Desktop/odnn_code/odnn_training_visualization.py)

In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 5 # radius when using uniform circular detectors
circle_detectsize = 10  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "superposition"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
num_superposition_visual_samples = 20
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
run_misalignment_robustness = True
label_pattern_mode = "circle"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

: 

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

: 

In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

: 

In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

superposition_eval_ctx: dict | None = None
if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
    superposition_eval_ctx = super_ctx
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)


def shift_complex_batch(batch: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor:
    """
    Translate a batch of complex fields by (shift_y, shift_x) pixels with zero padding.
    Positive shift_y moves downward; positive shift_x moves right.
    """
    if shift_y == 0 and shift_x == 0:
        return batch

    _, _, height, width = batch.shape
    if abs(shift_y) >= height or abs(shift_x) >= width:
        return torch.zeros_like(batch)

    real_imag = torch.view_as_real(batch)
    shifted = torch.zeros_like(real_imag)

    if shift_y >= 0:
        src_y = slice(0, height - shift_y)
        dst_y = slice(shift_y, height)
    else:
        src_y = slice(-shift_y, height)
        dst_y = slice(0, height + shift_y)

    if shift_x >= 0:
        src_x = slice(0, width - shift_x)
        dst_x = slice(shift_x, width)
    else:
        src_x = slice(-shift_x, width)
        dst_x = slice(0, width + shift_x)

    shifted[:, :, dst_y, dst_x, :] = real_imag[:, :, src_y, src_x, :]
    return torch.view_as_complex(shifted)


def compute_amp_relative_error_with_shift(
    model: torch.nn.Module,
    loader,
    *,
    shift_y_px: int,
    shift_x_px: int,
    evaluation_regions,
    pred_case: int,
    num_modes: int,
    eval_amplitudes: np.ndarray,
    eval_amplitudes_phases: np.ndarray,
    eval_phases: np.ndarray,
    phase_option: int,
    mmf_modes: torch.Tensor,
    field_size: int,
    image_test_data: torch.Tensor,
    device: torch.device,
) -> dict:
    """
    Evaluate amplitude-related metrics when the input field is shifted by (shift_y_px, shift_x_px).
    """
    model.eval()
    all_weights_pred: list[np.ndarray] = []

    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            shifted_images = shift_complex_batch(images, shift_y_px, shift_x_px)
            preds = model(shifted_images)
            preds_np = preds.detach().cpu().numpy()

            for sample_idx in range(preds_np.shape[0]):
                intensity_map = preds_np[sample_idx, 0]
                weights = []
                for (x0, x1, y0, y1) in evaluation_regions:
                    weights.append(float(intensity_map[y0:y1, x0:x1].mean()))
                weights = np.asarray(weights, dtype=np.float64)

                if pred_case == 3 and num_modes <= len(weights):
                    norm_val = np.linalg.norm(weights[:num_modes])
                    if norm_val > 0:
                        weights[:num_modes] /= norm_val
                else:
                    norm_val = np.linalg.norm(weights)
                    if norm_val > 0:
                        weights /= norm_val

                all_weights_pred.append(weights)

    metrics = compute_model_prediction_metrics(
        all_weights_pred,
        eval_amplitudes,
        eval_amplitudes_phases,
        eval_phases,
        phase_option,
        pred_case,
        num_modes,
        mmf_modes,
        field_size,
        image_test_data,
    )

    return metrics

: 

In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    mode_triptych_records: list[dict[str, str | int]] = []
    if evaluation_mode == "eigenmode":
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_{timestamp_tag}"
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(
                f"✔ Saved mode {mode_idx + 1} triptych -> {record['fig_path']}\n"
                f"  MAT -> {record['mat_path']}"
            )

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )

: 

In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")
        mode_triptychs = summary.get("mode_triptychs", [])
        if mode_triptychs:
            print("   Mode triptychs:")
            for trip in mode_triptychs:
                print(
                    f"     Mode {trip['mode']}: fig={trip['fig']}, mat={trip['mat']}"
                )

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

# #直观的看看输出和label的差异
# for s in [0, 1, 2, 5]:
#     plot_sys_vs_label_strict(
#         D2NN,
#         test_dataset,
#         sample_idx=s,
#         evaluation_regions=evaluation_regions,
#         detect_radius=detectsize,
#         save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
#         device=device,
#         use_big_canvas=False,
#         sys_scale="bg_pct",
#         sys_pct=99.5,
#         clip_pct=99.5,
#         mask_roi_for_scale=True,
#         show_signed=True,
#     )
#     plot_reconstruction_vs_input(
#         image_test_data=image_test_data,
#         reconstructed_fields=all_image_data_pred,
#         sample_idx=s,
#         model_idx=0,
#         save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
#     )

# #

: 

In [None]:
temp_dataset = test_dataset
FIXED_E_INDEX = 4

def get_fixed_input(dataset, idx, device):
    if isinstance(dataset, list):
        sample = dataset[idx][0]
    else:
        sample = dataset.tensors[0][idx]
    return sample.squeeze(0).to(device)


assert len(temp_dataset) > 0, "test_dataset 为空"
temp_E = get_fixed_input(temp_dataset, FIXED_E_INDEX % len(temp_dataset), device)

z_start = 0.0
z_step = 5e-6
z_prop_plus = z_prop

save_root = Path("results_MD")
save_root.mkdir(parents=True, exist_ok=True)
run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename_prefix = f"ODNN_vis_{run_stamp}"

for i_model, phase_masks in enumerate(all_phase_masks, start=1):
    model_dir = save_root / f"m{i_model}"
    scans, camera_field = visualize_model_slices(
        D2NN,
        phase_masks,
        temp_E,
        output_dir=model_dir,
        sample_tag=f"m{i_model}",
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop_plus=z_prop_plus,
        z_step=z_step,
        pixel_size=pixel_size,
        wavelength=wavelength,
    )

    phase_stack = np.stack([np.asarray(mask, dtype=np.float32) for mask in phase_masks], axis=0)
    meta = {
        "z_start": float(z_start),
        "z_step": float(z_step),
        "z_layers": float(z_layers),
        "z_prop": float(z_prop),
        "z_prop_plus": float(z_prop_plus),
        "pixel_size": float(pixel_size),
        "wavelength": float(wavelength),
        "layer_size": int(layer_size),
        "padding_ratio": 0.5,
    }

    mat_path = model_dir / f"{filename_prefix}_LIGHT_m{i_model}.mat"
    save_to_mat_light_plus(
        mat_path,
        phase_stack=phase_stack,
        input_field=temp_E.detach().cpu().numpy(),
        scans=scans,
        camera_field=camera_field,
        sample_stacks_kmax=20,
        save_amplitude_only=False,
        meta=meta,
    )
    print("Saved ->", mat_path)

    save_masks_one_file_per_layer(
        phase_masks,
        out_dir=model_dir,
        base_name=f"{filename_prefix}_MASK",
        save_degree=False,
        use_xlsx=True,
    )

: 

In [None]:
if pred_case == 1 and run_superposition_debug:
    super_dir = Path("results_superposition")
    super_dir.mkdir(parents=True, exist_ok=True)
    super_tag = datetime.now().strftime("%Y%m%d_%H%M%S")
    super_records: list[dict[str, str | int]] = []
    slice_reference_input: torch.Tensor | None = None

    for sample_idx in range(num_superposition_visual_samples):
        super_sample = generate_superposition_sample(
            num_modes=num_modes,
            field_size=field_size,
            layer_size=layer_size,
            mmf_modes=MMF_data_ts,
            mmf_label_data=MMF_Label_data,
        )
        super_output_map = infer_superposition_output(
            D2NN,
            super_sample["padded_image"],
            device,
        )

        sample_tag = f"{super_tag}_s{sample_idx:02d}"
        triptych_paths = save_superposition_triptych(
            input_field=super_sample["padded_image"][0],
            output_intensity_map=super_output_map,
            amplitudes=super_sample["amplitudes"],
            phases=super_sample["phases"],
            complex_weights=super_sample["complex_weights"],
            label_map=super_sample["padded_label"][0],
            evaluation_regions=evaluation_regions,
            detect_radius=detectsize,
            output_dir=super_dir,
            tag=sample_tag,
            save_plot=save_superposition_plots,
        )
        if triptych_paths["fig_path"]:
            print(
                f"Superposition sample {sample_idx + 1}/{num_superposition_visual_samples} -> "
                f"{triptych_paths['fig_path']}"
            )
        print(f"  MAT saved -> {triptych_paths['mat_path']}")

        super_records.append(
            {
                "index": sample_idx,
                "tag": sample_tag,
                "fig": triptych_paths["fig_path"] if triptych_paths else "",
                "mat": triptych_paths["mat_path"] if triptych_paths else "",
            }
        )

        if slice_reference_input is None:
            slice_reference_input = (
                super_sample["padded_image"].squeeze(0).to(device, dtype=torch.complex64)
            )

    if save_superposition_slices and all_phase_masks and slice_reference_input is not None:
        slices_root = super_dir / f"slices_{super_tag}"
        export_superposition_slices(
            D2NN,
            all_phase_masks,
            slice_reference_input,
            slices_root,
            sample_tag="superposition",
            z_input_to_first=z_input_to_first,
            z_layers=z_layers,
            z_prop=z_prop,
            z_step=z_step,
            pixel_size=pixel_size,
            wavelength=wavelength,
        )

    if super_records:
        print("\nSuperposition sample outputs:")
        for record in super_records:
            print(
                f" - Sample {record['index'] + 1:02d} ({record['tag']}): "
                f"fig={record['fig']}, mat={record['mat']}"
            )

: 

已重启 odnn_venv (Python 3.13.5)

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_triptych,
    save_mode_triptych,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:5


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 5 # radius when using uniform circular detectors
circle_detectsize = 10  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "superposition"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
num_superposition_visual_samples = 20
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
run_misalignment_robustness = True
label_pattern_mode = "circle"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=26.67, 列=17.50
相邻图案中心间距： 行=36.67, 列=27.50
中心坐标： [(32, 22), (32, 50), (32, 78), (68, 22), (68, 50), (68, 78)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

superposition_eval_ctx: dict | None = None
if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
    superposition_eval_ctx = super_ctx
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)


def shift_complex_batch(batch: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor:
    """
    Translate a batch of complex fields by (shift_y, shift_x) pixels with zero padding.
    Positive shift_y moves downward; positive shift_x moves right.
    """
    if shift_y == 0 and shift_x == 0:
        return batch

    _, _, height, width = batch.shape
    if abs(shift_y) >= height or abs(shift_x) >= width:
        return torch.zeros_like(batch)

    real_imag = torch.view_as_real(batch)
    shifted = torch.zeros_like(real_imag)

    if shift_y >= 0:
        src_y = slice(0, height - shift_y)
        dst_y = slice(shift_y, height)
    else:
        src_y = slice(-shift_y, height)
        dst_y = slice(0, height + shift_y)

    if shift_x >= 0:
        src_x = slice(0, width - shift_x)
        dst_x = slice(shift_x, width)
    else:
        src_x = slice(-shift_x, width)
        dst_x = slice(0, width + shift_x)

    shifted[:, :, dst_y, dst_x, :] = real_imag[:, :, src_y, src_x, :]
    return torch.view_as_complex(shifted)


def compute_amp_relative_error_with_shift(
    model: torch.nn.Module,
    loader,
    *,
    shift_y_px: int,
    shift_x_px: int,
    evaluation_regions,
    pred_case: int,
    num_modes: int,
    eval_amplitudes: np.ndarray,
    eval_amplitudes_phases: np.ndarray,
    eval_phases: np.ndarray,
    phase_option: int,
    mmf_modes: torch.Tensor,
    field_size: int,
    image_test_data: torch.Tensor,
    device: torch.device,
) -> dict:
    """
    Evaluate amplitude-related metrics when the input field is shifted by (shift_y_px, shift_x_px).
    """
    model.eval()
    all_weights_pred: list[np.ndarray] = []

    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            shifted_images = shift_complex_batch(images, shift_y_px, shift_x_px)
            preds = model(shifted_images)
            preds_np = preds.detach().cpu().numpy()

            for sample_idx in range(preds_np.shape[0]):
                intensity_map = preds_np[sample_idx, 0]
                weights = []
                for (x0, x1, y0, y1) in evaluation_regions:
                    weights.append(float(intensity_map[y0:y1, x0:x1].mean()))
                weights = np.asarray(weights, dtype=np.float64)

                if pred_case == 3 and num_modes <= len(weights):
                    norm_val = np.linalg.norm(weights[:num_modes])
                    if norm_val > 0:
                        weights[:num_modes] /= norm_val
                else:
                    norm_val = np.linalg.norm(weights)
                    if norm_val > 0:
                        weights /= norm_val

                all_weights_pred.append(weights)

    metrics = compute_model_prediction_metrics(
        all_weights_pred,
        eval_amplitudes,
        eval_amplitudes_phases,
        eval_phases,
        phase_option,
        pred_case,
        num_modes,
        mmf_modes,
        field_size,
        image_test_data,
    )

    return metrics

Detection Regions: [(17, 27, 27, 37), (45, 55, 27, 37), (73, 83, 27, 37), (17, 27, 63, 73), (45, 55, 63, 73), (73, 83, 63, 73)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    mode_triptych_records: list[dict[str, str | int]] = []
    if evaluation_mode == "eigenmode":
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_{timestamp_tag}"
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(
                f"✔ Saved mode {mode_idx + 1} triptych -> {record['fig_path']}\n"
                f"  MAT -> {record['mat_path']}"
            )

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [1/1000], Loss: 0.008090289309620857, Epoch Time: 0.10 seconds
Epoch [100/1000], Loss: 0.002760213334113359, Epoch Time: 0.01 seconds
Epoch [200/1000], Loss: 0.002684592502191663, Epoch Time: 0.01 seconds
Epoch [300/1000], Loss: 0.002658451907336712, Epoch Time: 0.01 seconds
Epoch [400/1000], Loss: 0.002648203633725643, Epoch Time: 0.01 seconds
Epoch [500/1000], Loss: 0.002644307911396027, Epoch Time: 0.01 seconds
Epoch [600/1000], Loss: 0.002642895793542266, Epoch Time: 0.01 seconds
Epoch [700/1000], Loss: 0.002642374718561769, Epoch Time: 0.01 seconds
Epoch [800/1000], Loss: 0.002642178907990456, Epoch Time: 0.01 seconds
Epoch [900/1000], Loss: 0.002642104867845774, Epoch Time: 0.01 seconds
Epoch [1000/1000], Loss: 0.002642077626660466, Epoch Time: 0.01 seconds
Total tra

  plt.tight_layout(rect=[0, 0, 1, 0.97])


✔ Saved eigenmode-3 propagation plot -> results/propagation_slices/propagation_mode3_layers3_20251103_181746.png
✔ Saved eigenmode-3 propagation data (.mat) -> results/propagation_slices/propagation_mode3_layers3_20251103_181746.mat
✔ Saved model -> checkpoints/odnn_3layers.pth
3 layers: modes=6, phase_opt=4, pred_case=1
  amp_err=0.085485, amp_err_rel=0.209394
  snr_full=0.590977, snr_crop=0.831166, throughput=0.710960
  cc_amp=0.963570±0.021786, cc_phase=0.793870±0.137667, cc_real=0.969919±0.025222, cc_imag=0.970586±0.030134


In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")
        mode_triptychs = summary.get("mode_triptychs", [])
        if mode_triptychs:
            print("   Mode triptychs:")
            for trip in mode_triptychs:
                print(
                    f"     Mode {trip['mode']}: fig={trip['fig']}, mat={trip['mat']}"
                )

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

# #直观的看看输出和label的差异
# for s in [0, 1, 2, 5]:
#     plot_sys_vs_label_strict(
#         D2NN,
#         test_dataset,
#         sample_idx=s,
#         evaluation_regions=evaluation_regions,
#         detect_radius=detectsize,
#         save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
#         device=device,
#         use_big_canvas=False,
#         sys_scale="bg_pct",
#         sys_pct=99.5,
#         clip_pct=99.5,
#         mask_roi_for_scale=True,
#         show_signed=True,
#     )
#     plot_reconstruction_vs_input(
#         image_test_data=image_test_data,
#         reconstructed_fields=all_image_data_pred,
#         sample_idx=s,
#         model_idx=0,
#         save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
#     )

# #


Training duration summary:
 - 3 layers: 10.28 s (~0.17 min)
   Loss curve: results/training_analysis/loss_curve_layers3_20251103_181746.png
   Time curve: results/training_analysis/epoch_time_layers3_20251103_181746.png
   Data (.mat): results/training_analysis/training_curves_layers3_20251103_181746.mat
   Propagation plot: results/propagation_slices/propagation_mode3_layers3_20251103_181746.png
   Propagation data (.mat): results/propagation_slices/propagation_mode3_layers3_20251103_181746.mat
✔ Saved: /home/ydzhang/Desktop/odnn_code/results/plots/Amp_3layers.png


In [None]:
temp_dataset = test_dataset
FIXED_E_INDEX = 4

def get_fixed_input(dataset, idx, device):
    if isinstance(dataset, list):
        sample = dataset[idx][0]
    else:
        sample = dataset.tensors[0][idx]
    return sample.squeeze(0).to(device)


assert len(temp_dataset) > 0, "test_dataset 为空"
temp_E = get_fixed_input(temp_dataset, FIXED_E_INDEX % len(temp_dataset), device)

z_start = 0.0
z_step = 5e-6
z_prop_plus = z_prop

save_root = Path("results_MD")
save_root.mkdir(parents=True, exist_ok=True)
run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename_prefix = f"ODNN_vis_{run_stamp}"

for i_model, phase_masks in enumerate(all_phase_masks, start=1):
    model_dir = save_root / f"m{i_model}"
    scans, camera_field = visualize_model_slices(
        D2NN,
        phase_masks,
        temp_E,
        output_dir=model_dir,
        sample_tag=f"m{i_model}",
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop_plus=z_prop_plus,
        z_step=z_step,
        pixel_size=pixel_size,
        wavelength=wavelength,
    )

    phase_stack = np.stack([np.asarray(mask, dtype=np.float32) for mask in phase_masks], axis=0)
    meta = {
        "z_start": float(z_start),
        "z_step": float(z_step),
        "z_layers": float(z_layers),
        "z_prop": float(z_prop),
        "z_prop_plus": float(z_prop_plus),
        "pixel_size": float(pixel_size),
        "wavelength": float(wavelength),
        "layer_size": int(layer_size),
        "padding_ratio": 0.5,
    }

    mat_path = model_dir / f"{filename_prefix}_LIGHT_m{i_model}.mat"
    save_to_mat_light_plus(
        mat_path,
        phase_stack=phase_stack,
        input_field=temp_E.detach().cpu().numpy(),
        scans=scans,
        camera_field=camera_field,
        sample_stacks_kmax=20,
        save_amplitude_only=False,
        meta=meta,
    )
    print("Saved ->", mat_path)

    save_masks_one_file_per_layer(
        phase_masks,
        out_dir=model_dir,
        base_name=f"{filename_prefix}_MASK",
        save_degree=False,
        use_xlsx=True,
    )

  plt.tight_layout()


Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_input.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_layer1.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_layer2.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_layer3.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_to_camera.png
Saved (v5 plus): results_MD/m1/ODNN_vis_20251103_181752_LIGHT_m1.mat
Saved -> results_MD/m1/ODNN_vis_20251103_181752_LIGHT_m1.mat


In [None]:
if pred_case == 1 and run_superposition_debug:
    super_dir = Path("results_superposition")
    super_dir.mkdir(parents=True, exist_ok=True)
    super_tag = datetime.now().strftime("%Y%m%d_%H%M%S")
    super_records: list[dict[str, str | int]] = []
    slice_reference_input: torch.Tensor | None = None

    for sample_idx in range(num_superposition_visual_samples):
        super_sample = generate_superposition_sample(
            num_modes=num_modes,
            field_size=field_size,
            layer_size=layer_size,
            mmf_modes=MMF_data_ts,
            mmf_label_data=MMF_Label_data,
        )
        super_output_map = infer_superposition_output(
            D2NN,
            super_sample["padded_image"],
            device,
        )

        sample_tag = f"{super_tag}_s{sample_idx:02d}"
        triptych_paths = save_superposition_triptych(
            input_field=super_sample["padded_image"][0],
            output_intensity_map=super_output_map,
            amplitudes=super_sample["amplitudes"],
            phases=super_sample["phases"],
            complex_weights=super_sample["complex_weights"],
            label_map=super_sample["padded_label"][0],
            evaluation_regions=evaluation_regions,
            detect_radius=detectsize,
            output_dir=super_dir,
            tag=sample_tag,
            save_plot=save_superposition_plots,
        )
        if triptych_paths["fig_path"]:
            print(
                f"Superposition sample {sample_idx + 1}/{num_superposition_visual_samples} -> "
                f"{triptych_paths['fig_path']}"
            )
        print(f"  MAT saved -> {triptych_paths['mat_path']}")

        super_records.append(
            {
                "index": sample_idx,
                "tag": sample_tag,
                "fig": triptych_paths["fig_path"] if triptych_paths else "",
                "mat": triptych_paths["mat_path"] if triptych_paths else "",
            }
        )

        if slice_reference_input is None:
            slice_reference_input = (
                super_sample["padded_image"].squeeze(0).to(device, dtype=torch.complex64)
            )

    if save_superposition_slices and all_phase_masks and slice_reference_input is not None:
        slices_root = super_dir / f"slices_{super_tag}"
        export_superposition_slices(
            D2NN,
            all_phase_masks,
            slice_reference_input,
            slices_root,
            sample_tag="superposition",
            z_input_to_first=z_input_to_first,
            z_layers=z_layers,
            z_prop=z_prop,
            z_step=z_step,
            pixel_size=pixel_size,
            wavelength=wavelength,
        )

    if super_records:
        print("\nSuperposition sample outputs:")
        for record in super_records:
            print(
                f" - Sample {record['index'] + 1:02d} ({record['tag']}): "
                f"fig={record['fig']}, mat={record['mat']}"
            )

Superposition sample 1/20 -> results_superposition/super_triptych_20251103_181758_s00.png
  MAT saved -> results_superposition/super_triptych_20251103_181758_s00.mat
Superposition sample 2/20 -> results_superposition/super_triptych_20251103_181758_s01.png
  MAT saved -> results_superposition/super_triptych_20251103_181758_s01.mat
Superposition sample 3/20 -> results_superposition/super_triptych_20251103_181758_s02.png
  MAT saved -> results_superposition/super_triptych_20251103_181758_s02.mat
Superposition sample 4/20 -> results_superposition/super_triptych_20251103_181758_s03.png
  MAT saved -> results_superposition/super_triptych_20251103_181758_s03.mat
Superposition sample 5/20 -> results_superposition/super_triptych_20251103_181758_s04.png
  MAT saved -> results_superposition/super_triptych_20251103_181758_s04.mat
Superposition sample 6/20 -> results_superposition/super_triptych_20251103_181758_s05.png
  MAT saved -> results_superposition/super_triptych_20251103_181758_s05.mat
Supe

  plt.tight_layout()


Saved figure -> /home/ydzhang/Desktop/odnn_code/results_superposition/slices_20251103_181758/m1/superposition_m1_scan_input.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_superposition/slices_20251103_181758/m1/superposition_m1_scan_layer1.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_superposition/slices_20251103_181758/m1/superposition_m1_scan_layer2.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_superposition/slices_20251103_181758/m1/superposition_m1_scan_layer3.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_superposition/slices_20251103_181758/m1/superposition_m1_scan_to_camera.png
Superposition slices saved -> /home/ydzhang/Desktop/odnn_code/results_superposition/slices_20251103_181758/m1

Superposition sample outputs:
 - Sample 01 (20251103_181758_s00): fig=results_superposition/super_triptych_20251103_181758_s00.png, mat=results_superposition/super_triptych_20251103_181758_s00.mat
 - Sample 02 (20251103_181758_s01): fig=resul

In [None]:
if run_misalignment_robustness and pred_case == 1:
    robustness_dir = Path("results/robustness_analysis")
    robustness_dir.mkdir(parents=True, exist_ok=True)
    robustness_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    dx_um_values = np.arange(-200.0, 205.0, 10.0, dtype=np.float32)
    dy_um_values = np.arange(-200.0, 205.0, 10.0, dtype=np.float32)
    amp_err_surface = np.zeros((len(dy_um_values), len(dx_um_values)), dtype=np.float64)

    def um_to_pixels(shift_um: float) -> int:
        return int(round((shift_um * 1e-6) / pixel_size))

    print("\nRunning misalignment robustness sweep (±200 µm, 5 µm steps)...")
    for iy, dy_um in enumerate(dy_um_values):
        shift_y_px = um_to_pixels(float(dy_um))
        for ix, dx_um in enumerate(dx_um_values):
            shift_x_px = um_to_pixels(float(dx_um))
            metrics = compute_amp_relative_error_with_shift(
                D2NN,
                test_loader,
                shift_y_px=shift_y_px,
                shift_x_px=shift_x_px,
                evaluation_regions=evaluation_regions,
                pred_case=pred_case,
                num_modes=num_modes,
                eval_amplitudes=eval_amplitudes,
                eval_amplitudes_phases=eval_amplitudes_phases,
                eval_phases=eval_phases,
                phase_option=phase_option,
                mmf_modes=MMF_data_ts,
                field_size=field_size,
                image_test_data=image_test_data,
                device=device,
            )
            amp_err_surface[iy, ix] = float(metrics.get("avg_relative_amp_err", float("nan")))
        print(f"  Completed shift row {iy + 1}/{len(dy_um_values)} (Δy = {dy_um:.1f} µm)")

    DX, DY = np.meshgrid(dx_um_values, dy_um_values)
    fig = plt.figure(figsize=(9, 7))
    ax = fig.add_subplot(111, projection="3d")
    surf = ax.plot_surface(DX, DY, amp_err_surface, cmap="viridis")
    ax.set_xlabel("Δx (µm)")
    ax.set_ylabel("Δy (µm)")
    ax.set_zlabel("Relative amplitude error")
    ax.set_title("Amplitude error vs. input-mask misalignment")
    fig.colorbar(surf, shrink=0.6, aspect=12)
    fig.tight_layout()

    robustness_fig_path = robustness_dir / f"misalignment_surface_{robustness_tag}.png"
    fig.savefig(robustness_fig_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    dx_px_values = np.rint(dx_um_values * 1e-6 / pixel_size).astype(np.int32)
    dy_px_values = np.rint(dy_um_values * 1e-6 / pixel_size).astype(np.int32)
    robustness_mat_path = robustness_dir / f"misalignment_surface_{robustness_tag}.mat"
    savemat(
        str(robustness_mat_path),
        {
            "dx_um": dx_um_values.astype(np.float32),
            "dy_um": dy_um_values.astype(np.float32),
            "dx_pixels": dx_px_values,
            "dy_pixels": dy_px_values,
            "relative_amp_error": amp_err_surface.astype(np.float64),
            "pixel_size_m": np.array([pixel_size], dtype=np.float64),
            "step_um": np.array([5.0], dtype=np.float32),
            "range_um": np.array([200.0], dtype=np.float32),
        },
    )

    if all_training_summaries:
        all_training_summaries[-1]["robustness_fig"] = str(robustness_fig_path)
        all_training_summaries[-1]["robustness_mat"] = str(robustness_mat_path)

    print(f"\n✔ Misalignment robustness surface saved -> {robustness_fig_path}")
    print(f"✔ Misalignment robustness data (.mat) -> {robustness_mat_path}")


Running misalignment robustness sweep (±200 µm, 5 µm steps)...


  c /= stddev[:, None]
  c /= stddev[None, :]


  Completed shift row 1/41 (Δy = -200.0 µm)
  Completed shift row 2/41 (Δy = -190.0 µm)
  Completed shift row 3/41 (Δy = -180.0 µm)
  Completed shift row 4/41 (Δy = -170.0 µm)
  Completed shift row 5/41 (Δy = -160.0 µm)
  Completed shift row 6/41 (Δy = -150.0 µm)
  Completed shift row 7/41 (Δy = -140.0 µm)
  Completed shift row 8/41 (Δy = -130.0 µm)
  Completed shift row 9/41 (Δy = -120.0 µm)
  Completed shift row 10/41 (Δy = -110.0 µm)
  Completed shift row 11/41 (Δy = -100.0 µm)
  Completed shift row 12/41 (Δy = -90.0 µm)
  Completed shift row 13/41 (Δy = -80.0 µm)
  Completed shift row 14/41 (Δy = -70.0 µm)
  Completed shift row 15/41 (Δy = -60.0 µm)
  Completed shift row 16/41 (Δy = -50.0 µm)
  Completed shift row 17/41 (Δy = -40.0 µm)
  Completed shift row 18/41 (Δy = -30.0 µm)
  Completed shift row 19/41 (Δy = -20.0 µm)
  Completed shift row 20/41 (Δy = -10.0 µm)
  Completed shift row 21/41 (Δy = 0.0 µm)
  Completed shift row 22/41 (Δy = 10.0 µm)
  Completed shift row 23/41 (Δy =

In [None]:
if run_misalignment_robustness and pred_case == 1:
    robustness_dir = Path("results/robustness_analysis")
    robustness_dir.mkdir(parents=True, exist_ok=True)
    robustness_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    dx_um_values = np.arange(-20.0, 20.0, 2.0, dtype=np.float32)
    dy_um_values = np.arange(-20.0, 20.0, 2.0, dtype=np.float32)
    amp_err_surface = np.zeros((len(dy_um_values), len(dx_um_values)), dtype=np.float64)

    def um_to_pixels(shift_um: float) -> int:
        return int(round((shift_um * 1e-6) / pixel_size))

    print("\nRunning misalignment robustness sweep (±200 µm, 5 µm steps)...")
    for iy, dy_um in enumerate(dy_um_values):
        shift_y_px = um_to_pixels(float(dy_um))
        for ix, dx_um in enumerate(dx_um_values):
            shift_x_px = um_to_pixels(float(dx_um))
            metrics = compute_amp_relative_error_with_shift(
                D2NN,
                test_loader,
                shift_y_px=shift_y_px,
                shift_x_px=shift_x_px,
                evaluation_regions=evaluation_regions,
                pred_case=pred_case,
                num_modes=num_modes,
                eval_amplitudes=eval_amplitudes,
                eval_amplitudes_phases=eval_amplitudes_phases,
                eval_phases=eval_phases,
                phase_option=phase_option,
                mmf_modes=MMF_data_ts,
                field_size=field_size,
                image_test_data=image_test_data,
                device=device,
            )
            amp_err_surface[iy, ix] = float(metrics.get("avg_relative_amp_err", float("nan")))
        print(f"  Completed shift row {iy + 1}/{len(dy_um_values)} (Δy = {dy_um:.1f} µm)")

    DX, DY = np.meshgrid(dx_um_values, dy_um_values)
    fig = plt.figure(figsize=(9, 7))
    ax = fig.add_subplot(111, projection="3d")
    surf = ax.plot_surface(DX, DY, amp_err_surface, cmap="viridis")
    ax.set_xlabel("Δx (µm)")
    ax.set_ylabel("Δy (µm)")
    ax.set_zlabel("Relative amplitude error")
    ax.set_title("Amplitude error vs. input-mask misalignment")
    fig.colorbar(surf, shrink=0.6, aspect=12)
    fig.tight_layout()

    robustness_fig_path = robustness_dir / f"misalignment_surface_{robustness_tag}.png"
    fig.savefig(robustness_fig_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    dx_px_values = np.rint(dx_um_values * 1e-6 / pixel_size).astype(np.int32)
    dy_px_values = np.rint(dy_um_values * 1e-6 / pixel_size).astype(np.int32)
    robustness_mat_path = robustness_dir / f"misalignment_surface_{robustness_tag}.mat"
    savemat(
        str(robustness_mat_path),
        {
            "dx_um": dx_um_values.astype(np.float32),
            "dy_um": dy_um_values.astype(np.float32),
            "dx_pixels": dx_px_values,
            "dy_pixels": dy_px_values,
            "relative_amp_error": amp_err_surface.astype(np.float64),
            "pixel_size_m": np.array([pixel_size], dtype=np.float64),
            "step_um": np.array([5.0], dtype=np.float32),
            "range_um": np.array([200.0], dtype=np.float32),
        },
    )

    if all_training_summaries:
        all_training_summaries[-1]["robustness_fig"] = str(robustness_fig_path)
        all_training_summaries[-1]["robustness_mat"] = str(robustness_mat_path)

    print(f"\n✔ Misalignment robustness surface saved -> {robustness_fig_path}")
    print(f"✔ Misalignment robustness data (.mat) -> {robustness_mat_path}")


Running misalignment robustness sweep (±200 µm, 5 µm steps)...
  Completed shift row 1/20 (Δy = -20.0 µm)
  Completed shift row 2/20 (Δy = -18.0 µm)
  Completed shift row 3/20 (Δy = -16.0 µm)
  Completed shift row 4/20 (Δy = -14.0 µm)
  Completed shift row 5/20 (Δy = -12.0 µm)
  Completed shift row 6/20 (Δy = -10.0 µm)
  Completed shift row 7/20 (Δy = -8.0 µm)
  Completed shift row 8/20 (Δy = -6.0 µm)
  Completed shift row 9/20 (Δy = -4.0 µm)
  Completed shift row 10/20 (Δy = -2.0 µm)
  Completed shift row 11/20 (Δy = 0.0 µm)
  Completed shift row 12/20 (Δy = 2.0 µm)
  Completed shift row 13/20 (Δy = 4.0 µm)
  Completed shift row 14/20 (Δy = 6.0 µm)
  Completed shift row 15/20 (Δy = 8.0 µm)
  Completed shift row 16/20 (Δy = 10.0 µm)
  Completed shift row 17/20 (Δy = 12.0 µm)
  Completed shift row 18/20 (Δy = 14.0 µm)
  Completed shift row 19/20 (Δy = 16.0 µm)
  Completed shift row 20/20 (Δy = 18.0 µm)

✔ Misalignment robustness surface saved -> results/robustness_analysis/misalignment

已连接到 odnn_venv (Python 3.13.5)

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_triptych,
    save_mode_triptych,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:5


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 5 # radius when using uniform circular detectors
circle_detectsize = 10  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "superposition"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
num_superposition_visual_samples = 20
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
run_misalignment_robustness = True
label_pattern_mode = "circle"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=26.67, 列=17.50
相邻图案中心间距： 行=36.67, 列=27.50
中心坐标： [(32, 22), (32, 50), (32, 78), (68, 22), (68, 50), (68, 78)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

superposition_eval_ctx: dict | None = None
if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
    superposition_eval_ctx = super_ctx
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)


def shift_complex_batch(batch: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor:
    """
    Translate a batch of complex fields by (shift_y, shift_x) pixels with zero padding.
    Positive shift_y moves downward; positive shift_x moves right.
    """
    if shift_y == 0 and shift_x == 0:
        return batch

    _, _, height, width = batch.shape
    if abs(shift_y) >= height or abs(shift_x) >= width:
        return torch.zeros_like(batch)

    real_imag = torch.view_as_real(batch)
    shifted = torch.zeros_like(real_imag)

    if shift_y >= 0:
        src_y = slice(0, height - shift_y)
        dst_y = slice(shift_y, height)
    else:
        src_y = slice(-shift_y, height)
        dst_y = slice(0, height + shift_y)

    if shift_x >= 0:
        src_x = slice(0, width - shift_x)
        dst_x = slice(shift_x, width)
    else:
        src_x = slice(-shift_x, width)
        dst_x = slice(0, width + shift_x)

    shifted[:, :, dst_y, dst_x, :] = real_imag[:, :, src_y, src_x, :]
    return torch.view_as_complex(shifted)


def compute_amp_relative_error_with_shift(
    model: torch.nn.Module,
    loader,
    *,
    shift_y_px: int,
    shift_x_px: int,
    evaluation_regions,
    pred_case: int,
    num_modes: int,
    eval_amplitudes: np.ndarray,
    eval_amplitudes_phases: np.ndarray,
    eval_phases: np.ndarray,
    phase_option: int,
    mmf_modes: torch.Tensor,
    field_size: int,
    image_test_data: torch.Tensor,
    device: torch.device,
) -> dict:
    """
    Evaluate amplitude-related metrics when the input field is shifted by (shift_y_px, shift_x_px).
    """
    model.eval()
    all_weights_pred: list[np.ndarray] = []

    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            shifted_images = shift_complex_batch(images, shift_y_px, shift_x_px)
            preds = model(shifted_images)
            preds_np = preds.detach().cpu().numpy()

            for sample_idx in range(preds_np.shape[0]):
                intensity_map = preds_np[sample_idx, 0]
                weights = []
                for (x0, x1, y0, y1) in evaluation_regions:
                    weights.append(float(intensity_map[y0:y1, x0:x1].mean()))
                weights = np.asarray(weights, dtype=np.float64)

                if pred_case == 3 and num_modes <= len(weights):
                    norm_val = np.linalg.norm(weights[:num_modes])
                    if norm_val > 0:
                        weights[:num_modes] /= norm_val
                else:
                    norm_val = np.linalg.norm(weights)
                    if norm_val > 0:
                        weights /= norm_val

                all_weights_pred.append(weights)

    metrics = compute_model_prediction_metrics(
        all_weights_pred,
        eval_amplitudes,
        eval_amplitudes_phases,
        eval_phases,
        phase_option,
        pred_case,
        num_modes,
        mmf_modes,
        field_size,
        image_test_data,
    )

    return metrics

Detection Regions: [(17, 27, 27, 37), (45, 55, 27, 37), (73, 83, 27, 37), (17, 27, 63, 73), (45, 55, 63, 73), (73, 83, 63, 73)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    mode_triptych_records: list[dict[str, str | int]] = []
    if evaluation_mode == "eigenmode":
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_{timestamp_tag}"
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(
                f"✔ Saved mode {mode_idx + 1} triptych -> {record['fig_path']}\n"
                f"  MAT -> {record['mat_path']}"
            )

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [1/1000], Loss: 0.008090289309620857, Epoch Time: 0.10 seconds
Epoch [100/1000], Loss: 0.002760213334113359, Epoch Time: 0.01 seconds
Epoch [200/1000], Loss: 0.002684592502191663, Epoch Time: 0.01 seconds
Epoch [300/1000], Loss: 0.002658451907336712, Epoch Time: 0.01 seconds
Epoch [400/1000], Loss: 0.002648203633725643, Epoch Time: 0.01 seconds
Epoch [500/1000], Loss: 0.002644307911396027, Epoch Time: 0.01 seconds
Epoch [600/1000], Loss: 0.002642895793542266, Epoch Time: 0.01 seconds
Epoch [700/1000], Loss: 0.002642374718561769, Epoch Time: 0.01 seconds
Epoch [800/1000], Loss: 0.002642178907990456, Epoch Time: 0.01 seconds
Epoch [900/1000], Loss: 0.002642104867845774, Epoch Time: 0.01 seconds
Epoch [1000/1000], Loss: 0.002642077626660466, Epoch Time: 0.01 seconds
Total tra

In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")
        mode_triptychs = summary.get("mode_triptychs", [])
        if mode_triptychs:
            print("   Mode triptychs:")
            for trip in mode_triptychs:
                print(
                    f"     Mode {trip['mode']}: fig={trip['fig']}, mat={trip['mat']}"
                )

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

# #直观的看看输出和label的差异
# for s in [0, 1, 2, 5]:
#     plot_sys_vs_label_strict(
#         D2NN,
#         test_dataset,
#         sample_idx=s,
#         evaluation_regions=evaluation_regions,
#         detect_radius=detectsize,
#         save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
#         device=device,
#         use_big_canvas=False,
#         sys_scale="bg_pct",
#         sys_pct=99.5,
#         clip_pct=99.5,
#         mask_roi_for_scale=True,
#         show_signed=True,
#     )
#     plot_reconstruction_vs_input(
#         image_test_data=image_test_data,
#         reconstructed_fields=all_image_data_pred,
#         sample_idx=s,
#         model_idx=0,
#         save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
#     )

# #

NameError: name 'all_training_summaries' is not defined

In [None]:
temp_dataset = test_dataset
FIXED_E_INDEX = 4

def get_fixed_input(dataset, idx, device):
    if isinstance(dataset, list):
        sample = dataset[idx][0]
    else:
        sample = dataset.tensors[0][idx]
    return sample.squeeze(0).to(device)


assert len(temp_dataset) > 0, "test_dataset 为空"
temp_E = get_fixed_input(temp_dataset, FIXED_E_INDEX % len(temp_dataset), device)

z_start = 0.0
z_step = 5e-6
z_prop_plus = z_prop

save_root = Path("results_MD")
save_root.mkdir(parents=True, exist_ok=True)
run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename_prefix = f"ODNN_vis_{run_stamp}"

for i_model, phase_masks in enumerate(all_phase_masks, start=1):
    model_dir = save_root / f"m{i_model}"
    scans, camera_field = visualize_model_slices(
        D2NN,
        phase_masks,
        temp_E,
        output_dir=model_dir,
        sample_tag=f"m{i_model}",
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop_plus=z_prop_plus,
        z_step=z_step,
        pixel_size=pixel_size,
        wavelength=wavelength,
    )

    phase_stack = np.stack([np.asarray(mask, dtype=np.float32) for mask in phase_masks], axis=0)
    meta = {
        "z_start": float(z_start),
        "z_step": float(z_step),
        "z_layers": float(z_layers),
        "z_prop": float(z_prop),
        "z_prop_plus": float(z_prop_plus),
        "pixel_size": float(pixel_size),
        "wavelength": float(wavelength),
        "layer_size": int(layer_size),
        "padding_ratio": 0.5,
    }

    mat_path = model_dir / f"{filename_prefix}_LIGHT_m{i_model}.mat"
    save_to_mat_light_plus(
        mat_path,
        phase_stack=phase_stack,
        input_field=temp_E.detach().cpu().numpy(),
        scans=scans,
        camera_field=camera_field,
        sample_stacks_kmax=20,
        save_amplitude_only=False,
        meta=meta,
    )
    print("Saved ->", mat_path)

    save_masks_one_file_per_layer(
        phase_masks,
        out_dir=model_dir,
        base_name=f"{filename_prefix}_MASK",
        save_degree=False,
        use_xlsx=True,
    )

: 

已重启 odnn_venv (Python 3.13.5)

In [None]:
import json
import math
import os
import random
import time
from datetime import datetime
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from scipy.io import savemat
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset

from ODNN_functions import (
    create_evaluation_regions,
    generate_complex_weights,
    generate_fields_ts,
)
from odnn_generate_label import (
    compute_label_centers,
    compose_labels_from_patterns,
    generate_detector_patterns,
)
from odnn_io import load_complex_modes_from_mat
from odnn_model import D2NNModel
from odnn_processing import prepare_sample
from odnn_training_eval import (
    build_superposition_eval_context,
    compute_model_prediction_metrics,
    evaluate_spot_metrics,
    format_metric_report,
    generate_superposition_sample,
    infer_superposition_output,
)
from odnn_training_io import save_masks_one_file_per_layer, save_to_mat_light_plus
from odnn_training_visualization import (
    capture_eigenmode_propagation,
    export_superposition_slices,
    plot_amplitude_comparison_grid,
    plot_reconstruction_vs_input,
    plot_sys_vs_label_strict,
    save_superposition_triptych,
    save_mode_triptych,
    visualize_model_slices,
)

SEED = 424242
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# 让 cuDNN/算子走确定性分支
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
if torch.cuda.is_available():
    device = torch.device('cuda:5')           # 或者 'cuda:0'
    print('Using Device:', device)
else:
    device = torch.device('cpu')
    print('Using Device: CPU')

Using Device: cuda:5


In [None]:
field_size = 25  #the field size in eigenmodes_OM4 is 50 pixels
layer_size = 100 #400#300#100
num_data = 1000 # options: 1. random datas 2.eigenmodes
num_modes = 6 #the mode number of MMF 3 6 10
circle_focus_radius = 5 # radius when using uniform circular detectors
circle_detectsize = 10  # square window size for circular detectors
eigenmode_focus_radius = 12.5  # radius when using eigenmode patterns
eigenmode_detectsize = 27    # square window size for eigenmode patterns
focus_radius = circle_focus_radius
detectsize = circle_detectsize
batch_size = 16

# Evaluation selection: "eigenmode" uses the base modes, "superposition" samples random mixtures
evaluation_mode = "superposition"  # options: "eigenmode", "superposition"
num_superposition_eval_samples = 1000
num_superposition_visual_samples = 20
run_superposition_debug = True
save_superposition_plots = True
save_superposition_slices = True
run_misalignment_robustness = True
label_pattern_mode = "circle"  # options: "eigenmode", "circle"
# Define multiple D2NN models 
num_layer_option = [3]   #, 3]#, 4]  # Define the different layer-number ODNN
all_losses = [] #the loss for each epoch of each ODNN model
all_phase_masks = [] #the phase masks field of each ODNN model
all_predictions = [] #the output light field of each ODNN model
model_metrics: list[dict] = []
all_amplitudes_diff: list[np.ndarray] = []
all_average_amplitudes_diff: list[float] = []
all_amplitudes_relative_diff: list[float] = []
all_complex_weights_pred: list[np.ndarray] = []
all_image_data_pred: list[np.ndarray] = []
all_cc_real: list[np.ndarray] = []
all_cc_imag: list[np.ndarray] = []
all_cc_recon_amp: list[np.ndarray] = []
all_cc_recon_phase: list[np.ndarray] = []
all_training_summaries: list[dict] = []
# SLM
z_layers   = 40e-6        # 原 47.571e-3  -> 40 μm
pixel_size = 1e-6
z_prop     = 120e-6        # 原 16.74e-2   -> 60 μm plus 40（最后一层到相机）
wavelength = 1568e-9      # 原 1568     -> 1550 nm
z_input_to_first = 40e-6  # 40 μm # 新增：输入面到第一层的传播距离

In [None]:
eigenmodes_OM4 = load_complex_modes_from_mat(
    'mmf_6modes_25_PD_1.15.mat',
    key='modes_field'
)
# (H, W, M)
print("Loaded modes shape:", eigenmodes_OM4.shape, "dtype:", eigenmodes_OM4.dtype)

# 取前 num_modes 个 → (H, W, M_sel) → (M_sel, H, W)
MMF_data = eigenmodes_OM4[:, :, :num_modes].transpose(2, 0, 1)
MMF_data_amp_norm = (np.abs(MMF_data) - np.min(np.abs(MMF_data))) / (np.max(np.abs(MMF_data)) - np.min(np.abs(MMF_data)))

MMF_data = MMF_data_amp_norm * np.exp(1j * np.angle(MMF_data))

#要是以后确定了用4我在想要不要去掉其他选项
phase_option = 4
#phase_option 1: (0,0,...,0)
#phase_option 2: (0,2pi,...,2pi)
#phase_option 3: (0,pi,...,2pi)
#phase_option 4: eigenmodes
#phase_option 5: (0,pi,...,pi)

if phase_option in [1, 2, 3, 5]:
    amplitudes,phases = generate_complex_weights(num_data,num_modes,phase_option)

if phase_option == 4:
    num_data = num_modes # use the eigenmodes to train ODNN
    amplitudes = np.eye(num_modes)#[[1,0,0][0,1,0][0,0,1]]
    phases = np.eye(num_modes)

amplitudes_phases_ori = np.hstack((amplitudes[:, :], phases[:, 1:]))  # amplitudes (l2 norm) phases
amplitudes_phases = np.hstack((amplitudes[:, :], phases[:, 1:]/(2*np.pi)))  # amplitudes (l2 norm) phases (0-1)

# Generate complex weights vector with specified amplitudes and phases
complex_weights = amplitudes * np.exp(1j * phases)
MMF_data_ts = torch.from_numpy(MMF_data)
complex_weights_ts = torch.from_numpy(complex_weights)
image_data = generate_fields_ts(complex_weights_ts, MMF_data_ts, num_data, num_modes, field_size).to(torch.complex64)
image_test_data = image_data

Loaded modes shape: (25, 25, 6) dtype: complex64


In [None]:
'''
pred_case = 1: only amplitudes prediction
pred_case = 2: only phases prediction
pred_case = 3: amplitudes and phases prediction
pred_case = 4: amplitudes and phases prediction (extra energy phase area)
'''
#
pred_case = 1
label_data = torch.zeros([num_data,1,layer_size,layer_size])
label_size = layer_size

if pred_case == 1: # 3
    num_detector = num_modes
    detector_focus_radius = focus_radius
    detector_detectsize = detectsize
    if label_pattern_mode == "eigenmode":
        pattern_stack = np.transpose(np.abs(MMF_data), (1, 2, 0))
        pattern_h, pattern_w, _ = pattern_stack.shape
        if pattern_h > label_size or pattern_w > label_size:
            raise ValueError(
                f"Eigenmode pattern size ({pattern_h}x{pattern_w}) exceeds label canvas {label_size}."
            )
        layout_radius = math.ceil(max(pattern_h, pattern_w) / 2)
        detector_focus_radius = eigenmode_focus_radius
        detector_detectsize = eigenmode_detectsize
    elif label_pattern_mode == "circle":
        circle_radius = circle_focus_radius
        pattern_size = circle_radius * 2
        if pattern_size % 2 == 0:
            pattern_size += 1  
        pattern_stack = generate_detector_patterns(pattern_size, pattern_size, num_detector, shape="circle")
        layout_radius = circle_radius
        detector_focus_radius = circle_radius
        detector_detectsize = circle_detectsize
    else:
        raise ValueError(f"Unknown label_pattern_mode: {label_pattern_mode}")

    centers, _, _ = compute_label_centers(label_size, label_size, num_detector, layout_radius)
    mode_label_maps = [
        compose_labels_from_patterns(
            label_size,
            label_size,
            pattern_stack,
            centers,
            Index=i + 1,
            visualize=False,
        )
        for i in range(num_detector)
    ]
    MMF_Label_data = torch.from_numpy(
        np.stack(mode_label_maps, axis=2).astype(np.float32)
    )
    amplitude_weights = torch.from_numpy(amplitudes_phases[:num_data, 0:num_modes]).float()
    combined_labels = (
        amplitude_weights[:, None, None, :] * MMF_Label_data.unsqueeze(0)
    ).sum(dim=3)
    label_data[:, 0, :, :] = combined_labels
    focus_radius = detector_focus_radius
    detectsize = detector_detectsize

label_test_data = label_data

相邻图案边缘间距： 行=26.67, 列=17.50
相邻图案中心间距： 行=36.67, 列=27.50
中心坐标： [(32, 22), (32, 50), (32, 78), (68, 22), (68, 50), (68, 78)]


In [None]:
train_dataset = [
    prepare_sample(image_data[i], label_data[i], layer_size) for i in range(len(label_data))
]
train_tensor_data = TensorDataset(*[torch.stack(tensors) for tensors in zip(*train_dataset)])
g = torch.Generator()
g.manual_seed(SEED)

train_loader = DataLoader(
    train_tensor_data,
    batch_size=batch_size,
    shuffle=True,               # 顺序会被 g 固定
    generator=g,                # 固定打乱
   
)

superposition_eval_ctx: dict | None = None
if evaluation_mode == "eigenmode":
    test_dataset = train_dataset
    test_tensor_data = train_tensor_data
    test_loader = DataLoader(test_tensor_data, batch_size=batch_size, shuffle=False)
    eval_amplitudes = amplitudes
    eval_amplitudes_phases = amplitudes_phases
    eval_phases = phases
    image_test_data = image_data
elif evaluation_mode == "superposition":
    if pred_case != 1:
        raise ValueError("Superposition evaluation mode currently supports pred_case == 1 only.")
    super_ctx = build_superposition_eval_context(
        num_superposition_eval_samples,
        num_modes=num_modes,
        field_size=field_size,
        layer_size=layer_size,
        mmf_modes=MMF_data_ts,
        mmf_label_data=MMF_Label_data,
        batch_size=batch_size,
        second_mode_half_range=True,
    )
    test_dataset = super_ctx["dataset"]
    test_tensor_data = super_ctx["tensor_dataset"]
    test_loader = super_ctx["loader"]
    image_test_data = super_ctx["image_data"]
    eval_amplitudes = super_ctx["amplitudes"]
    eval_amplitudes_phases = super_ctx["amplitudes_phases"]
    eval_phases = super_ctx["phases"]
    superposition_eval_ctx = super_ctx
else:
    raise ValueError(f"Unknown evaluation_mode: {evaluation_mode}")

# Generate detection regions using existing function
if pred_case ==1:
    evaluation_regions = create_evaluation_regions(layer_size, layer_size, num_detector, focus_radius, detectsize)
    print("Detection Regions:", evaluation_regions)


def shift_complex_batch(batch: torch.Tensor, shift_y: int, shift_x: int) -> torch.Tensor:
    """
    Translate a batch of complex fields by (shift_y, shift_x) pixels with zero padding.
    Positive shift_y moves downward; positive shift_x moves right.
    """
    if shift_y == 0 and shift_x == 0:
        return batch

    _, _, height, width = batch.shape
    if abs(shift_y) >= height or abs(shift_x) >= width:
        return torch.zeros_like(batch)

    real_imag = torch.view_as_real(batch)
    shifted = torch.zeros_like(real_imag)

    if shift_y >= 0:
        src_y = slice(0, height - shift_y)
        dst_y = slice(shift_y, height)
    else:
        src_y = slice(-shift_y, height)
        dst_y = slice(0, height + shift_y)

    if shift_x >= 0:
        src_x = slice(0, width - shift_x)
        dst_x = slice(shift_x, width)
    else:
        src_x = slice(-shift_x, width)
        dst_x = slice(0, width + shift_x)

    shifted[:, :, dst_y, dst_x, :] = real_imag[:, :, src_y, src_x, :]
    return torch.view_as_complex(shifted)


def compute_amp_relative_error_with_shift(
    model: torch.nn.Module,
    loader,
    *,
    shift_y_px: int,
    shift_x_px: int,
    evaluation_regions,
    pred_case: int,
    num_modes: int,
    eval_amplitudes: np.ndarray,
    eval_amplitudes_phases: np.ndarray,
    eval_phases: np.ndarray,
    phase_option: int,
    mmf_modes: torch.Tensor,
    field_size: int,
    image_test_data: torch.Tensor,
    device: torch.device,
) -> dict:
    """
    Evaluate amplitude-related metrics when the input field is shifted by (shift_y_px, shift_x_px).
    """
    model.eval()
    all_weights_pred: list[np.ndarray] = []

    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            shifted_images = shift_complex_batch(images, shift_y_px, shift_x_px)
            preds = model(shifted_images)
            preds_np = preds.detach().cpu().numpy()

            for sample_idx in range(preds_np.shape[0]):
                intensity_map = preds_np[sample_idx, 0]
                weights = []
                for (x0, x1, y0, y1) in evaluation_regions:
                    weights.append(float(intensity_map[y0:y1, x0:x1].mean()))
                weights = np.asarray(weights, dtype=np.float64)

                if pred_case == 3 and num_modes <= len(weights):
                    norm_val = np.linalg.norm(weights[:num_modes])
                    if norm_val > 0:
                        weights[:num_modes] /= norm_val
                else:
                    norm_val = np.linalg.norm(weights)
                    if norm_val > 0:
                        weights /= norm_val

                all_weights_pred.append(weights)

    metrics = compute_model_prediction_metrics(
        all_weights_pred,
        eval_amplitudes,
        eval_amplitudes_phases,
        eval_phases,
        phase_option,
        pred_case,
        num_modes,
        mmf_modes,
        field_size,
        image_test_data,
    )

    return metrics

Detection Regions: [(17, 27, 27, 37), (45, 55, 27, 37), (73, 83, 27, 37), (17, 27, 63, 73), (45, 55, 63, 73), (73, 83, 63, 73)]


In [None]:
for num_layer in num_layer_option:
    print(f"\nTraining D2NN with {num_layer} layers...\n")

    D2NN = D2NNModel(
        num_layers=num_layer,
        layer_size=layer_size,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        device=device,
        padding_ratio=0.5,
        z_input_to_first=z_input_to_first,   # NEW
    ).to(device)

    print(D2NN)

    # Training
    criterion = nn.MSELoss()  # Define loss function (对比的是loss)
    optimizer = optim.Adam(D2NN.parameters(), lr=1.99) 
    scheduler = ExponentialLR(optimizer, gamma=0.99)  
    epochs = 1000
    losses = []
    epoch_durations: list[float] = []
    training_start_time = time.time()

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        D2NN.train()
        epoch_loss = 0
        for images, labels in train_loader:
            images = images.to(device, dtype=torch.complex64, non_blocking=True)
            labels = labels.to(device, dtype=torch.float32,   non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = D2NN(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)  # Calculate average loss for the epoch
        losses.append(avg_loss)  # the loss for each model
        if device.type == "cuda":
            torch.cuda.synchronize(device)
        epoch_duration = time.time() - epoch_start_time
        epoch_durations.append(epoch_duration)

        if epoch % 100 == 0 or epoch == 1 or epoch == epochs:
            print(
                f'Epoch [{epoch}/{epochs}], Loss: {avg_loss:.18f}, '
                f'Epoch Time: {epoch_duration:.2f} seconds'
            )

    if device.type == "cuda":
        torch.cuda.synchronize(device)
    total_training_time = time.time() - training_start_time
    print(
        f'Total training time for {num_layer}-layer model: {total_training_time:.2f} seconds '
        f'(~{total_training_time / 60:.2f} minutes)'
    )
    all_losses.append(losses)  # save the loss for each model
    training_output_dir = Path("results/training_analysis")
    training_output_dir.mkdir(parents=True, exist_ok=True)
    epochs_array = np.arange(1, epochs + 1, dtype=np.int32)
    cumulative_epoch_times = np.cumsum(epoch_durations)
    timestamp_tag = datetime.now().strftime("%Y%m%d_%H%M%S")

    fig, ax = plt.subplots()
    ax.plot(epochs_array, losses, label="Training Loss")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(f"D2NN Training Loss ({num_layer} layers)")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend()
    loss_plot_path = training_output_dir / f"loss_curve_layers{num_layer}_{timestamp_tag}.png"
    fig.savefig(loss_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig)

    fig_time, ax_time = plt.subplots()
    ax_time.plot(epochs_array, cumulative_epoch_times, label="Cumulative Time")
    ax_time.set_xlabel("Epoch")
    ax_time.set_ylabel("Time (seconds)")
    ax_time.set_title(f"Cumulative Training Time ({num_layer} layers)")
    ax_time.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax_time.legend()
    time_plot_path = training_output_dir / f"epoch_time_layers{num_layer}_{timestamp_tag}.png"
    fig_time.savefig(time_plot_path, dpi=300, bbox_inches="tight")
    plt.close(fig_time)

    mat_path = training_output_dir / f"training_curves_layers{num_layer}_{timestamp_tag}.mat"
    savemat(
        str(mat_path),
        {
            "epochs": epochs_array,
            "losses": np.array(losses, dtype=np.float64),
            "epoch_durations": np.array(epoch_durations, dtype=np.float64),
            "cumulative_epoch_times": np.array(cumulative_epoch_times, dtype=np.float64),
            "total_training_time": np.array([total_training_time], dtype=np.float64),
            "num_layers": np.array([num_layer], dtype=np.int32),
        },
    )

    print(f"✔ Saved training loss plot -> {loss_plot_path}")
    print(f"✔ Saved cumulative time plot -> {time_plot_path}")
    print(f"✔ Saved training log data (.mat) -> {mat_path}")

    propagation_dir = Path("results/propagation_slices")
    eigenmode_index = min(2, MMF_data_ts.shape[0] - 1)
    propagation_summary = capture_eigenmode_propagation(
        model=D2NN,
        eigenmode_field=MMF_data_ts[eigenmode_index],
        mode_index=eigenmode_index,
        layer_size=layer_size,
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop=z_prop,
        pixel_size=pixel_size,
        wavelength=wavelength,
        output_dir=propagation_dir,
        tag=f"layers{num_layer}_{timestamp_tag}",
    )
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation plot -> {propagation_summary['fig_path']}")
    print(f"✔ Saved eigenmode-{eigenmode_index + 1} propagation data (.mat) -> {propagation_summary['mat_path']}")

    mode_triptych_records: list[dict[str, str | int]] = []
    if evaluation_mode == "eigenmode":
        triptych_dir = Path("results/mode_triptychs")
        mode_tag = f"layers{num_layer}_{timestamp_tag}"
        for mode_idx in range(min(num_modes, len(MMF_data_ts))):
            label_tensor = label_data[mode_idx, 0]
            record = save_mode_triptych(
                model=D2NN,
                mode_index=mode_idx,
                eigenmode_field=MMF_data_ts[mode_idx],
                label_field=label_tensor,
                layer_size=layer_size,
                output_dir=triptych_dir,
                tag=mode_tag,
            )
            mode_triptych_records.append(
                {
                    "mode": mode_idx + 1,
                    "fig": record["fig_path"],
                    "mat": record["mat_path"],
                }
            )
            print(
                f"✔ Saved mode {mode_idx + 1} triptych -> {record['fig_path']}\n"
                f"  MAT -> {record['mat_path']}"
            )

    all_training_summaries.append(
        {
            "num_layers": num_layer,
            "total_time": total_training_time,
            "loss_plot": str(loss_plot_path),
            "time_plot": str(time_plot_path),
            "mat_path": str(mat_path),
            "propagation_fig": propagation_summary["fig_path"],
            "propagation_mat": propagation_summary["mat_path"],
            "mode_triptychs": mode_triptych_records,
        }
    )
   
    # === after training ===
    ckpt_dir = "checkpoints"
    os.makedirs(ckpt_dir, exist_ok=True)

    ckpt = {
        "state_dict": D2NN.state_dict(),
        "meta": {
            "num_layers":        len(D2NN.layers),
            "layer_size":        layer_size,
            "z_layers":          z_layers,
            "z_prop":            z_prop,
            "pixel_size":        pixel_size,
            "wavelength":        wavelength,
            "padding_ratio":     0.5,         
            "field_size":        field_size,  
            "num_modes":         num_modes, 
            "z_input_to_first":  z_input_to_first, 
        }
    }
    save_path = os.path.join(ckpt_dir, f"odnn_{len(D2NN.layers)}layers.pth")
    torch.save(ckpt, save_path)
    print("✔ Saved model ->", save_path)
    # Free GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Cache phase masks for later visualization/export
    phase_masks = []
    for layer in D2NN.layers:
        phase_np = layer.phase.detach().cpu().numpy()
        phase_masks.append(np.remainder(phase_np, 2 * np.pi))
    all_phase_masks.append(phase_masks)

    # Collect evaluation metrics for this model
    metrics = evaluate_spot_metrics(
        D2NN,
        test_loader,
        evaluation_regions,
        detect_radius=detectsize,
        device=device,
        pred_case=pred_case,
        num_modes=num_modes,
        phase_option=phase_option,
        amplitudes=eval_amplitudes,
        amplitudes_phases=eval_amplitudes_phases,
        phases=eval_phases,
        mmf_modes=MMF_data_ts,
        field_size=field_size,
        image_test_data=image_test_data,
    )

    model_metrics.append(metrics)
    all_amplitudes_diff.append(metrics.get("amplitudes_diff", np.array([])))
    all_average_amplitudes_diff.append(float(metrics.get("avg_amplitudes_diff", float("nan"))))
    all_amplitudes_relative_diff.append(float(metrics.get("avg_relative_amp_err", float("nan"))))
    all_complex_weights_pred.append(metrics.get("complex_weights_pred", np.array([])))
    all_image_data_pred.append(metrics.get("image_data_pred", np.array([])))
    all_cc_recon_amp.append(metrics.get("cc_recon_amp", np.array([])))
    all_cc_recon_phase.append(metrics.get("cc_recon_phase", np.array([])))
    all_cc_real.append(metrics.get("cc_real", np.array([])))
    all_cc_imag.append(metrics.get("cc_imag", np.array([])))

    print(
        format_metric_report(
            num_modes=num_modes,
            phase_option=phase_option,
            pred_case=pred_case,
            label=f"{num_layer} layers",
            metrics=metrics,
        )
    )


Training D2NN with 3 layers...

D2NNModel(
  (pre_propagation): Propagation()
  (layers): ModuleList(
    (0-2): 3 x DiffractionLayer()
  )
  (propagation): Propagation()
  (regression): RegressionDetector()
)
Epoch [1/1000], Loss: 0.008090289309620857, Epoch Time: 0.09 seconds
Epoch [100/1000], Loss: 0.002760213334113359, Epoch Time: 0.01 seconds
Epoch [200/1000], Loss: 0.002684592502191663, Epoch Time: 0.01 seconds
Epoch [300/1000], Loss: 0.002658451907336712, Epoch Time: 0.01 seconds
Epoch [400/1000], Loss: 0.002648203633725643, Epoch Time: 0.01 seconds
Epoch [500/1000], Loss: 0.002644307911396027, Epoch Time: 0.01 seconds
Epoch [600/1000], Loss: 0.002642895793542266, Epoch Time: 0.01 seconds
Epoch [700/1000], Loss: 0.002642374718561769, Epoch Time: 0.01 seconds
Epoch [800/1000], Loss: 0.002642178907990456, Epoch Time: 0.01 seconds
Epoch [900/1000], Loss: 0.002642104867845774, Epoch Time: 0.01 seconds
Epoch [1000/1000], Loss: 0.002642077626660466, Epoch Time: 0.01 seconds
Total tra

  plt.tight_layout(rect=[0, 0, 1, 0.97])


✔ Saved eigenmode-3 propagation plot -> results/propagation_slices/propagation_mode3_layers3_20251104_142909.png
✔ Saved eigenmode-3 propagation data (.mat) -> results/propagation_slices/propagation_mode3_layers3_20251104_142909.mat
✔ Saved model -> checkpoints/odnn_3layers.pth
3 layers: modes=6, phase_opt=4, pred_case=1
  amp_err=0.085485, amp_err_rel=0.209394
  snr_full=0.590977, snr_crop=0.831166, throughput=0.710960
  cc_amp=0.963570±0.021786, cc_phase=0.793870±0.137667, cc_real=0.969919±0.025222, cc_imag=0.970586±0.030134


In [None]:
if all_training_summaries:
    print("\nTraining duration summary:")
    for summary in all_training_summaries:
        minutes = summary["total_time"] / 60
        print(
            f" - {summary['num_layers']} layers: {summary['total_time']:.2f} s "
            f"(~{minutes:.2f} min)"
        )
        print(f"   Loss curve: {summary['loss_plot']}")
        print(f"   Time curve: {summary['time_plot']}")
        print(f"   Data (.mat): {summary['mat_path']}")
        print(f"   Propagation plot: {summary['propagation_fig']}")
        print(f"   Propagation data (.mat): {summary['propagation_mat']}")
        mode_triptychs = summary.get("mode_triptychs", [])
        if mode_triptychs:
            print("   Mode triptychs:")
            for trip in mode_triptychs:
                print(
                    f"     Mode {trip['mode']}: fig={trip['fig']}, mat={trip['mat']}"
                )

save_dir = "results/plots"
os.makedirs(save_dir, exist_ok=True)
num_samples_to_display = 6
for idx, num_layer in enumerate(num_layer_option):
    plot_amplitude_comparison_grid(
        image_test_data,
        all_image_data_pred[idx],
        all_cc_recon_amp[idx],
        max_samples=num_samples_to_display,
        save_path=os.path.join(save_dir, f"Amp_{num_layer}layers.png"),
        title=f"Amp. distribution of Real and Predicted Images({num_layer}_layer_ODNN)",
    )

# #直观的看看输出和label的差异
# for s in [0, 1, 2, 5]:
#     plot_sys_vs_label_strict(
#         D2NN,
#         test_dataset,
#         sample_idx=s,
#         evaluation_regions=evaluation_regions,
#         detect_radius=detectsize,
#         save_path=f"results/plots/IO_Pred_Label_RAW_{s}.png",
#         device=device,
#         use_big_canvas=False,
#         sys_scale="bg_pct",
#         sys_pct=99.5,
#         clip_pct=99.5,
#         mask_roi_for_scale=True,
#         show_signed=True,
#     )
#     plot_reconstruction_vs_input(
#         image_test_data=image_test_data,
#         reconstructed_fields=all_image_data_pred,
#         sample_idx=s,
#         model_idx=0,
#         save_path=f"results/plots/Reconstruction_vs_Input_{s}.png",
#     )

# #


Training duration summary:
 - 3 layers: 10.74 s (~0.18 min)
   Loss curve: results/training_analysis/loss_curve_layers3_20251104_142909.png
   Time curve: results/training_analysis/epoch_time_layers3_20251104_142909.png
   Data (.mat): results/training_analysis/training_curves_layers3_20251104_142909.mat
   Propagation plot: results/propagation_slices/propagation_mode3_layers3_20251104_142909.png
   Propagation data (.mat): results/propagation_slices/propagation_mode3_layers3_20251104_142909.mat
✔ Saved: /home/ydzhang/Desktop/odnn_code/results/plots/Amp_3layers.png


In [None]:
temp_dataset = test_dataset
FIXED_E_INDEX = 4

def get_fixed_input(dataset, idx, device):
    if isinstance(dataset, list):
        sample = dataset[idx][0]
    else:
        sample = dataset.tensors[0][idx]
    return sample.squeeze(0).to(device)


assert len(temp_dataset) > 0, "test_dataset 为空"
temp_E = get_fixed_input(temp_dataset, FIXED_E_INDEX % len(temp_dataset), device)

z_start = 0.0
z_step = 5e-6
z_prop_plus = z_prop

save_root = Path("results_MD")
save_root.mkdir(parents=True, exist_ok=True)
run_stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename_prefix = f"ODNN_vis_{run_stamp}"

for i_model, phase_masks in enumerate(all_phase_masks, start=1):
    model_dir = save_root / f"m{i_model}"
    scans, camera_field = visualize_model_slices(
        D2NN,
        phase_masks,
        temp_E,
        output_dir=model_dir,
        sample_tag=f"m{i_model}",
        z_input_to_first=z_input_to_first,
        z_layers=z_layers,
        z_prop_plus=z_prop_plus,
        z_step=z_step,
        pixel_size=pixel_size,
        wavelength=wavelength,
    )

    phase_stack = np.stack([np.asarray(mask, dtype=np.float32) for mask in phase_masks], axis=0)
    meta = {
        "z_start": float(z_start),
        "z_step": float(z_step),
        "z_layers": float(z_layers),
        "z_prop": float(z_prop),
        "z_prop_plus": float(z_prop_plus),
        "pixel_size": float(pixel_size),
        "wavelength": float(wavelength),
        "layer_size": int(layer_size),
        "padding_ratio": 0.5,
    }

    mat_path = model_dir / f"{filename_prefix}_LIGHT_m{i_model}.mat"
    save_to_mat_light_plus(
        mat_path,
        phase_stack=phase_stack,
        input_field=temp_E.detach().cpu().numpy(),
        scans=scans,
        camera_field=camera_field,
        sample_stacks_kmax=20,
        save_amplitude_only=False,
        meta=meta,
    )
    print("Saved ->", mat_path)

    save_masks_one_file_per_layer(
        phase_masks,
        out_dir=model_dir,
        base_name=f"{filename_prefix}_MASK",
        save_degree=False,
        use_xlsx=True,
    )

  plt.tight_layout()


Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_input.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_layer1.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_layer2.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_layer3.png
Saved figure -> /home/ydzhang/Desktop/odnn_code/results_MD/m1/m1_scan_to_camera.png
Saved (v5 plus): results_MD/m1/ODNN_vis_20251104_142915_LIGHT_m1.mat
Saved -> results_MD/m1/ODNN_vis_20251104_142915_LIGHT_m1.mat


In [None]:
if pred_case == 1 and run_superposition_debug:
    super_dir = Path("results_superposition")
    super_dir.mkdir(parents=True, exist_ok=True)
    super_tag = datetime.now().strftime("%Y%m%d_%H%M%S")
    super_records: list[dict[str, str | int]] = []
    slice_reference_input: torch.Tensor | None = None

    for sample_idx in range(num_superposition_visual_samples):
        super_sample = generate_superposition_sample(
            num_modes=num_modes,
            field_size=field_size,
            layer_size=layer_size,
            mmf_modes=MMF_data_ts,
            mmf_label_data=MMF_Label_data,
        )
        super_output_map = infer_superposition_output(
            D2NN,
            super_sample["padded_image"],
            device,
        )

        sample_tag = f"{super_tag}_s{sample_idx:02d}"
        triptych_paths = save_superposition_triptych(
            input_field=super_sample["padded_image"][0],
            output_intensity_map=super_output_map,
            amplitudes=super_sample["amplitudes"],
            phases=super_sample["phases"],
            complex_weights=super_sample["complex_weights"],
            label_map=super_sample["padded_label"][0],
            evaluation_regions=evaluation_regions,
            detect_radius=detectsize,
            output_dir=super_dir,
            tag=sample_tag,
            save_plot=save_superposition_plots,
        )
        if triptych_paths["fig_path"]:
            print(
                f"Superposition sample {sample_idx + 1}/{num_superposition_visual_samples} -> "
                f"{triptych_paths['fig_path']}"
            )
        print(f"  MAT saved -> {triptych_paths['mat_path']}")

        super_records.append(
            {
                "index": sample_idx,
                "tag": sample_tag,
                "fig": triptych_paths["fig_path"] if triptych_paths else "",
                "mat": triptych_paths["mat_path"] if triptych_paths else "",
            }
        )

        if slice_reference_input is None:
            slice_reference_input = (
                super_sample["padded_image"].squeeze(0).to(device, dtype=torch.complex64)
            )

    if save_superposition_slices and all_phase_masks and slice_reference_input is not None:
        slices_root = super_dir / f"slices_{super_tag}"
        export_superposition_slices(
            D2NN,
            all_phase_masks,
            slice_reference_input,
            slices_root,
            sample_tag="superposition",
            z_input_to_first=z_input_to_first,
            z_layers=z_layers,
            z_prop=z_prop,
            z_step=z_step,
            pixel_size=pixel_size,
            wavelength=wavelength,
        )

    if super_records:
        print("\nSuperposition sample outputs:")
        for record in super_records:
            print(
                f" - Sample {record['index'] + 1:02d} ({record['tag']}): "
                f"fig={record['fig']}, mat={record['mat']}"
            )

[Superposition] 20251104_142934_s00 label weights: [0.6808 0.1229 0.4798 0.2562 0.1763 0.441 ] (sum_sq=1.000000)
[Superposition] 20251104_142934_s00 predicted weights: [0.6607 0.1947 0.463  0.2374 0.1971 0.4648] (sum_sq=1.000000)
Superposition sample 1/20 -> results_superposition/super_triptych_20251104_142934_s00.png
  MAT saved -> results_superposition/super_triptych_20251104_142934_s00.mat
[Superposition] 20251104_142934_s01 label weights: [0.5822 0.5824 0.1089 0.1084 0.2438 0.4887] (sum_sq=1.000000)
[Superposition] 20251104_142934_s01 predicted weights: [0.5665 0.6057 0.1407 0.1843 0.2586 0.4377] (sum_sq=1.000000)
Superposition sample 2/20 -> results_superposition/super_triptych_20251104_142934_s01.png
  MAT saved -> results_superposition/super_triptych_20251104_142934_s01.mat
[Superposition] 20251104_142934_s02 label weights: [0.5402 0.2013 0.0856 0.2285 0.401  0.6688] (sum_sq=1.000000)
[Superposition] 20251104_142934_s02 predicted weights: [0.5394 0.1929 0.1665 0.2175 0.4406 0.63

  plt.tight_layout()


KeyboardInterrupt: 